mirror of https://github.com/Mai-with-u/MaiBot.git
feat:优化记忆检索和停止
parent
c5276ce629
commit
57b92ca124
|
|
@ -1,507 +0,0 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import importlib
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import initialize_logging, get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import LLMUsage
|
||||
|
||||
logger = get_logger("compare_finish_search_token")
|
||||
|
||||
|
||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
"""获取从指定时间开始的token使用情况
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
|
||||
Returns:
|
||||
包含token使用统计的字典
|
||||
"""
|
||||
try:
|
||||
start_datetime = datetime.fromtimestamp(start_time)
|
||||
|
||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
||||
records = (
|
||||
LLMUsage.select()
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_datetime)
|
||||
& (
|
||||
(LLMUsage.request_type.like("%memory%"))
|
||||
| (LLMUsage.request_type == "memory.question")
|
||||
| (LLMUsage.request_type == "memory.react")
|
||||
| (LLMUsage.request_type == "memory.react.final")
|
||||
)
|
||||
)
|
||||
.order_by(LLMUsage.timestamp.asc())
|
||||
)
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
request_count = 0
|
||||
model_usage = {} # 按模型统计
|
||||
|
||||
for record in records:
|
||||
total_prompt_tokens += record.prompt_tokens or 0
|
||||
total_completion_tokens += record.completion_tokens or 0
|
||||
total_tokens += record.total_tokens or 0
|
||||
total_cost += record.cost or 0.0
|
||||
request_count += 1
|
||||
|
||||
# 按模型统计
|
||||
model_name = record.model_name or "unknown"
|
||||
if model_name not in model_usage:
|
||||
model_usage[model_name] = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": 0.0,
|
||||
"request_count": 0,
|
||||
}
|
||||
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
|
||||
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
|
||||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
||||
model_usage[model_name]["request_count"] += 1
|
||||
|
||||
return {
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"request_count": request_count,
|
||||
"model_usage": model_usage,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取token使用情况失败: {e}")
|
||||
return {
|
||||
"total_prompt_tokens": 0,
|
||||
"total_completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"request_count": 0,
|
||||
"model_usage": {},
|
||||
}
|
||||
|
||||
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
try:
|
||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
||||
|
||||
module_name = "src.memory_system.memory_retrieval"
|
||||
|
||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
)
|
||||
|
||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith('src.memory_system.') and key != 'src.memory_system':
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
del sys.modules[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
||||
try:
|
||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
try:
|
||||
import src.chat.replyer.group_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
try:
|
||||
import src.chat.replyer.private_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
||||
|
||||
# 现在尝试导入 memory_retrieval
|
||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
||||
memory_retrieval_module = importlib.import_module(module_name)
|
||||
|
||||
return (
|
||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
||||
memory_retrieval_module._react_agent_solve_question,
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _init_tools_without_finish_search():
|
||||
"""初始化工具但不注册 finish_search"""
|
||||
from src.memory_system.retrieval_tools import (
|
||||
register_query_chat_history,
|
||||
register_query_person_info,
|
||||
register_query_words,
|
||||
)
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
from src.config.config import global_config
|
||||
|
||||
# 清空工具注册器
|
||||
tool_registry = get_tool_registry()
|
||||
tool_registry.tools.clear()
|
||||
|
||||
# 注册除 finish_search 外的所有工具
|
||||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
register_query_words()
|
||||
|
||||
# 如果启用 LPMM agent 模式,也注册 LPMM 工具
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
from src.memory_system.retrieval_tools.query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
register_lpmm_knowledge()
|
||||
|
||||
logger.info("已初始化工具(不包含 finish_search)")
|
||||
|
||||
|
||||
def _init_tools_with_finish_search():
|
||||
"""初始化工具并注册 finish_search"""
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
from src.memory_system.retrieval_tools import init_all_tools
|
||||
|
||||
# 清空工具注册器
|
||||
tool_registry = get_tool_registry()
|
||||
tool_registry.tools.clear()
|
||||
|
||||
# 初始化所有工具(包括 finish_search)
|
||||
init_all_tools()
|
||||
logger.info("已初始化工具(包含 finish_search)")
|
||||
|
||||
|
||||
async def get_prompt_tokens_for_tools(
|
||||
question: str,
|
||||
chat_id: str,
|
||||
use_finish_search: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""获取使用不同工具配置时的prompt token消耗
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
use_finish_search: 是否使用 finish_search 工具
|
||||
|
||||
Returns:
|
||||
包含prompt token信息的字典
|
||||
"""
|
||||
# 先初始化 prompt(如果还未初始化)
|
||||
# 注意:init_memory_retrieval_prompt 会调用 init_all_tools,所以我们需要在它之后重新设置工具
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt, _ = _import_memory_retrieval()
|
||||
init_memory_retrieval_prompt()
|
||||
|
||||
# 初始化工具(根据参数决定是否包含 finish_search)
|
||||
# 必须在 init_memory_retrieval_prompt 之后调用,因为它会调用 init_all_tools
|
||||
if use_finish_search:
|
||||
_init_tools_with_finish_search()
|
||||
else:
|
||||
_init_tools_without_finish_search()
|
||||
|
||||
# 获取工具注册器
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
tool_registry = get_tool_registry()
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
|
||||
# 验证工具列表(调试用)
|
||||
tool_names = [tool["name"] for tool in tool_definitions]
|
||||
if use_finish_search:
|
||||
if "finish_search" not in tool_names:
|
||||
logger.warning("期望包含 finish_search 工具,但工具列表中未找到")
|
||||
else:
|
||||
if "finish_search" in tool_names:
|
||||
logger.warning("期望不包含 finish_search 工具,但工具列表中找到了,将移除")
|
||||
# 移除 finish_search 工具
|
||||
tool_registry.tools.pop("finish_search", None)
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
tool_names = [tool["name"] for tool in tool_definitions]
|
||||
|
||||
# 构建第一次调用的prompt(模拟_react_agent_solve_question的第一次调用)
|
||||
from src.config.config import global_config
|
||||
bot_name = global_config.bot.nickname
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
# 构建head_prompt
|
||||
head_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_react_prompt_head",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
question=question,
|
||||
collected_info="",
|
||||
current_iteration=1,
|
||||
remaining_iterations=global_config.memory.max_agent_iterations - 1,
|
||||
max_iterations=global_config.memory.max_agent_iterations,
|
||||
)
|
||||
|
||||
# 构建消息列表(只包含system message,模拟第一次调用)
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
messages = []
|
||||
system_builder = MessageBuilder()
|
||||
system_builder.set_role(RoleType.System)
|
||||
system_builder.add_text_content(head_prompt)
|
||||
messages.append(system_builder.build())
|
||||
|
||||
# 调用LLM API来计算token(只调用一次,不实际执行)
|
||||
from src.llm_models.utils_model import LLMRequest, RequestType
|
||||
from src.config.config import model_config
|
||||
|
||||
# 创建LLM请求对象
|
||||
llm_request = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="memory.react.compare")
|
||||
|
||||
# 构建工具选项
|
||||
tool_built = llm_request._build_tool_options(tool_definitions)
|
||||
|
||||
# 直接调用 _execute_request 以获取完整的响应对象(包含 usage)
|
||||
response, model_info = await llm_request._execute_request(
|
||||
request_type=RequestType.RESPONSE,
|
||||
message_factory=lambda _client, *, _messages=messages: _messages,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
tool_options=tool_built,
|
||||
)
|
||||
|
||||
# 从响应中获取token使用情况
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
if response and hasattr(response, 'usage') and response.usage:
|
||||
prompt_tokens = response.usage.prompt_tokens or 0
|
||||
completion_tokens = response.usage.completion_tokens or 0
|
||||
total_tokens = response.usage.total_tokens or 0
|
||||
|
||||
return {
|
||||
"use_finish_search": use_finish_search,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"tool_count": len(tool_definitions),
|
||||
"tool_names": [tool["name"] for tool in tool_definitions],
|
||||
}
|
||||
|
||||
|
||||
async def compare_prompt_tokens(
|
||||
question: str,
|
||||
chat_id: str = "compare_finish_search",
|
||||
) -> Dict[str, Any]:
|
||||
"""对比使用 finish_search 工具与否的输入 token 差异
|
||||
|
||||
只运行一次,只计算输入 token 的差异,确保除了工具定义外其他内容一致
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
包含对比结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("finish_search 工具 输入 Token 消耗对比测试")
|
||||
print("=" * 80)
|
||||
print(f"\n[测试问题] {question}")
|
||||
print(f"[聊天ID] {chat_id}")
|
||||
print("\n注意: 只对比第一次LLM调用的输入token差异,不运行完整迭代流程")
|
||||
|
||||
# 第一次测试:不使用 finish_search
|
||||
print("\n" + "-" * 80)
|
||||
print("[测试 1/2] 不使用 finish_search 工具")
|
||||
print("-" * 80)
|
||||
result_without = await get_prompt_tokens_for_tools(
|
||||
question=question,
|
||||
chat_id=f"{chat_id}_without",
|
||||
use_finish_search=False,
|
||||
)
|
||||
|
||||
print(f"\n[结果]")
|
||||
print(f" 工具数量: {result_without['tool_count']}")
|
||||
print(f" 工具列表: {', '.join(result_without['tool_names'])}")
|
||||
print(f" 输入 Prompt Tokens: {result_without['prompt_tokens']:,}")
|
||||
|
||||
# 等待一下,确保数据库记录已写入
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 第二次测试:使用 finish_search
|
||||
print("\n" + "-" * 80)
|
||||
print("[测试 2/2] 使用 finish_search 工具")
|
||||
print("-" * 80)
|
||||
result_with = await get_prompt_tokens_for_tools(
|
||||
question=question,
|
||||
chat_id=f"{chat_id}_with",
|
||||
use_finish_search=True,
|
||||
)
|
||||
|
||||
print(f"\n[结果]")
|
||||
print(f" 工具数量: {result_with['tool_count']}")
|
||||
print(f" 工具列表: {', '.join(result_with['tool_names'])}")
|
||||
print(f" 输入 Prompt Tokens: {result_with['prompt_tokens']:,}")
|
||||
|
||||
# 对比结果
|
||||
print("\n" + "=" * 80)
|
||||
print("[对比结果]")
|
||||
print("=" * 80)
|
||||
|
||||
prompt_token_diff = result_with['prompt_tokens'] - result_without['prompt_tokens']
|
||||
prompt_token_diff_percent = (prompt_token_diff / result_without['prompt_tokens'] * 100) if result_without['prompt_tokens'] > 0 else 0
|
||||
|
||||
tool_count_diff = result_with['tool_count'] - result_without['tool_count']
|
||||
|
||||
print(f"\n[输入 Prompt Token 对比]")
|
||||
print(f" 不使用 finish_search: {result_without['prompt_tokens']:,} tokens")
|
||||
print(f" 使用 finish_search: {result_with['prompt_tokens']:,} tokens")
|
||||
print(f" 差异: {prompt_token_diff:+,} tokens ({prompt_token_diff_percent:+.2f}%)")
|
||||
|
||||
print(f"\n[工具数量对比]")
|
||||
print(f" 不使用 finish_search: {result_without['tool_count']} 个工具")
|
||||
print(f" 使用 finish_search: {result_with['tool_count']} 个工具")
|
||||
print(f" 差异: {tool_count_diff:+d} 个工具")
|
||||
|
||||
print(f"\n[工具列表对比]")
|
||||
without_tools = set(result_without['tool_names'])
|
||||
with_tools = set(result_with['tool_names'])
|
||||
only_with = with_tools - without_tools
|
||||
only_without = without_tools - with_tools
|
||||
|
||||
if only_with:
|
||||
print(f" 仅在 '使用 finish_search' 中的工具: {', '.join(only_with)}")
|
||||
if only_without:
|
||||
print(f" 仅在 '不使用 finish_search' 中的工具: {', '.join(only_without)}")
|
||||
if not only_with and not only_without:
|
||||
print(f" 工具列表相同(除了 finish_search)")
|
||||
|
||||
# 显示其他token信息
|
||||
print(f"\n[其他 Token 信息]")
|
||||
print(f" Completion Tokens (不使用 finish_search): {result_without.get('completion_tokens', 0):,}")
|
||||
print(f" Completion Tokens (使用 finish_search): {result_with.get('completion_tokens', 0):,}")
|
||||
print(f" 总 Tokens (不使用 finish_search): {result_without.get('total_tokens', 0):,}")
|
||||
print(f" 总 Tokens (使用 finish_search): {result_with.get('total_tokens', 0):,}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"without_finish_search": result_without,
|
||||
"with_finish_search": result_with,
|
||||
"comparison": {
|
||||
"prompt_token_diff": prompt_token_diff,
|
||||
"prompt_token_diff_percent": prompt_token_diff_percent,
|
||||
"tool_count_diff": tool_count_diff,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="对比使用 finish_search 工具与否的 token 消耗差异"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
default="compare_finish_search",
|
||||
help="测试用的聊天ID(默认: compare_finish_search)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
help="将结果保存到JSON文件(可选)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
||||
initialize_logging(verbose=False)
|
||||
|
||||
# 交互式输入问题
|
||||
print("\n" + "=" * 80)
|
||||
print("finish_search 工具 Token 消耗对比测试工具")
|
||||
print("=" * 80)
|
||||
question = input("\n请输入要查询的问题: ").strip()
|
||||
if not question:
|
||||
print("错误: 问题不能为空")
|
||||
return
|
||||
|
||||
# 连接数据库
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
print(f"错误: 数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 运行对比测试
|
||||
try:
|
||||
result = asyncio.run(
|
||||
compare_prompt_tokens(
|
||||
question=question,
|
||||
chat_id=args.chat_id,
|
||||
)
|
||||
)
|
||||
|
||||
# 如果指定了输出文件,保存结果
|
||||
if args.output:
|
||||
# 将thinking_steps转换为可序列化的格式
|
||||
output_result = result.copy()
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[结果已保存] {args.output}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[中断] 用户中断测试")
|
||||
except Exception as e:
|
||||
logger.error(f"测试失败: {e}", exc_info=True)
|
||||
print(f"\n[错误] 测试失败: {e}")
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -1,476 +0,0 @@
|
|||
"""
|
||||
表达方式评估脚本
|
||||
|
||||
功能:
|
||||
1. 随机读取指定数量的表达方式,获取其situation和style
|
||||
2. 先进行人工评估(逐条手动评估)
|
||||
3. 然后使用LLM进行评估
|
||||
4. 对比人工评估和LLM评估的正确率、精确率、召回率、F1分数等指标(以人工评估为标准)
|
||||
5. 不真正修改数据库,只是做评估
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import db
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expression_evaluator_comparison")
|
||||
|
||||
|
||||
def get_random_expressions(count: int = 10) -> List[Expression]:
|
||||
"""
|
||||
随机读取指定数量的表达方式
|
||||
|
||||
Args:
|
||||
count: 要读取的数量,默认10条
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
# 如果总数少于请求数量,返回所有
|
||||
if len(all_expressions) <= count:
|
||||
logger.info(f"数据库中共有 {len(all_expressions)} 条表达方式,全部返回")
|
||||
return all_expressions
|
||||
|
||||
# 随机选择指定数量
|
||||
selected = random.sample(all_expressions, count)
|
||||
logger.info(f"从 {len(all_expressions)} 条表达方式中随机选择了 {len(selected)} 条")
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机读取表达方式失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
||||
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
|
||||
"""
|
||||
人工评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
index: 当前索引(从1开始)
|
||||
total: 总数
|
||||
|
||||
Returns:
|
||||
评估结果字典,包含:
|
||||
- expression_id: 表达方式ID
|
||||
- situation: 情境
|
||||
- style: 风格
|
||||
- suitable: 是否合适(人工评估)
|
||||
- reason: 评估理由(始终为None)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print(f"人工评估 [{index}/{total}]")
|
||||
print("=" * 60)
|
||||
print(f"Situation: {expression.situation}")
|
||||
print(f"Style: {expression.style}")
|
||||
print("\n请评估该表达方式是否合适:")
|
||||
print(" 输入 'y' 或 'yes' 或 '1' 表示合适(通过)")
|
||||
print(" 输入 'n' 或 'no' 或 '0' 表示不合适(不通过)")
|
||||
print(" 输入 'q' 或 'quit' 退出评估")
|
||||
|
||||
while True:
|
||||
user_input = input("\n您的评估 (y/n/q): ").strip().lower()
|
||||
|
||||
if user_input in ['q', 'quit']:
|
||||
print("退出评估")
|
||||
return None
|
||||
|
||||
if user_input in ['y', 'yes', '1', '是', '通过']:
|
||||
suitable = True
|
||||
break
|
||||
elif user_input in ['n', 'no', '0', '否', '不通过']:
|
||||
suitable = False
|
||||
break
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n/q)")
|
||||
|
||||
result = {
|
||||
"expression_id": expression.id,
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"suitable": suitable,
|
||||
"reason": None,
|
||||
"evaluator": "manual"
|
||||
}
|
||||
|
||||
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
prompt = f"""请评估以下表达方式是否合适:
|
||||
|
||||
情境(situation):{situation}
|
||||
风格(style):{style}
|
||||
|
||||
请从以下方面进行评估:
|
||||
1. 情境描述是否清晰、准确
|
||||
2. 风格表达是否合理、自然
|
||||
3. 情境和风格是否匹配
|
||||
4. 允许部分语法错误出现
|
||||
5. 允许口头化或缺省表达
|
||||
6. 允许部分上下文缺失
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
}}
|
||||
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(expression: Expression, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(expression.situation, expression.style)
|
||||
logger.debug(f"正在评估表达方式 ID: {expression.id}")
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
import re
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果")
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 ID: {expression.id} 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
||||
|
||||
async def evaluate_expression_llm(expression: Expression, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式 ID: {expression.id}")
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(expression, llm)
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
return {
|
||||
"expression_id": expression.id,
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"suitable": suitable,
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm"
|
||||
}
|
||||
|
||||
|
||||
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
|
||||
"""
|
||||
对比人工评估和LLM评估的结果
|
||||
|
||||
Args:
|
||||
manual_results: 人工评估结果列表
|
||||
llm_results: LLM评估结果列表
|
||||
method_name: 评估方法名称(用于标识)
|
||||
|
||||
Returns:
|
||||
对比分析结果字典
|
||||
"""
|
||||
# 按expression_id建立映射
|
||||
llm_dict = {r["expression_id"]: r for r in llm_results}
|
||||
|
||||
total = len(manual_results)
|
||||
matched = 0
|
||||
true_positives = 0
|
||||
true_negatives = 0
|
||||
false_positives = 0
|
||||
false_negatives = 0
|
||||
|
||||
for manual_result in manual_results:
|
||||
llm_result = llm_dict.get(manual_result["expression_id"])
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
manual_suitable = manual_result["suitable"]
|
||||
llm_suitable = llm_result["suitable"]
|
||||
|
||||
if manual_suitable == llm_suitable:
|
||||
matched += 1
|
||||
|
||||
if manual_suitable and llm_suitable:
|
||||
true_positives += 1
|
||||
elif not manual_suitable and not llm_suitable:
|
||||
true_negatives += 1
|
||||
elif not manual_suitable and llm_suitable:
|
||||
false_positives += 1
|
||||
elif manual_suitable and not llm_suitable:
|
||||
false_negatives += 1
|
||||
|
||||
accuracy = (matched / total * 100) if total > 0 else 0
|
||||
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
|
||||
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
|
||||
random_baseline = 50.0
|
||||
accuracy_above_random = accuracy - random_baseline
|
||||
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
|
||||
|
||||
return {
|
||||
"method": method_name,
|
||||
"total": total,
|
||||
"matched": matched,
|
||||
"accuracy": accuracy,
|
||||
"accuracy_above_random": accuracy_above_random,
|
||||
"accuracy_improvement_ratio": accuracy_improvement_ratio,
|
||||
"true_positives": true_positives,
|
||||
"true_negatives": true_negatives,
|
||||
"false_positives": false_positives,
|
||||
"false_negatives": false_negatives,
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1_score,
|
||||
"specificity": specificity
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
logger.info("数据库连接成功")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 1. 随机读取表达方式
|
||||
logger.info("\n步骤1: 随机读取表达方式")
|
||||
expressions = get_random_expressions(10)
|
||||
if not expressions:
|
||||
logger.error("没有可用的表达方式,退出")
|
||||
return
|
||||
logger.info(f"成功读取 {len(expressions)} 条表达方式")
|
||||
|
||||
# 2. 人工评估
|
||||
print("\n" + "=" * 60)
|
||||
print("开始人工评估")
|
||||
print("=" * 60)
|
||||
print(f"共需要评估 {len(expressions)} 条表达方式")
|
||||
print("请逐条进行评估...\n")
|
||||
|
||||
manual_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
manual_result = manual_evaluate_expression(expression, i, len(expressions))
|
||||
if manual_result is None:
|
||||
print("\n评估已中断")
|
||||
return
|
||||
manual_results.append(manual_result)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估完成")
|
||||
print("=" * 60)
|
||||
|
||||
# 3. 创建LLM实例并评估
|
||||
logger.info("\n步骤3: 创建LLM实例")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_comparison"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
logger.info("\n步骤4: 开始LLM评估")
|
||||
llm_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
logger.info(f"LLM评估进度: {i}/{len(expressions)}")
|
||||
llm_results.append(await evaluate_expression_llm(expression, llm))
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 4. 对比分析并输出结果
|
||||
comparison = compare_evaluations(manual_results, llm_results, "LLM评估")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估结果(以人工评估为标准)")
|
||||
print("=" * 60)
|
||||
print("\n评估目标:")
|
||||
print(" 1. 核心能力:将不合适的项目正确提取出来(特定负类召回率)")
|
||||
print(" 2. 次要能力:尽可能少的误删合适的项目(召回率)")
|
||||
|
||||
# 详细评估结果(核心指标优先)
|
||||
print("\n【详细对比】")
|
||||
print(f"\n--- {comparison['method']} ---")
|
||||
print(f" 总数: {comparison['total']} 条")
|
||||
print()
|
||||
print(" 【核心能力指标】")
|
||||
print(f" ⭐ 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
|
||||
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个")
|
||||
print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(f" ⭐ 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
|
||||
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
|
||||
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个")
|
||||
print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(" 【其他指标】")
|
||||
print(f" 准确率: {comparison['accuracy']:.2f}% (整体判断正确率)")
|
||||
print(f" 精确率: {comparison['precision']:.2f}% (判断为合适的项目中,实际合适的比例)")
|
||||
print(f" F1分数: {comparison['f1_score']:.2f} (精确率和召回率的调和平均)")
|
||||
print(f" 匹配数: {comparison['matched']}/{comparison['total']}")
|
||||
print()
|
||||
print(" 【分类统计】")
|
||||
print(f" TP (正确识别为合适): {comparison['true_positives']}")
|
||||
print(f" TN (正确识别为不合适): {comparison['true_negatives']} ⭐")
|
||||
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
|
||||
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
|
||||
|
||||
# 5. 输出人工评估不通过但LLM误判为通过的详细信息
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估不通过但LLM误判为通过的项目(FP - False Positive)")
|
||||
print("=" * 60)
|
||||
|
||||
# 按expression_id建立映射
|
||||
llm_dict = {r["expression_id"]: r for r in llm_results}
|
||||
|
||||
fp_items = []
|
||||
for manual_result in manual_results:
|
||||
llm_result = llm_dict.get(manual_result["expression_id"])
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
# 人工评估不通过,但LLM评估通过(FP情况)
|
||||
if not manual_result["suitable"] and llm_result["suitable"]:
|
||||
fp_items.append({
|
||||
"expression_id": manual_result["expression_id"],
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
|
||||
if fp_items:
|
||||
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
|
||||
for idx, item in enumerate(fp_items, 1):
|
||||
print(f"--- [{idx}] 项目 ID: {item['expression_id']} ---")
|
||||
print(f"Situation: {item['situation']}")
|
||||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 不通过 ❌")
|
||||
print("LLM评估: 通过 ✅ (误判)")
|
||||
if item.get('llm_error'):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误判项目(所有人工评估不通过的项目都被LLM正确识别为不通过)")
|
||||
|
||||
# 6. 保存结果到JSON文件
|
||||
output_file = os.path.join(project_root, "data", "expression_evaluation_comparison.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"manual_results": manual_results,
|
||||
"llm_results": llm_results,
|
||||
"comparison": comparison
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"\n评估结果已保存到: {output_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存结果到文件失败: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估完成")
|
||||
print("=" * 60)
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
|
|
@ -108,7 +107,7 @@ def init_memory_retrieval_prompt():
|
|||
- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点
|
||||
- 先思考当前信息是否足够回答问题
|
||||
- 如果信息不足,则需要使用tool查询信息,你必须给出使用什么工具进行查询
|
||||
- 如果当前已收集的信息足够或信息不足确定无法找到答案,你必须调用finish_search工具结束查询
|
||||
- 如果当前已收集的信息足够或信息不足确定无法找到答案,你必须调用found_answer工具结束查询
|
||||
""",
|
||||
name="memory_retrieval_react_prompt_head",
|
||||
)
|
||||
|
|
@ -312,7 +311,7 @@ async def _react_agent_solve_question(
|
|||
|
||||
return None
|
||||
|
||||
# 正常迭代:使用head_prompt决定调用哪些工具(包含finish_search工具)
|
||||
# 正常迭代:使用head_prompt决定调用哪些工具(包含found_answer工具)
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
# tool_names = [tool_def["name"] for tool_def in tool_definitions]
|
||||
# logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具: {', '.join(tool_names)} (共{len(tool_definitions)}个)")
|
||||
|
|
@ -373,7 +372,7 @@ async def _react_agent_solve_question(
|
|||
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
||||
break
|
||||
|
||||
# 注意:这里会检查finish_search工具调用,如果检测到finish_search工具,会根据found_answer参数决定返回答案或退出查询
|
||||
# 注意:这里会检查found_answer工具调用,如果检测到found_answer工具,会根据answer参数决定返回答案或退出查询
|
||||
|
||||
assistant_message: Optional[Message] = None
|
||||
if tool_calls:
|
||||
|
|
@ -403,81 +402,117 @@ async def _react_agent_solve_question(
|
|||
|
||||
# 处理工具调用
|
||||
if not tool_calls:
|
||||
# 如果没有工具调用,检查响应文本中是否包含finish_search函数调用格式
|
||||
# 如果没有工具调用,检查响应文本中是否包含found_answer函数调用格式或JSON格式
|
||||
if response and response.strip():
|
||||
# 尝试从文本中解析finish_search函数调用
|
||||
def parse_finish_search_from_text(text: str):
|
||||
"""从文本中解析finish_search函数调用,返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
|
||||
# 首先尝试解析JSON格式的found_answer
|
||||
def parse_json_found_answer(text: str):
|
||||
"""从文本中解析JSON格式的found_answer,返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
|
||||
if not text:
|
||||
return None, None
|
||||
|
||||
# 查找finish_search函数调用位置(不区分大小写)
|
||||
func_pattern = "finish_search"
|
||||
try:
|
||||
# 尝试提取JSON对象(可能包含在代码块中或直接是JSON)
|
||||
json_text = text.strip()
|
||||
|
||||
# 如果包含代码块标记,提取JSON部分
|
||||
if "```json" in json_text:
|
||||
start = json_text.find("```json") + 7
|
||||
end = json_text.find("```", start)
|
||||
if end != -1:
|
||||
json_text = json_text[start:end].strip()
|
||||
elif "```" in json_text:
|
||||
start = json_text.find("```") + 3
|
||||
end = json_text.find("```", start)
|
||||
if end != -1:
|
||||
json_text = json_text[start:end].strip()
|
||||
|
||||
# 尝试解析JSON
|
||||
data = json.loads(json_text)
|
||||
|
||||
# 检查是否包含found_answer字段
|
||||
if isinstance(data, dict) and "found_answer" in data:
|
||||
found_answer = bool(data.get("found_answer", False))
|
||||
answer = data.get("answer", "")
|
||||
return found_answer, answer
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
# 如果JSON解析失败,尝试在文本中查找JSON对象
|
||||
try:
|
||||
# 查找第一个 { 和最后一个 } 之间的内容(更健壮的JSON提取)
|
||||
first_brace = text.find('{')
|
||||
if first_brace != -1:
|
||||
# 从第一个 { 开始,找到匹配的 }
|
||||
brace_count = 0
|
||||
json_end = -1
|
||||
for i in range(first_brace, len(text)):
|
||||
if text[i] == '{':
|
||||
brace_count += 1
|
||||
elif text[i] == '}':
|
||||
brace_count -= 1
|
||||
if brace_count == 0:
|
||||
json_end = i + 1
|
||||
break
|
||||
|
||||
if json_end != -1:
|
||||
json_text = text[first_brace:json_end]
|
||||
data = json.loads(json_text)
|
||||
if isinstance(data, dict) and "found_answer" in data:
|
||||
found_answer = bool(data.get("found_answer", False))
|
||||
answer = data.get("answer", "")
|
||||
return found_answer, answer
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return None, None
|
||||
|
||||
# 尝试从文本中解析found_answer函数调用
|
||||
def parse_found_answer_from_text(text: str):
|
||||
"""从文本中解析found_answer函数调用,返回answer字符串,如果未找到则返回None
|
||||
如果answer存在且非空,表示找到答案;如果answer为空或不存在,表示未找到答案"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# 查找found_answer函数调用位置(不区分大小写)
|
||||
func_pattern = "found_answer"
|
||||
text_lower = text.lower()
|
||||
func_pos = text_lower.find(func_pattern)
|
||||
if func_pos == -1:
|
||||
return None, None
|
||||
|
||||
# 查找函数调用的开始和结束位置
|
||||
# 从func_pos开始向后查找左括号
|
||||
start_pos = text.find("(", func_pos)
|
||||
if start_pos == -1:
|
||||
return None, None
|
||||
|
||||
# 查找匹配的右括号(考虑嵌套)
|
||||
paren_count = 0
|
||||
end_pos = start_pos
|
||||
for i in range(start_pos, len(text)):
|
||||
if text[i] == "(":
|
||||
paren_count += 1
|
||||
elif text[i] == ")":
|
||||
paren_count -= 1
|
||||
if paren_count == 0:
|
||||
end_pos = i
|
||||
break
|
||||
else:
|
||||
# 没有找到匹配的右括号
|
||||
return None, None
|
||||
|
||||
# 提取函数参数部分
|
||||
params_text = text[start_pos + 1 : end_pos]
|
||||
|
||||
# 解析found_answer参数(布尔值,可能是true/false/True/False)
|
||||
found_answer = None
|
||||
found_answer_patterns = [
|
||||
r"found_answer\s*=\s*true",
|
||||
r"found_answer\s*=\s*True",
|
||||
r"found_answer\s*=\s*false",
|
||||
r"found_answer\s*=\s*False",
|
||||
]
|
||||
for pattern in found_answer_patterns:
|
||||
match = re.search(pattern, params_text, re.IGNORECASE)
|
||||
if match:
|
||||
found_answer = "true" in match.group(0).lower()
|
||||
break
|
||||
return None
|
||||
|
||||
# 解析answer参数(字符串,使用extract_quoted_content)
|
||||
answer = extract_quoted_content(text, "finish_search", "answer")
|
||||
answer = extract_quoted_content(text, "found_answer", "answer")
|
||||
|
||||
return found_answer, answer
|
||||
# 如果answer存在(即使是空字符串),也返回它(空字符串表示未找到答案)
|
||||
return answer
|
||||
|
||||
parsed_found_answer, parsed_answer = parse_finish_search_from_text(response)
|
||||
# 首先尝试解析JSON格式
|
||||
parsed_found_answer_json, parsed_answer_json = parse_json_found_answer(response)
|
||||
is_json_format = parsed_found_answer_json is not None
|
||||
|
||||
if parsed_found_answer is not None:
|
||||
# 检测到finish_search函数调用格式
|
||||
if parsed_found_answer:
|
||||
# 如果JSON解析成功,使用JSON结果
|
||||
if is_json_format:
|
||||
parsed_answer = parsed_answer_json
|
||||
has_answer = parsed_found_answer_json and parsed_answer and parsed_answer.strip()
|
||||
else:
|
||||
# 如果JSON解析失败,尝试解析函数调用格式
|
||||
parsed_answer = parse_found_answer_from_text(response)
|
||||
# 如果answer存在且非空,表示找到答案;否则表示未找到答案
|
||||
has_answer = parsed_answer is not None and parsed_answer.strip() != ""
|
||||
|
||||
if parsed_answer is not None or is_json_format:
|
||||
# 检测到found_answer格式(可能是JSON格式或函数调用格式)
|
||||
format_type = "JSON格式" if is_json_format else "函数调用格式"
|
||||
if has_answer:
|
||||
# 找到了答案
|
||||
if parsed_answer:
|
||||
step["actions"].append(
|
||||
{
|
||||
"action_type": "finish_search",
|
||||
"action_params": {"found_answer": True, "answer": parsed_answer},
|
||||
"action_type": "found_answer",
|
||||
"action_params": {"answer": parsed_answer},
|
||||
}
|
||||
)
|
||||
step["observations"] = ["检测到finish_search文本格式调用,找到答案"]
|
||||
step["observations"] = [f"检测到found_answer{format_type}调用,找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过found_answer{format_type}找到关于问题{question}的答案: {parsed_answer[:100]}..."
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
|
|
@ -487,31 +522,26 @@ async def _react_agent_solve_question(
|
|||
)
|
||||
|
||||
return True, parsed_answer, thinking_steps, False
|
||||
else:
|
||||
# found_answer为True但没有提供answer,视为错误,继续迭代
|
||||
logger.warning(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
|
||||
)
|
||||
else:
|
||||
# 未找到答案,直接退出查询
|
||||
step["actions"].append(
|
||||
{"action_type": "finish_search", "action_params": {"found_answer": False}}
|
||||
{"action_type": "found_answer", "action_params": {"answer": ""}}
|
||||
)
|
||||
step["observations"] = ["检测到finish_search文本格式调用,未找到答案"]
|
||||
step["observations"] = [f"检测到found_answer{format_type}调用,未找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案"
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过found_answer{format_type}判断未找到答案"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:通过finish_search文本格式判断未找到答案",
|
||||
final_status="未找到答案:通过found_answer文本格式判断未找到答案",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
# 如果没有检测到finish_search格式,记录思考过程,继续下一轮迭代
|
||||
# 如果没有检测到found_answer格式,记录思考过程,继续下一轮迭代
|
||||
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}"
|
||||
|
|
@ -525,62 +555,55 @@ async def _react_agent_solve_question(
|
|||
continue
|
||||
|
||||
# 处理工具调用
|
||||
# 首先检查是否有finish_search工具调用,如果有则立即返回,不再处理其他工具
|
||||
finish_search_found = None
|
||||
finish_search_answer = None
|
||||
# 首先检查是否有found_answer工具调用,如果有则立即返回,不再处理其他工具
|
||||
found_answer_answer = None
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
|
||||
if tool_name == "finish_search":
|
||||
finish_search_found = tool_args.get("found_answer", False)
|
||||
finish_search_answer = tool_args.get("answer", "")
|
||||
if tool_name == "found_answer":
|
||||
found_answer_answer = tool_args.get("answer", "")
|
||||
|
||||
if finish_search_found:
|
||||
# 如果answer存在且非空,表示找到答案;否则表示未找到答案
|
||||
if found_answer_answer and found_answer_answer.strip():
|
||||
# 找到了答案
|
||||
if finish_search_answer:
|
||||
step["actions"].append(
|
||||
{
|
||||
"action_type": "finish_search",
|
||||
"action_params": {"found_answer": True, "answer": finish_search_answer},
|
||||
"action_type": "found_answer",
|
||||
"action_params": {"answer": found_answer_answer},
|
||||
}
|
||||
)
|
||||
step["observations"] = ["检测到finish_search工具调用,找到答案"]
|
||||
step["observations"] = ["检测到found_answer工具调用,找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过found_answer工具找到关于问题{question}的答案: {found_answer_answer}"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"找到答案:{finish_search_answer}",
|
||||
final_status=f"找到答案:{found_answer_answer}",
|
||||
)
|
||||
|
||||
return True, finish_search_answer, thinking_steps, False
|
||||
else:
|
||||
# found_answer为True但没有提供answer,视为错误
|
||||
logger.warning(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
|
||||
)
|
||||
return True, found_answer_answer, thinking_steps, False
|
||||
else:
|
||||
# 未找到答案,直接退出查询
|
||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
||||
step["observations"] = ["检测到finish_search工具调用,未找到答案"]
|
||||
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": ""}})
|
||||
step["observations"] = ["检测到found_answer工具调用,未找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案"
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过found_answer工具判断未找到答案"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:通过finish_search工具判断未找到答案",
|
||||
final_status="未找到答案:通过found_answer工具判断未找到答案",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
# 如果没有finish_search工具调用,继续处理其他工具
|
||||
# 如果没有found_answer工具调用,继续处理其他工具
|
||||
tool_tasks = []
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
tool_name = tool_call.func_name
|
||||
|
|
@ -590,8 +613,8 @@ async def _react_agent_solve_question(
|
|||
f"{react_log_prefix}第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||
)
|
||||
|
||||
# 跳过finish_search工具调用(已经在上面处理过了)
|
||||
if tool_name == "finish_search":
|
||||
# 跳过found_answer工具调用(已经在上面处理过了)
|
||||
if tool_name == "found_answer":
|
||||
continue
|
||||
|
||||
# 记录最后一次使用的工具名称(用于判断是否需要额外迭代)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from .query_chat_history import register_tool as register_query_chat_history
|
|||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_person_info import register_tool as register_query_person_info
|
||||
from .query_words import register_tool as register_query_words
|
||||
from .found_answer import register_tool as register_finish_search
|
||||
from .found_answer import register_tool as register_found_answer
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ def init_all_tools():
|
|||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
register_query_words() # 注册query_words工具
|
||||
register_finish_search() # 注册finish_search工具
|
||||
register_found_answer() # 注册found_answer工具
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
register_lpmm_knowledge()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
finish_search工具 - 用于在记忆检索过程中结束查询
|
||||
found_answer工具 - 用于在记忆检索过程中结束查询
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
|
@ -8,17 +8,16 @@ from .tool_registry import register_memory_retrieval_tool
|
|||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def finish_search(found_answer: bool, answer: str = "") -> str:
|
||||
async def found_answer(answer: str = "") -> str:
|
||||
"""结束查询
|
||||
|
||||
Args:
|
||||
found_answer: 是否找到了答案
|
||||
answer: 如果找到了答案,提供答案内容;如果未找到,可以为空
|
||||
answer: 如果找到了答案,提供答案内容;如果未找到答案,可以为空或不提供此参数
|
||||
|
||||
Returns:
|
||||
str: 确认信息
|
||||
"""
|
||||
if found_answer:
|
||||
if answer and answer.strip():
|
||||
logger.info(f"找到答案: {answer}")
|
||||
return f"已确认找到答案: {answer}"
|
||||
else:
|
||||
|
|
@ -27,23 +26,17 @@ async def finish_search(found_answer: bool, answer: str = "") -> str:
|
|||
|
||||
|
||||
def register_tool():
|
||||
"""注册finish_search工具"""
|
||||
"""注册found_answer工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="finish_search",
|
||||
description="当你决定结束查询时,调用此工具。如果找到了明确答案,设置found_answer为true并在answer中提供答案;如果未找到答案,设置found_answer为false。只有在检索到明确、具体的答案时才设置found_answer为true,不要编造信息。",
|
||||
name="found_answer",
|
||||
description="当你决定结束查询时,调用此工具。如果找到了明确答案,在answer参数中提供答案内容;如果未找到答案,可以不提供answer参数或提供空字符串。只有在检索到明确、具体的答案时才提供answer,不要编造信息。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "found_answer",
|
||||
"type": "boolean",
|
||||
"description": "是否找到了答案",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "answer",
|
||||
"type": "string",
|
||||
"description": "如果found_answer为true,提供找到的答案内容,必须基于已收集的信息,不要编造;如果found_answer为false,可以为空",
|
||||
"description": "如果找到了答案,提供找到的答案内容,必须基于已收集的信息,不要编造;如果未找到答案,可以不提供此参数或提供空字符串",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=finish_search,
|
||||
execute_func=found_answer,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue