为工具调用加入了一点点可拓展性(可选用)

pull/1001/head
SnowindMe 2025-04-19 18:37:48 +08:00
parent b1dc34f7b1
commit 3f293a2010
6 changed files with 91 additions and 19 deletions

Binary file not shown.

View File

@ -1,4 +1,5 @@
from typing import Dict, List, Any, Optional, Type
from abc import ABC, abstractmethod
import inspect
import importlib
import pkgutil
@ -11,7 +12,7 @@ logger = get_module_logger("base_tool")
TOOL_REGISTRY = {}
class BaseTool:
class BaseTool(ABC):
"""所有工具的基类"""
# 工具名称,子类必须重写
@ -36,6 +37,7 @@ class BaseTool:
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
}
@abstractmethod
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具函数
@ -111,3 +113,17 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
if not tool_class:
return None
return tool_class()
def run_lua_code(lua_code: str):
"""兼容Lua代码运行小工具
Args:
lua_code (str): Lua代码
Returns:
_LuaTable: Lua运行时的全局变量
"""
from lupa import LuaRuntime
lua = LuaRuntime(unpack_returned_tuples=True)
lua.execute(lua_code)
return lua.globals()

View File

@ -1,4 +1,4 @@
from src.do_tool.tool_can_use.base_tool import BaseTool
from src.do_tool.tool_can_use.base_tool import BaseTool,run_lua_code
from src.common.logger import get_module_logger
from typing import Dict, Any
@ -32,13 +32,13 @@ class CompareNumbersTool(BaseTool):
try:
num1 = function_args.get("num1")
num2 = function_args.get("num2")
if num1 > num2:
result = f"{num1} 大于 {num2}"
elif num1 < num2:
result = f"{num1} 小于 {num2}"
else:
result = f"{num1} 等于 {num2}"
lua_code = """
function CompareNumbers(a, b)
return a .. (a > b and " 大于 " or a < b and " 小于 " or " 等于 ") .. b
end
"""
CompareNumbers = run_lua_code(lua_code).CompareNumbers
result = CompareNumbers(num1, num2)
return {"name": self.name, "content": result}
except Exception as e:

View File

@ -1,7 +1,6 @@
from src.do_tool.tool_can_use.base_tool import BaseTool
from src.do_tool.tool_can_use.base_tool import BaseTool,run_lua_code
from src.common.logger import get_module_logger
from typing import Dict, Any
from datetime import datetime
logger = get_module_logger("get_time_date")
@ -27,12 +26,13 @@ class GetCurrentDateTimeTool(BaseTool):
Returns:
Dict: 工具执行结果
"""
current_time = datetime.now().strftime("%H:%M:%S")
current_date = datetime.now().strftime("%Y-%m-%d")
current_year = datetime.now().strftime("%Y")
current_weekday = datetime.now().strftime("%A")
lua_code = """
GetCurrentDateTime = function()
return ("当前时间: %s, 日期: %s, 年份: %s, 星期: %s"):format(os.date("%H:%M:%S"), os.date("%Y-%m-%d"), os.date("%Y"), os.date("%A"))
end
"""
GetCurrentDateTime = run_lua_code(lua_code).GetCurrentDateTime
return {
"name": "get_current_date_time",
"content": f"当前时间: {current_time}, 日期: {current_date}, 年份: {current_year}, 星期: {current_weekday}",
"content": GetCurrentDateTime(),
}

View File

@ -0,0 +1,56 @@
import re
from src.do_tool.tool_can_use.base_tool import BaseTool, run_lua_code
from src.config.config import global_config
from src.plugins.models.utils_model import LLMRequest
from src.common.logger import get_module_logger
from typing import Dict, Any
logger = get_module_logger("letter_count_tool")
class LetterCountTool(BaseTool):
"""数单词内某字母个数的工具"""
name = "word_letter_count"
description = "当有人询问你或者提到某个英文单词内有多少个某字母时,可以使用这个工具来数字母(如果传入的是中文,传入之前要将中文转为英文)"
parameters = {
"type": "object",
"properties": {
"word": {"type": "string", "description": "英文单词"},
"letter": {"type": "string", "description": "英文字母"},
},
"required": ["word", "letter"],
}
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""
执行数数该单词的某字母个数的函数
Args:
function_args: 工具参数
message_txt: 原始消息文本
Returns:
Dict: 工具执行结果
"""
try:
word = function_args.get("word")
letter = function_args.get("letter")
if re.match(r"^[a-zA-Z]+$", letter) is None:
raise ValueError("请输入英文字母")
lua_code = """
function LetterCount(inputStr, targetLetter)
local lower = (inputStr:gsub("[^"..targetLetter:lower().."]", "")):len()
local upper = (inputStr:gsub("[^"..targetLetter:upper().."]", "")):len()
return string.format("字母 %s 在字符串 %s 中出现的次数:%d个(小写), %d个(大写)", targetLetter, inputStr, lower, upper)
end
"""
LetterCount = run_lua_code(lua_code).LetterCount
return {"name": self.name, "content": LetterCount(word, letter)}
except Exception as e:
logger.error(f"数字母失败: {str(e)}")
return {"name": self.name, "content": f"数字母失败: {str(e)}"}
# 注册工具
# register_tool(LetterCountTool)

View File

@ -61,7 +61,7 @@ class ToolUser:
prompt += message_txt
# prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += f"注意你就是{bot_name}{bot_name}是你的名字。根据之前的聊天记录补充问题信息,搜索时避开你的名字。\n"
prompt += "你现在需要对群里的聊天内容进行回复,现在选择工具来对消息和你的回复进行处理,你是否需要额外的信息,比如回忆或者搜寻已有的知识,改变关系和情感,或者了解你现在正在做什么"
prompt += "你现在需要对群里的聊天内容进行回复,现在选择工具来对消息和你的回复进行处理,你是否需要额外的信息,比如回忆或者搜寻已有的知识,改变关系和情感,或者了解你现在正在做什么,如果是数字母的话,含中文词的部分要先转换成英文单词再数字母"
return prompt
@staticmethod