MaiBot/src/memory_system/retrieval_tools/tool_registry.py

161 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
工具注册系统
提供统一的工具注册和管理接口
"""
from typing import List, Dict, Any, Optional, Callable, Awaitable
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolParamType
logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
):
"""
初始化工具
Args:
name: 工具名称
description: 工具描述
parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}]
execute_func: 执行函数,必须是异步函数
"""
self.name = name
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt"""
param_descriptions = []
for param in self.parameters:
param_name = param.get("name", "")
param_type = param.get("type", "string")
param_desc = param.get("description", "")
required = param.get("required", True)
required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
def get_tool_definition(self) -> Dict[str, Any]:
"""获取工具定义用于LLM function calling
Returns:
Dict[str, Any]: 工具定义字典格式与BaseTool一致
格式: {"name": str, "description": str, "parameters": List[Tuple]}
"""
# 转换参数格式为元组列表格式与BaseTool一致
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
param_tuples = []
for param in self.parameters:
param_name = param.get("name", "")
param_type_str = param.get("type", "string").lower()
param_desc = param.get("description", "")
is_required = param.get("required", False)
enum_values = param.get("enum", None)
# 转换类型字符串到ToolParamType
type_mapping = {
"string": ToolParamType.STRING,
"integer": ToolParamType.INTEGER,
"int": ToolParamType.INTEGER,
"float": ToolParamType.FLOAT,
"boolean": ToolParamType.BOOLEAN,
"bool": ToolParamType.BOOLEAN,
}
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
# 构建参数元组
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
param_tuples.append(param_tuple)
# 构建工具定义格式与BaseTool.get_tool_definition()一致
tool_def = {
"name": self.name,
"description": self.description,
"parameters": param_tuples
}
return tool_def
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self):
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
if tool.name in self.tools:
logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册")
return
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具"""
return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具"""
return self.tools.copy()
def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt"""
descriptions = []
for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt已废弃保留用于兼容"""
action_types = [tool.name for tool in self.tools.values()]
action_types.append("final_answer")
action_types.append("no_answer")
return "".join([f'"{at}"' for at in action_types])
def get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取所有工具的定义列表用于LLM function calling
Returns:
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
"""
return [tool.get_tool_definition() for tool in self.tools.values()]
# 全局工具注册器实例
_tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
) -> None:
"""注册记忆检索工具的便捷函数
Args:
name: 工具名称
description: 工具描述
parameters: 参数定义列表
execute_func: 执行函数
"""
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
_tool_registry.register_tool(tool)
def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例"""
return _tool_registry