mirror of https://github.com/Mai-with-u/MaiBot.git
增加缓存检查
parent
8cdb0238f8
commit
2181de7a5f
|
|
@ -105,6 +105,34 @@ class ToolHistoryManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录工具调用时发生错误: {e}")
|
logger.error(f"记录工具调用时发生错误: {e}")
|
||||||
|
|
||||||
|
def find_cached_result(self, tool_name: str, args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""查找匹配的缓存记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
args: 工具调用参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Dict[str, Any]]: 如果找到匹配的缓存记录则返回结果,否则返回None
|
||||||
|
"""
|
||||||
|
# 检查是否启用历史记录
|
||||||
|
if not global_config.tool.history.enable_history:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 清理输入参数中的敏感信息以便比较
|
||||||
|
sanitized_input_args = self._sanitize_args(args)
|
||||||
|
|
||||||
|
# 按时间倒序遍历历史记录
|
||||||
|
for record in reversed(self._history):
|
||||||
|
if (record["tool_name"] == tool_name and
|
||||||
|
record["status"] == "completed" and
|
||||||
|
record["ttl_count"] < record.get("ttl", 5)):
|
||||||
|
# 比较参数是否匹配
|
||||||
|
if self._sanitize_args(record["arguments"]) == sanitized_input_args:
|
||||||
|
logger.info(f"工具 {tool_name} 命中缓存记录")
|
||||||
|
return record["result"]
|
||||||
|
return None
|
||||||
|
|
||||||
def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""清理参数中的敏感信息"""
|
"""清理参数中的敏感信息"""
|
||||||
sensitive_keys = ['api_key', 'token', 'password', 'secret']
|
sensitive_keys = ['api_key', 'token', 'password', 'secret']
|
||||||
|
|
@ -297,6 +325,12 @@ def wrap_tool_executor():
|
||||||
|
|
||||||
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
|
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 首先检查缓存
|
||||||
|
if cached_result := history_manager.find_cached_result(tool_call.func_name, tool_call.args):
|
||||||
|
logger.info(f"{getattr(self, 'log_prefix', '')}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||||
|
return cached_result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await original_execute(self, tool_call, tool_instance)
|
result = await original_execute(self, tool_call, tool_instance)
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue