mirror of https://github.com/Mai-with-u/MaiBot.git
161 lines
5.9 KiB
Python
161 lines
5.9 KiB
Python
"""
|
||
工具注册系统
|
||
提供统一的工具注册和管理接口
|
||
"""
|
||
|
||
from typing import List, Dict, Any, Optional, Callable, Awaitable
|
||
from src.common.logger import get_logger
|
||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||
|
||
logger = get_logger("memory_retrieval_tools")
|
||
|
||
|
||
class MemoryRetrievalTool:
|
||
"""记忆检索工具基类"""
|
||
|
||
def __init__(
|
||
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||
):
|
||
"""
|
||
初始化工具
|
||
|
||
Args:
|
||
name: 工具名称
|
||
description: 工具描述
|
||
parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}]
|
||
execute_func: 执行函数,必须是异步函数
|
||
"""
|
||
self.name = name
|
||
self.description = description
|
||
self.parameters = parameters
|
||
self.execute_func = execute_func
|
||
|
||
def get_tool_description(self) -> str:
|
||
"""获取工具的文本描述,用于prompt"""
|
||
param_descriptions = []
|
||
for param in self.parameters:
|
||
param_name = param.get("name", "")
|
||
param_type = param.get("type", "string")
|
||
param_desc = param.get("description", "")
|
||
required = param.get("required", True)
|
||
required_str = "必填" if required else "可选"
|
||
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
|
||
|
||
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
|
||
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
|
||
|
||
async def execute(self, **kwargs) -> str:
|
||
"""执行工具"""
|
||
return await self.execute_func(**kwargs)
|
||
|
||
def get_tool_definition(self) -> Dict[str, Any]:
|
||
"""获取工具定义,用于LLM function calling
|
||
|
||
Returns:
|
||
Dict[str, Any]: 工具定义字典,格式与BaseTool一致
|
||
格式: {"name": str, "description": str, "parameters": List[Tuple]}
|
||
"""
|
||
# 转换参数格式为元组列表,格式与BaseTool一致
|
||
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
|
||
param_tuples = []
|
||
|
||
for param in self.parameters:
|
||
param_name = param.get("name", "")
|
||
param_type_str = param.get("type", "string").lower()
|
||
param_desc = param.get("description", "")
|
||
is_required = param.get("required", False)
|
||
enum_values = param.get("enum", None)
|
||
|
||
# 转换类型字符串到ToolParamType
|
||
type_mapping = {
|
||
"string": ToolParamType.STRING,
|
||
"integer": ToolParamType.INTEGER,
|
||
"int": ToolParamType.INTEGER,
|
||
"float": ToolParamType.FLOAT,
|
||
"boolean": ToolParamType.BOOLEAN,
|
||
"bool": ToolParamType.BOOLEAN,
|
||
}
|
||
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
|
||
|
||
# 构建参数元组
|
||
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
|
||
param_tuples.append(param_tuple)
|
||
|
||
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
||
tool_def = {
|
||
"name": self.name,
|
||
"description": self.description,
|
||
"parameters": param_tuples
|
||
}
|
||
|
||
return tool_def
|
||
|
||
|
||
class MemoryRetrievalToolRegistry:
|
||
"""工具注册器"""
|
||
|
||
def __init__(self):
|
||
self.tools: Dict[str, MemoryRetrievalTool] = {}
|
||
|
||
def register_tool(self, tool: MemoryRetrievalTool) -> None:
|
||
"""注册工具"""
|
||
if tool.name in self.tools:
|
||
logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册")
|
||
return
|
||
self.tools[tool.name] = tool
|
||
logger.info(f"注册记忆检索工具: {tool.name}")
|
||
|
||
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
|
||
"""获取工具"""
|
||
return self.tools.get(name)
|
||
|
||
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
|
||
"""获取所有工具"""
|
||
return self.tools.copy()
|
||
|
||
def get_tools_description(self) -> str:
|
||
"""获取所有工具的描述,用于prompt"""
|
||
descriptions = []
|
||
for i, tool in enumerate(self.tools.values(), 1):
|
||
descriptions.append(f"{i}. {tool.get_tool_description()}")
|
||
return "\n".join(descriptions)
|
||
|
||
def get_action_types_list(self) -> str:
|
||
"""获取所有动作类型的列表,用于prompt(已废弃,保留用于兼容)"""
|
||
action_types = [tool.name for tool in self.tools.values()]
|
||
action_types.append("final_answer")
|
||
action_types.append("no_answer")
|
||
return " 或 ".join([f'"{at}"' for at in action_types])
|
||
|
||
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||
"""获取所有工具的定义列表,用于LLM function calling
|
||
|
||
Returns:
|
||
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
|
||
"""
|
||
return [tool.get_tool_definition() for tool in self.tools.values()]
|
||
|
||
|
||
# 全局工具注册器实例
|
||
_tool_registry = MemoryRetrievalToolRegistry()
|
||
|
||
|
||
def register_memory_retrieval_tool(
|
||
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||
) -> None:
|
||
"""注册记忆检索工具的便捷函数
|
||
|
||
Args:
|
||
name: 工具名称
|
||
description: 工具描述
|
||
parameters: 参数定义列表
|
||
execute_func: 执行函数
|
||
"""
|
||
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
|
||
_tool_registry.register_tool(tool)
|
||
|
||
|
||
def get_tool_registry() -> MemoryRetrievalToolRegistry:
|
||
"""获取工具注册器实例"""
|
||
return _tool_registry
|