为工具调用加入了一点点可拓展性(可选用)

pull/1001/head
SnowindMe 2025-04-19 18:37:48 +08:00
parent b1dc34f7b1
commit 3f293a2010
6 changed files with 91 additions and 19 deletions

Binary file not shown.

View File

@ -1,4 +1,5 @@
from typing import Dict, List, Any, Optional, Type from typing import Dict, List, Any, Optional, Type
from abc import ABC, abstractmethod
import inspect import inspect
import importlib import importlib
import pkgutil import pkgutil
@ -11,7 +12,7 @@ logger = get_module_logger("base_tool")
TOOL_REGISTRY = {} TOOL_REGISTRY = {}
class BaseTool: class BaseTool(ABC):
"""所有工具的基类""" """所有工具的基类"""
# 工具名称,子类必须重写 # 工具名称,子类必须重写
@ -36,6 +37,7 @@ class BaseTool:
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters}, "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
} }
@abstractmethod
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具函数 """执行工具函数
@ -111,3 +113,17 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
if not tool_class: if not tool_class:
return None return None
return tool_class() return tool_class()
def run_lua_code(lua_code: str):
"""兼容Lua代码运行小工具
Args:
lua_code (str): Lua代码
Returns:
_LuaTable: Lua运行时的全局变量
"""
from lupa import LuaRuntime
lua = LuaRuntime(unpack_returned_tuples=True)
lua.execute(lua_code)
return lua.globals()

View File

@ -1,4 +1,4 @@
from src.do_tool.tool_can_use.base_tool import BaseTool from src.do_tool.tool_can_use.base_tool import BaseTool,run_lua_code
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from typing import Dict, Any from typing import Dict, Any
@ -32,13 +32,13 @@ class CompareNumbersTool(BaseTool):
try: try:
num1 = function_args.get("num1") num1 = function_args.get("num1")
num2 = function_args.get("num2") num2 = function_args.get("num2")
lua_code = """
if num1 > num2: function CompareNumbers(a, b)
result = f"{num1} 大于 {num2}" return a .. (a > b and " 大于 " or a < b and " 小于 " or " 等于 ") .. b
elif num1 < num2: end
result = f"{num1} 小于 {num2}" """
else: CompareNumbers = run_lua_code(lua_code).CompareNumbers
result = f"{num1} 等于 {num2}" result = CompareNumbers(num1, num2)
return {"name": self.name, "content": result} return {"name": self.name, "content": result}
except Exception as e: except Exception as e:

View File

@ -1,7 +1,6 @@
from src.do_tool.tool_can_use.base_tool import BaseTool from src.do_tool.tool_can_use.base_tool import BaseTool,run_lua_code
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from typing import Dict, Any from typing import Dict, Any
from datetime import datetime
logger = get_module_logger("get_time_date") logger = get_module_logger("get_time_date")
@ -27,12 +26,13 @@ class GetCurrentDateTimeTool(BaseTool):
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
current_time = datetime.now().strftime("%H:%M:%S") lua_code = """
current_date = datetime.now().strftime("%Y-%m-%d") GetCurrentDateTime = function()
current_year = datetime.now().strftime("%Y") return ("当前时间: %s, 日期: %s, 年份: %s, 星期: %s"):format(os.date("%H:%M:%S"), os.date("%Y-%m-%d"), os.date("%Y"), os.date("%A"))
current_weekday = datetime.now().strftime("%A") end
"""
GetCurrentDateTime = run_lua_code(lua_code).GetCurrentDateTime
return { return {
"name": "get_current_date_time", "name": "get_current_date_time",
"content": f"当前时间: {current_time}, 日期: {current_date}, 年份: {current_year}, 星期: {current_weekday}", "content": GetCurrentDateTime(),
} }

View File

@ -0,0 +1,56 @@
import re
from src.do_tool.tool_can_use.base_tool import BaseTool, run_lua_code
from src.config.config import global_config
from src.plugins.models.utils_model import LLMRequest
from src.common.logger import get_module_logger
from typing import Dict, Any
logger = get_module_logger("letter_count_tool")
class LetterCountTool(BaseTool):
"""数单词内某字母个数的工具"""
name = "word_letter_count"
description = "当有人询问你或者提到某个英文单词内有多少个某字母时,可以使用这个工具来数字母(如果传入的是中文,传入之前要将中文转为英文)"
parameters = {
"type": "object",
"properties": {
"word": {"type": "string", "description": "英文单词"},
"letter": {"type": "string", "description": "英文字母"},
},
"required": ["word", "letter"],
}
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""
执行数数该单词的某字母个数的函数
Args:
function_args: 工具参数
message_txt: 原始消息文本
Returns:
Dict: 工具执行结果
"""
try:
word = function_args.get("word")
letter = function_args.get("letter")
if re.match(r"^[a-zA-Z]+$", letter) is None:
raise ValueError("请输入英文字母")
lua_code = """
function LetterCount(inputStr, targetLetter)
local lower = (inputStr:gsub("[^"..targetLetter:lower().."]", "")):len()
local upper = (inputStr:gsub("[^"..targetLetter:upper().."]", "")):len()
return string.format("字母 %s 在字符串 %s 中出现的次数:%d个(小写), %d个(大写)", targetLetter, inputStr, lower, upper)
end
"""
LetterCount = run_lua_code(lua_code).LetterCount
return {"name": self.name, "content": LetterCount(word, letter)}
except Exception as e:
logger.error(f"数字母失败: {str(e)}")
return {"name": self.name, "content": f"数字母失败: {str(e)}"}
# 注册工具
# register_tool(LetterCountTool)

View File

@ -61,7 +61,7 @@ class ToolUser:
prompt += message_txt prompt += message_txt
# prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" # prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += f"注意你就是{bot_name}{bot_name}是你的名字。根据之前的聊天记录补充问题信息,搜索时避开你的名字。\n" prompt += f"注意你就是{bot_name}{bot_name}是你的名字。根据之前的聊天记录补充问题信息,搜索时避开你的名字。\n"
prompt += "你现在需要对群里的聊天内容进行回复,现在选择工具来对消息和你的回复进行处理,你是否需要额外的信息,比如回忆或者搜寻已有的知识,改变关系和情感,或者了解你现在正在做什么" prompt += "你现在需要对群里的聊天内容进行回复,现在选择工具来对消息和你的回复进行处理,你是否需要额外的信息,比如回忆或者搜寻已有的知识,改变关系和情感,或者了解你现在正在做什么,如果是数字母的话,含中文词的部分要先转换成英文单词再数字母"
return prompt return prompt
@staticmethod @staticmethod