MaiBot/src/tools/tool_can_use/base_tool.py

116 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from typing import List, Any, Optional, Type
import inspect
import importlib
import pkgutil
import os
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("base_tool")
# 工具注册表
TOOL_REGISTRY = {}
class BaseTool:
"""所有工具的基类"""
# 工具名称,子类必须重写
name = None
# 工具描述,子类必须重写
description = None
# 工具参数定义,子类必须重写
parameters = None
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {
"type": "function",
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
}
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数
Args:
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")
def register_tool(tool_class: Type[BaseTool]):
"""注册工具到全局注册表
Args:
tool_class: 工具类
"""
if not issubclass(tool_class, BaseTool):
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
tool_name = tool_class.name
if not tool_name:
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
TOOL_REGISTRY[tool_name] = tool_class
logger.info(f"已注册: {tool_name}")
def discover_tools():
"""自动发现并注册tool_can_use目录下的所有工具"""
# 获取当前目录路径
current_dir = os.path.dirname(os.path.abspath(__file__))
package_name = os.path.basename(current_dir)
# 遍历包中的所有模块
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
# 跳过当前模块和__pycache__
if module_name == "base_tool" or module_name.startswith("__"):
continue
# 导入模块
module = importlib.import_module(f"src.tools.{package_name}.{module_name}")
# 查找模块中的工具类
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
register_tool(obj)
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
def get_all_tool_definitions() -> List[dict[str, Any]]:
"""获取所有已注册工具的定义
Returns:
List[dict]: 工具定义列表
"""
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取指定名称的工具实例
Args:
tool_name: 工具名称
Returns:
Optional[BaseTool]: 工具实例如果找不到则返回None
"""
tool_class = TOOL_REGISTRY.get(tool_name)
if not tool_class:
return None
return tool_class()