mirror of https://github.com/Mai-with-u/MaiBot.git
优化工具历史记录
parent
17d8bac504
commit
2622d3de3e
|
|
@ -26,7 +26,7 @@ def get_tool_history_prompt(message_id: Optional[str] = None) -> str:
|
|||
格式化的工具历史提示词
|
||||
"""
|
||||
return tool_history_manager.get_recent_history_prompt(
|
||||
session_id=message_id
|
||||
chat_id=message_id
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
import asyncio
|
||||
|
||||
from .logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("tool_history")
|
||||
|
||||
|
|
@ -52,8 +53,12 @@ class ToolHistoryManager:
|
|||
result: Any,
|
||||
execution_time: float,
|
||||
status: str,
|
||||
session_id: Optional[str] = None):
|
||||
chat_id: Optional[str] = None):
|
||||
"""记录工具调用"""
|
||||
# 检查是否启用历史记录
|
||||
if not global_config.tool.history.enable_history:
|
||||
return
|
||||
|
||||
try:
|
||||
# 创建记录
|
||||
record = {
|
||||
|
|
@ -63,7 +68,7 @@ class ToolHistoryManager:
|
|||
"result": self._sanitize_result(result),
|
||||
"execution_time": execution_time,
|
||||
"status": status,
|
||||
"session_id": session_id
|
||||
"chat_id": chat_id
|
||||
}
|
||||
|
||||
# 添加到内存中的历史记录
|
||||
|
|
@ -128,7 +133,7 @@ class ToolHistoryManager:
|
|||
tool_names: Optional[List[str]] = None,
|
||||
start_time: Optional[Union[datetime, str]] = None,
|
||||
end_time: Optional[Union[datetime, str]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
status: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
|
|
@ -138,7 +143,7 @@ class ToolHistoryManager:
|
|||
tool_names: 工具名称列表,为空则查询所有工具
|
||||
start_time: 开始时间,可以是datetime对象或ISO格式字符串
|
||||
end_time: 结束时间,可以是datetime对象或ISO格式字符串
|
||||
session_id: 会话ID,用于筛选特定会话的调用
|
||||
chat_id: 会话ID,用于筛选特定会话的调用
|
||||
limit: 返回记录数量限制
|
||||
status: 执行状态筛选("completed"或"error")
|
||||
|
||||
|
|
@ -177,11 +182,11 @@ class ToolHistoryManager:
|
|||
if datetime.fromisoformat(record["timestamp"]) <= end_dt
|
||||
]
|
||||
|
||||
# 按会话ID筛选
|
||||
if session_id:
|
||||
# 按聊天ID筛选
|
||||
if chat_id:
|
||||
filtered_history = [
|
||||
record for record in filtered_history
|
||||
if record.get("session_id") == session_id
|
||||
if record.get("chat_id") == chat_id
|
||||
]
|
||||
|
||||
# 按状态筛选
|
||||
|
|
@ -198,35 +203,53 @@ class ToolHistoryManager:
|
|||
return filtered_history
|
||||
|
||||
def get_recent_history_prompt(self,
|
||||
limit: int = 5,
|
||||
session_id: Optional[str] = None) -> str:
|
||||
limit: Optional[int] = None,
|
||||
chat_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
获取最近工具调用历史的提示词
|
||||
|
||||
Args:
|
||||
limit: 返回的历史记录数量
|
||||
session_id: 会话ID,用于只获取当前会话的历史
|
||||
limit: 返回的历史记录数量,如果不提供则使用配置中的max_history
|
||||
chat_id: 会话ID,用于只获取当前会话的历史
|
||||
|
||||
Returns:
|
||||
格式化的历史记录提示词
|
||||
"""
|
||||
# 检查是否启用历史记录
|
||||
if not global_config.tool.history.enable_history:
|
||||
return ""
|
||||
|
||||
# 使用配置中的最大历史记录数
|
||||
if limit is None:
|
||||
limit = global_config.tool.history.max_history
|
||||
|
||||
recent_history = self.query_history(
|
||||
session_id=session_id,
|
||||
chat_id=chat_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if not recent_history:
|
||||
return ""
|
||||
|
||||
prompt = "\n最近的工具调用历史:\n"
|
||||
prompt = "\n工具执行历史:\n"
|
||||
for record in recent_history:
|
||||
status = "成功" if record["status"] == "completed" else "失败"
|
||||
timestamp = datetime.fromisoformat(record["timestamp"]).strftime("%H:%M:%S")
|
||||
prompt += (
|
||||
f"- [{timestamp}] {record['tool_name']} ({status})\n"
|
||||
f" 参数: {json.dumps(record['arguments'], ensure_ascii=False)}\n"
|
||||
f" 结果: {str(record['result'])[:200]}...\n"
|
||||
)
|
||||
# 提取结果中的name和content
|
||||
result = record['result']
|
||||
if isinstance(result, dict):
|
||||
name = result.get('name', record['tool_name'])
|
||||
content = result.get('content', str(result))
|
||||
else:
|
||||
name = record['tool_name']
|
||||
content = str(result)
|
||||
|
||||
# 格式化内容,去除多余空白和换行
|
||||
content = content.strip().replace('\n', ' ')
|
||||
|
||||
# 如果内容太长则截断
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
|
||||
prompt += f"{name}: \n{content}\n\n"
|
||||
|
||||
return prompt
|
||||
|
||||
|
|
@ -254,11 +277,11 @@ def wrap_tool_executor():
|
|||
# 记录成功的调用
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
args=tool_call.arguments,
|
||||
args=tool_call.args,
|
||||
result=result,
|
||||
execution_time=execution_time,
|
||||
status="completed",
|
||||
session_id=getattr(self, 'session_id', None)
|
||||
chat_id=getattr(self, 'chat_id', None)
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
@ -268,11 +291,11 @@ def wrap_tool_executor():
|
|||
# 记录失败的调用
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
args=tool_call.arguments,
|
||||
args=tool_call.args,
|
||||
result=str(e),
|
||||
execution_time=execution_time,
|
||||
status="error",
|
||||
session_id=getattr(self, 'session_id', None)
|
||||
chat_id=getattr(self, 'chat_id', None)
|
||||
)
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -283,6 +283,20 @@ class ExpressionConfig(ConfigBase):
|
|||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolHistoryConfig(ConfigBase):
|
||||
"""工具历史记录配置类"""
|
||||
|
||||
enable_history: bool = True
|
||||
"""是否启用工具历史记录"""
|
||||
|
||||
max_history: int = 100
|
||||
"""历史记录最大保存数量"""
|
||||
|
||||
data_dir: str = "data/tool_history"
|
||||
"""历史记录保存目录"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolConfig(ConfigBase):
|
||||
"""工具配置类"""
|
||||
|
|
@ -290,6 +304,9 @@ class ToolConfig(ConfigBase):
|
|||
enable_tool: bool = False
|
||||
"""是否在聊天中启用工具"""
|
||||
|
||||
history: ToolHistoryConfig = field(default_factory=ToolHistoryConfig)
|
||||
"""工具历史记录配置"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from datetime import datetime
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.common.tool_history import ToolHistoryManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
|
|
@ -32,3 +33,110 @@ def get_llm_available_tool_definitions():
|
|||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||
|
||||
|
||||
def get_tool_history(
|
||||
tool_names: Optional[List[str]] = None,
|
||||
start_time: Optional[Union[datetime, str]] = None,
|
||||
end_time: Optional[Union[datetime, str]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
status: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取工具调用历史记录
|
||||
|
||||
Args:
|
||||
tool_names: 工具名称列表,为空则查询所有工具
|
||||
start_time: 开始时间,可以是datetime对象或ISO格式字符串
|
||||
end_time: 结束时间,可以是datetime对象或ISO格式字符串
|
||||
chat_id: 会话ID,用于筛选特定会话的调用
|
||||
limit: 返回记录数量限制
|
||||
status: 执行状态筛选("completed"或"error")
|
||||
|
||||
Returns:
|
||||
List[Dict]: 工具调用记录列表,每条记录包含以下字段:
|
||||
- tool_name: 工具名称
|
||||
- timestamp: 调用时间
|
||||
- arguments: 调用参数
|
||||
- result: 调用结果
|
||||
- execution_time: 执行时间
|
||||
- status: 执行状态
|
||||
- chat_id: 会话ID
|
||||
"""
|
||||
history_manager = ToolHistoryManager()
|
||||
return history_manager.query_history(
|
||||
tool_names=tool_names,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
chat_id=chat_id,
|
||||
limit=limit,
|
||||
status=status
|
||||
)
|
||||
|
||||
|
||||
def get_tool_history_text(
|
||||
tool_names: Optional[List[str]] = None,
|
||||
start_time: Optional[Union[datetime, str]] = None,
|
||||
end_time: Optional[Union[datetime, str]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
status: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
获取工具调用历史记录的文本格式
|
||||
|
||||
Args:
|
||||
tool_names: 工具名称列表,为空则查询所有工具
|
||||
start_time: 开始时间,可以是datetime对象或ISO格式字符串
|
||||
end_time: 结束时间,可以是datetime对象或ISO格式字符串
|
||||
chat_id: 会话ID,用于筛选特定会话的调用
|
||||
limit: 返回记录数量限制
|
||||
status: 执行状态筛选("completed"或"error")
|
||||
|
||||
Returns:
|
||||
str: 格式化的工具调用历史记录文本
|
||||
"""
|
||||
history = get_tool_history(
|
||||
tool_names=tool_names,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
chat_id=chat_id,
|
||||
limit=limit,
|
||||
status=status
|
||||
)
|
||||
|
||||
if not history:
|
||||
return "没有找到工具调用记录"
|
||||
|
||||
text = "工具调用历史记录:\n"
|
||||
for record in history:
|
||||
# 提取结果中的name和content
|
||||
result = record['result']
|
||||
if isinstance(result, dict):
|
||||
name = result.get('name', record['tool_name'])
|
||||
content = result.get('content', str(result))
|
||||
else:
|
||||
name = record['tool_name']
|
||||
content = str(result)
|
||||
|
||||
# 格式化内容
|
||||
content = content.strip().replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
|
||||
# 格式化时间
|
||||
timestamp = datetime.fromisoformat(record['timestamp']).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
text += f"[{timestamp}] {name}\n"
|
||||
text += f"结果: {content}\n\n"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def clear_tool_history() -> None:
|
||||
"""
|
||||
清除所有工具调用历史记录
|
||||
"""
|
||||
history_manager = ToolHistoryManager()
|
||||
history_manager.clear_history()
|
||||
|
|
|
|||
|
|
@ -151,9 +151,19 @@ class ToolExecutor:
|
|||
return [], []
|
||||
|
||||
# 提取tool_calls中的函数名称
|
||||
func_names = [call.func_name for call in tool_calls if call.func_name]
|
||||
func_names = []
|
||||
for call in tool_calls:
|
||||
try:
|
||||
if hasattr(call, 'func_name'):
|
||||
func_names.append(call.func_name)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}获取工具名称失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
if func_names:
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未找到有效的工具调用")
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in tool_calls:
|
||||
|
|
@ -216,16 +226,19 @@ class ToolExecutor:
|
|||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
# 执行工具并记录日志
|
||||
logger.debug(f"{self.log_prefix}执行工具 {function_name},参数: {function_args}")
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
logger.debug(f"{self.log_prefix}工具 {function_name} 执行成功,结果: {result}")
|
||||
return {
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": "function",
|
||||
"content": result["content"],
|
||||
"content": result.get("content", "")
|
||||
}
|
||||
logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue