feat:优化记忆检索和停止

pull/1481/head
SengokuCola 2025-12-31 19:34:33 +08:00
parent c5276ce629
commit 57b92ca124
5 changed files with 163 additions and 1130 deletions

View File

@ -1,507 +0,0 @@
import argparse
import asyncio
import os
import sys
import time
import json
import importlib
from typing import Dict, Any
from datetime import datetime
# 强制使用 utf-8避免控制台编码报错
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
# 确保能导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db
from src.common.database.database_model import LLMUsage
logger = get_logger("compare_finish_search_token")
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况
Args:
start_time: 开始时间戳
Returns:
包含token使用统计的字典
"""
try:
start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录
records = (
LLMUsage.select()
.where(
(LLMUsage.timestamp >= start_datetime)
& (
(LLMUsage.request_type.like("%memory%"))
| (LLMUsage.request_type == "memory.question")
| (LLMUsage.request_type == "memory.react")
| (LLMUsage.request_type == "memory.react.final")
)
)
.order_by(LLMUsage.timestamp.asc())
)
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
total_cost = 0.0
request_count = 0
model_usage = {} # 按模型统计
for record in records:
total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0
request_count += 1
# 按模型统计
model_name = record.model_name or "unknown"
if model_name not in model_usage:
model_usage[model_name] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost": 0.0,
"request_count": 0,
}
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1
return {
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"total_cost": total_cost,
"request_count": request_count,
"model_usage": model_usage,
}
except Exception as e:
logger.error(f"获取token使用情况失败: {e}")
return {
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"total_tokens": 0,
"total_cost": 0.0,
"request_count": 0,
"model_usage": {},
}
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try:
# 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
)
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith('src.memory_system.') and key != 'src.memory_system':
keys_to_remove.append(key)
for key in keys_to_remove:
try:
del sys.modules[key]
except KeyError:
pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try:
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
try:
import src.chat.replyer.group_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
try:
import src.chat.replyer.private_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name)
return (
memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question,
)
except (ImportError, AttributeError) as e:
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
raise
def _init_tools_without_finish_search():
"""初始化工具但不注册 finish_search"""
from src.memory_system.retrieval_tools import (
register_query_chat_history,
register_query_person_info,
register_query_words,
)
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
from src.config.config import global_config
# 清空工具注册器
tool_registry = get_tool_registry()
tool_registry.tools.clear()
# 注册除 finish_search 外的所有工具
register_query_chat_history()
register_query_person_info()
register_query_words()
# 如果启用 LPMM agent 模式,也注册 LPMM 工具
if global_config.lpmm_knowledge.lpmm_mode == "agent":
from src.memory_system.retrieval_tools.query_lpmm_knowledge import register_tool as register_lpmm_knowledge
register_lpmm_knowledge()
logger.info("已初始化工具(不包含 finish_search")
def _init_tools_with_finish_search():
"""初始化工具并注册 finish_search"""
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
from src.memory_system.retrieval_tools import init_all_tools
# 清空工具注册器
tool_registry = get_tool_registry()
tool_registry.tools.clear()
# 初始化所有工具(包括 finish_search
init_all_tools()
logger.info("已初始化工具(包含 finish_search")
async def get_prompt_tokens_for_tools(
question: str,
chat_id: str,
use_finish_search: bool,
) -> Dict[str, Any]:
"""获取使用不同工具配置时的prompt token消耗
Args:
question: 要查询的问题
chat_id: 聊天ID
use_finish_search: 是否使用 finish_search 工具
Returns:
包含prompt token信息的字典
"""
# 先初始化 prompt如果还未初始化
# 注意init_memory_retrieval_prompt 会调用 init_all_tools所以我们需要在它之后重新设置工具
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt, _ = _import_memory_retrieval()
init_memory_retrieval_prompt()
# 初始化工具(根据参数决定是否包含 finish_search
# 必须在 init_memory_retrieval_prompt 之后调用,因为它会调用 init_all_tools
if use_finish_search:
_init_tools_with_finish_search()
else:
_init_tools_without_finish_search()
# 获取工具注册器
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
tool_registry = get_tool_registry()
tool_definitions = tool_registry.get_tool_definitions()
# 验证工具列表(调试用)
tool_names = [tool["name"] for tool in tool_definitions]
if use_finish_search:
if "finish_search" not in tool_names:
logger.warning("期望包含 finish_search 工具,但工具列表中未找到")
else:
if "finish_search" in tool_names:
logger.warning("期望不包含 finish_search 工具,但工具列表中找到了,将移除")
# 移除 finish_search 工具
tool_registry.tools.pop("finish_search", None)
tool_definitions = tool_registry.get_tool_definitions()
tool_names = [tool["name"] for tool in tool_definitions]
# 构建第一次调用的prompt模拟_react_agent_solve_question的第一次调用
from src.config.config import global_config
bot_name = global_config.bot.nickname
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# 构建head_prompt
head_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_prompt_head",
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info="",
current_iteration=1,
remaining_iterations=global_config.memory.max_agent_iterations - 1,
max_iterations=global_config.memory.max_agent_iterations,
)
# 构建消息列表只包含system message模拟第一次调用
from src.llm_models.payload_content.message import MessageBuilder, RoleType
messages = []
system_builder = MessageBuilder()
system_builder.set_role(RoleType.System)
system_builder.add_text_content(head_prompt)
messages.append(system_builder.build())
# 调用LLM API来计算token只调用一次不实际执行
from src.llm_models.utils_model import LLMRequest, RequestType
from src.config.config import model_config
# 创建LLM请求对象
llm_request = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="memory.react.compare")
# 构建工具选项
tool_built = llm_request._build_tool_options(tool_definitions)
# 直接调用 _execute_request 以获取完整的响应对象(包含 usage
response, model_info = await llm_request._execute_request(
request_type=RequestType.RESPONSE,
message_factory=lambda _client, *, _messages=messages: _messages,
temperature=None,
max_tokens=None,
tool_options=tool_built,
)
# 从响应中获取token使用情况
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
if response and hasattr(response, 'usage') and response.usage:
prompt_tokens = response.usage.prompt_tokens or 0
completion_tokens = response.usage.completion_tokens or 0
total_tokens = response.usage.total_tokens or 0
return {
"use_finish_search": use_finish_search,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"tool_count": len(tool_definitions),
"tool_names": [tool["name"] for tool in tool_definitions],
}
async def compare_prompt_tokens(
question: str,
chat_id: str = "compare_finish_search",
) -> Dict[str, Any]:
"""对比使用 finish_search 工具与否的输入 token 差异
只运行一次只计算输入 token 的差异确保除了工具定义外其他内容一致
Args:
question: 要查询的问题
chat_id: 聊天ID
Returns:
包含对比结果的字典
"""
print("\n" + "=" * 80)
print("finish_search 工具 输入 Token 消耗对比测试")
print("=" * 80)
print(f"\n[测试问题] {question}")
print(f"[聊天ID] {chat_id}")
print("\n注意: 只对比第一次LLM调用的输入token差异不运行完整迭代流程")
# 第一次测试:不使用 finish_search
print("\n" + "-" * 80)
print("[测试 1/2] 不使用 finish_search 工具")
print("-" * 80)
result_without = await get_prompt_tokens_for_tools(
question=question,
chat_id=f"{chat_id}_without",
use_finish_search=False,
)
print(f"\n[结果]")
print(f" 工具数量: {result_without['tool_count']}")
print(f" 工具列表: {', '.join(result_without['tool_names'])}")
print(f" 输入 Prompt Tokens: {result_without['prompt_tokens']:,}")
# 等待一下,确保数据库记录已写入
await asyncio.sleep(1)
# 第二次测试:使用 finish_search
print("\n" + "-" * 80)
print("[测试 2/2] 使用 finish_search 工具")
print("-" * 80)
result_with = await get_prompt_tokens_for_tools(
question=question,
chat_id=f"{chat_id}_with",
use_finish_search=True,
)
print(f"\n[结果]")
print(f" 工具数量: {result_with['tool_count']}")
print(f" 工具列表: {', '.join(result_with['tool_names'])}")
print(f" 输入 Prompt Tokens: {result_with['prompt_tokens']:,}")
# 对比结果
print("\n" + "=" * 80)
print("[对比结果]")
print("=" * 80)
prompt_token_diff = result_with['prompt_tokens'] - result_without['prompt_tokens']
prompt_token_diff_percent = (prompt_token_diff / result_without['prompt_tokens'] * 100) if result_without['prompt_tokens'] > 0 else 0
tool_count_diff = result_with['tool_count'] - result_without['tool_count']
print(f"\n[输入 Prompt Token 对比]")
print(f" 不使用 finish_search: {result_without['prompt_tokens']:,} tokens")
print(f" 使用 finish_search: {result_with['prompt_tokens']:,} tokens")
print(f" 差异: {prompt_token_diff:+,} tokens ({prompt_token_diff_percent:+.2f}%)")
print(f"\n[工具数量对比]")
print(f" 不使用 finish_search: {result_without['tool_count']} 个工具")
print(f" 使用 finish_search: {result_with['tool_count']} 个工具")
print(f" 差异: {tool_count_diff:+d} 个工具")
print(f"\n[工具列表对比]")
without_tools = set(result_without['tool_names'])
with_tools = set(result_with['tool_names'])
only_with = with_tools - without_tools
only_without = without_tools - with_tools
if only_with:
print(f" 仅在 '使用 finish_search' 中的工具: {', '.join(only_with)}")
if only_without:
print(f" 仅在 '不使用 finish_search' 中的工具: {', '.join(only_without)}")
if not only_with and not only_without:
print(f" 工具列表相同(除了 finish_search")
# 显示其他token信息
print(f"\n[其他 Token 信息]")
print(f" Completion Tokens (不使用 finish_search): {result_without.get('completion_tokens', 0):,}")
print(f" Completion Tokens (使用 finish_search): {result_with.get('completion_tokens', 0):,}")
print(f" 总 Tokens (不使用 finish_search): {result_without.get('total_tokens', 0):,}")
print(f" 总 Tokens (使用 finish_search): {result_with.get('total_tokens', 0):,}")
print("\n" + "=" * 80)
return {
"question": question,
"without_finish_search": result_without,
"with_finish_search": result_with,
"comparison": {
"prompt_token_diff": prompt_token_diff,
"prompt_token_diff_percent": prompt_token_diff_percent,
"tool_count_diff": tool_count_diff,
},
}
def main() -> None:
parser = argparse.ArgumentParser(
description="对比使用 finish_search 工具与否的 token 消耗差异"
)
parser.add_argument(
"--chat-id",
default="compare_finish_search",
help="测试用的聊天ID默认: compare_finish_search",
)
parser.add_argument(
"--output",
"-o",
help="将结果保存到JSON文件可选",
)
args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False)
# 交互式输入问题
print("\n" + "=" * 80)
print("finish_search 工具 Token 消耗对比测试工具")
print("=" * 80)
question = input("\n请输入要查询的问题: ").strip()
if not question:
print("错误: 问题不能为空")
return
# 连接数据库
try:
db.connect(reuse_if_open=True)
except Exception as e:
logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}")
return
# 运行对比测试
try:
result = asyncio.run(
compare_prompt_tokens(
question=question,
chat_id=args.chat_id,
)
)
# 如果指定了输出文件,保存结果
if args.output:
# 将thinking_steps转换为可序列化的格式
output_result = result.copy()
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试")
except Exception as e:
logger.error(f"测试失败: {e}", exc_info=True)
print(f"\n[错误] 测试失败: {e}")
finally:
try:
db.close()
except Exception:
pass
if __name__ == "__main__":
main()

View File

@ -1,476 +0,0 @@
"""
表达方式评估脚本
功能
1. 随机读取指定数量的表达方式获取其situation和style
2. 先进行人工评估逐条手动评估
3. 然后使用LLM进行评估
4. 对比人工评估和LLM评估的正确率精确率召回率F1分数等指标以人工评估为标准
5. 不真正修改数据库只是做评估
"""
import asyncio
import random
import json
import sys
import os
from typing import List, Dict
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression
from src.common.database.database import db
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.common.logger import get_logger
logger = get_logger("expression_evaluator_comparison")
def get_random_expressions(count: int = 10) -> List[Expression]:
"""
随机读取指定数量的表达方式
Args:
count: 要读取的数量默认10条
Returns:
表达方式列表
"""
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 如果总数少于请求数量,返回所有
if len(all_expressions) <= count:
logger.info(f"数据库中共有 {len(all_expressions)} 条表达方式,全部返回")
return all_expressions
# 随机选择指定数量
selected = random.sample(all_expressions, count)
logger.info(f"{len(all_expressions)} 条表达方式中随机选择了 {len(selected)}")
return selected
except Exception as e:
logger.error(f"随机读取表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
"""
人工评估单个表达方式
Args:
expression: 表达方式对象
index: 当前索引从1开始
total: 总数
Returns:
评估结果字典包含
- expression_id: 表达方式ID
- situation: 情境
- style: 风格
- suitable: 是否合适人工评估
- reason: 评估理由始终为None
"""
print("\n" + "=" * 60)
print(f"人工评估 [{index}/{total}]")
print("=" * 60)
print(f"Situation: {expression.situation}")
print(f"Style: {expression.style}")
print("\n请评估该表达方式是否合适:")
print(" 输入 'y''yes''1' 表示合适(通过)")
print(" 输入 'n''no''0' 表示不合适(不通过)")
print(" 输入 'q''quit' 退出评估")
while True:
user_input = input("\n您的评估 (y/n/q): ").strip().lower()
if user_input in ['q', 'quit']:
print("退出评估")
return None
if user_input in ['y', 'yes', '1', '', '通过']:
suitable = True
break
elif user_input in ['n', 'no', '0', '', '不通过']:
suitable = False
break
else:
print("输入无效,请重新输入 (y/n/q)")
result = {
"expression_id": expression.id,
"situation": expression.situation,
"style": expression.style,
"suitable": suitable,
"reason": None,
"evaluator": "manual"
}
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
return result
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
prompt = f"""请评估以下表达方式是否合适:
情境situation{situation}
风格style{style}
请从以下方面进行评估
1. 情境描述是否清晰准确
2. 风格表达是否合理自然
3. 情境和风格是否匹配
4. 允许部分语法错误出现
5. 允许口头化或缺省表达
6. 允许部分上下文缺失
请以JSON格式输出评估结果
{{
"suitable": true/false,
"reason": "评估理由(如果不合适,请说明原因)"
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容"""
return prompt
async def _single_llm_evaluation(expression: Expression, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
expression: 表达方式对象
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(expression.situation, expression.style)
logger.debug(f"正在评估表达方式 ID: {expression.id}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果")
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 ID: {expression.id} 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
async def evaluate_expression_llm(expression: Expression, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
expression: 表达方式对象
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式 ID: {expression.id}")
suitable, reason, error = await _single_llm_evaluation(expression, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"expression_id": expression.id,
"situation": expression.situation,
"style": expression.style,
"suitable": suitable,
"reason": reason,
"error": error,
"evaluator": "llm"
}
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
"""
对比人工评估和LLM评估的结果
Args:
manual_results: 人工评估结果列表
llm_results: LLM评估结果列表
method_name: 评估方法名称用于标识
Returns:
对比分析结果字典
"""
# 按expression_id建立映射
llm_dict = {r["expression_id"]: r for r in llm_results}
total = len(manual_results)
matched = 0
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0
for manual_result in manual_results:
llm_result = llm_dict.get(manual_result["expression_id"])
if llm_result is None:
continue
manual_suitable = manual_result["suitable"]
llm_suitable = llm_result["suitable"]
if manual_suitable == llm_suitable:
matched += 1
if manual_suitable and llm_suitable:
true_positives += 1
elif not manual_suitable and not llm_suitable:
true_negatives += 1
elif not manual_suitable and llm_suitable:
false_positives += 1
elif manual_suitable and not llm_suitable:
false_negatives += 1
accuracy = (matched / total * 100) if total > 0 else 0
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
random_baseline = 50.0
accuracy_above_random = accuracy - random_baseline
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
return {
"method": method_name,
"total": total,
"matched": matched,
"accuracy": accuracy,
"accuracy_above_random": accuracy_above_random,
"accuracy_improvement_ratio": accuracy_improvement_ratio,
"true_positives": true_positives,
"true_negatives": true_negatives,
"false_positives": false_positives,
"false_negatives": false_negatives,
"precision": precision,
"recall": recall,
"f1_score": f1_score,
"specificity": specificity
}
async def main():
"""主函数"""
logger.info("=" * 60)
logger.info("开始表达方式评估")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
logger.info("数据库连接成功")
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 1. 随机读取表达方式
logger.info("\n步骤1: 随机读取表达方式")
expressions = get_random_expressions(10)
if not expressions:
logger.error("没有可用的表达方式,退出")
return
logger.info(f"成功读取 {len(expressions)} 条表达方式")
# 2. 人工评估
print("\n" + "=" * 60)
print("开始人工评估")
print("=" * 60)
print(f"共需要评估 {len(expressions)} 条表达方式")
print("请逐条进行评估...\n")
manual_results = []
for i, expression in enumerate(expressions, 1):
manual_result = manual_evaluate_expression(expression, i, len(expressions))
if manual_result is None:
print("\n评估已中断")
return
manual_results.append(manual_result)
print("\n" + "=" * 60)
print("人工评估完成")
print("=" * 60)
# 3. 创建LLM实例并评估
logger.info("\n步骤3: 创建LLM实例")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_comparison"
)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
return
logger.info("\n步骤4: 开始LLM评估")
llm_results = []
for i, expression in enumerate(expressions, 1):
logger.info(f"LLM评估进度: {i}/{len(expressions)}")
llm_results.append(await evaluate_expression_llm(expression, llm))
await asyncio.sleep(0.3)
# 4. 对比分析并输出结果
comparison = compare_evaluations(manual_results, llm_results, "LLM评估")
print("\n" + "=" * 60)
print("评估结果(以人工评估为标准)")
print("=" * 60)
print("\n评估目标:")
print(" 1. 核心能力:将不合适的项目正确提取出来(特定负类召回率)")
print(" 2. 次要能力:尽可能少的误删合适的项目(召回率)")
# 详细评估结果(核心指标优先)
print("\n【详细对比】")
print(f"\n--- {comparison['method']} ---")
print(f" 总数: {comparison['total']}")
print()
print(" 【核心能力指标】")
print(f" ⭐ 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}")
print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
print()
print(f" ⭐ 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}")
print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
print()
print(" 【其他指标】")
print(f" 准确率: {comparison['accuracy']:.2f}% (整体判断正确率)")
print(f" 精确率: {comparison['precision']:.2f}% (判断为合适的项目中,实际合适的比例)")
print(f" F1分数: {comparison['f1_score']:.2f} (精确率和召回率的调和平均)")
print(f" 匹配数: {comparison['matched']}/{comparison['total']}")
print()
print(" 【分类统计】")
print(f" TP (正确识别为合适): {comparison['true_positives']}")
print(f" TN (正确识别为不合适): {comparison['true_negatives']}")
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
# 5. 输出人工评估不通过但LLM误判为通过的详细信息
print("\n" + "=" * 60)
print("人工评估不通过但LLM误判为通过的项目FP - False Positive")
print("=" * 60)
# 按expression_id建立映射
llm_dict = {r["expression_id"]: r for r in llm_results}
fp_items = []
for manual_result in manual_results:
llm_result = llm_dict.get(manual_result["expression_id"])
if llm_result is None:
continue
# 人工评估不通过但LLM评估通过FP情况
if not manual_result["suitable"] and llm_result["suitable"]:
fp_items.append({
"expression_id": manual_result["expression_id"],
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
if fp_items:
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
for idx, item in enumerate(fp_items, 1):
print(f"--- [{idx}] 项目 ID: {item['expression_id']} ---")
print(f"Situation: {item['situation']}")
print(f"Style: {item['style']}")
print("人工评估: 不通过 ❌")
print("LLM评估: 通过 ✅ (误判)")
if item.get('llm_error'):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误判项目所有人工评估不通过的项目都被LLM正确识别为不通过")
# 6. 保存结果到JSON文件
output_file = os.path.join(project_root, "data", "expression_evaluation_comparison.json")
try:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"manual_results": manual_results,
"llm_results": llm_results,
"comparison": comparison
}, f, ensure_ascii=False, indent=2)
logger.info(f"\n评估结果已保存到: {output_file}")
except Exception as e:
logger.warning(f"保存结果到文件失败: {e}")
print("\n" + "=" * 60)
print("评估完成")
print("=" * 60)
# 关闭数据库连接
try:
db.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.warning(f"关闭数据库连接时出错: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,7 +1,6 @@
import time import time
import json import json
import asyncio import asyncio
import re
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
@ -108,7 +107,7 @@ def init_memory_retrieval_prompt():
- 你可以对查询思路给出简短的思考思考要简短直接切入要点 - 你可以对查询思路给出简短的思考思考要简短直接切入要点
- 先思考当前信息是否足够回答问题 - 先思考当前信息是否足够回答问题
- 如果信息不足则需要使用tool查询信息你必须给出使用什么工具进行查询 - 如果信息不足则需要使用tool查询信息你必须给出使用什么工具进行查询
- 如果当前已收集的信息足够或信息不足确定无法找到答案你必须调用finish_search工具结束查询 - 如果当前已收集的信息足够或信息不足确定无法找到答案你必须调用found_answer工具结束查询
""", """,
name="memory_retrieval_react_prompt_head", name="memory_retrieval_react_prompt_head",
) )
@ -312,7 +311,7 @@ async def _react_agent_solve_question(
return None return None
# 正常迭代使用head_prompt决定调用哪些工具包含finish_search工具) # 正常迭代使用head_prompt决定调用哪些工具包含found_answer工具)
tool_definitions = tool_registry.get_tool_definitions() tool_definitions = tool_registry.get_tool_definitions()
# tool_names = [tool_def["name"] for tool_def in tool_definitions] # tool_names = [tool_def["name"] for tool_def in tool_definitions]
# logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具: {', '.join(tool_names)} (共{len(tool_definitions)}个)") # logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具: {', '.join(tool_names)} (共{len(tool_definitions)}个)")
@ -373,7 +372,7 @@ async def _react_agent_solve_question(
logger.error(f"ReAct Agent LLM调用失败: {response}") logger.error(f"ReAct Agent LLM调用失败: {response}")
break break
# 注意这里会检查finish_search工具调用如果检测到finish_search工具会根据found_answer参数决定返回答案或退出查询 # 注意这里会检查found_answer工具调用如果检测到found_answer工具会根据answer参数决定返回答案或退出查询
assistant_message: Optional[Message] = None assistant_message: Optional[Message] = None
if tool_calls: if tool_calls:
@ -403,115 +402,146 @@ async def _react_agent_solve_question(
# 处理工具调用 # 处理工具调用
if not tool_calls: if not tool_calls:
# 如果没有工具调用检查响应文本中是否包含finish_search函数调用格式 # 如果没有工具调用检查响应文本中是否包含found_answer函数调用格式或JSON格式
if response and response.strip(): if response and response.strip():
# 尝试从文本中解析finish_search函数调用 # 首先尝试解析JSON格式的found_answer
def parse_finish_search_from_text(text: str): def parse_json_found_answer(text: str):
"""从文本中解析finish_search函数调用,返回(found_answer, answer)元组,如果未找到则返回(None, None)""" """从文本中解析JSON格式的found_answer,返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
if not text: if not text:
return None, None return None, None
try:
# 尝试提取JSON对象可能包含在代码块中或直接是JSON
json_text = text.strip()
# 如果包含代码块标记提取JSON部分
if "```json" in json_text:
start = json_text.find("```json") + 7
end = json_text.find("```", start)
if end != -1:
json_text = json_text[start:end].strip()
elif "```" in json_text:
start = json_text.find("```") + 3
end = json_text.find("```", start)
if end != -1:
json_text = json_text[start:end].strip()
# 尝试解析JSON
data = json.loads(json_text)
# 检查是否包含found_answer字段
if isinstance(data, dict) and "found_answer" in data:
found_answer = bool(data.get("found_answer", False))
answer = data.get("answer", "")
return found_answer, answer
except (json.JSONDecodeError, ValueError, TypeError):
# 如果JSON解析失败尝试在文本中查找JSON对象
try:
# 查找第一个 { 和最后一个 } 之间的内容更健壮的JSON提取
first_brace = text.find('{')
if first_brace != -1:
# 从第一个 { 开始,找到匹配的 }
brace_count = 0
json_end = -1
for i in range(first_brace, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
if brace_count == 0:
json_end = i + 1
break
if json_end != -1:
json_text = text[first_brace:json_end]
data = json.loads(json_text)
if isinstance(data, dict) and "found_answer" in data:
found_answer = bool(data.get("found_answer", False))
answer = data.get("answer", "")
return found_answer, answer
except (json.JSONDecodeError, ValueError, TypeError):
pass
return None, None
# 尝试从文本中解析found_answer函数调用
def parse_found_answer_from_text(text: str):
"""从文本中解析found_answer函数调用返回answer字符串如果未找到则返回None
如果answer存在且非空表示找到答案如果answer为空或不存在表示未找到答案"""
if not text:
return None
# 查找finish_search函数调用位置不区分大小写 # 查找found_answer函数调用位置(不区分大小写)
func_pattern = "finish_search" func_pattern = "found_answer"
text_lower = text.lower() text_lower = text.lower()
func_pos = text_lower.find(func_pattern) func_pos = text_lower.find(func_pattern)
if func_pos == -1: if func_pos == -1:
return None, None return None
# 查找函数调用的开始和结束位置
# 从func_pos开始向后查找左括号
start_pos = text.find("(", func_pos)
if start_pos == -1:
return None, None
# 查找匹配的右括号(考虑嵌套)
paren_count = 0
end_pos = start_pos
for i in range(start_pos, len(text)):
if text[i] == "(":
paren_count += 1
elif text[i] == ")":
paren_count -= 1
if paren_count == 0:
end_pos = i
break
else:
# 没有找到匹配的右括号
return None, None
# 提取函数参数部分
params_text = text[start_pos + 1 : end_pos]
# 解析found_answer参数布尔值可能是true/false/True/False
found_answer = None
found_answer_patterns = [
r"found_answer\s*=\s*true",
r"found_answer\s*=\s*True",
r"found_answer\s*=\s*false",
r"found_answer\s*=\s*False",
]
for pattern in found_answer_patterns:
match = re.search(pattern, params_text, re.IGNORECASE)
if match:
found_answer = "true" in match.group(0).lower()
break
# 解析answer参数字符串使用extract_quoted_content # 解析answer参数字符串使用extract_quoted_content
answer = extract_quoted_content(text, "finish_search", "answer") answer = extract_quoted_content(text, "found_answer", "answer")
# 如果answer存在即使是空字符串也返回它空字符串表示未找到答案
return answer
return found_answer, answer # 首先尝试解析JSON格式
parsed_found_answer_json, parsed_answer_json = parse_json_found_answer(response)
is_json_format = parsed_found_answer_json is not None
# 如果JSON解析成功使用JSON结果
if is_json_format:
parsed_answer = parsed_answer_json
has_answer = parsed_found_answer_json and parsed_answer and parsed_answer.strip()
else:
# 如果JSON解析失败尝试解析函数调用格式
parsed_answer = parse_found_answer_from_text(response)
# 如果answer存在且非空表示找到答案否则表示未找到答案
has_answer = parsed_answer is not None and parsed_answer.strip() != ""
parsed_found_answer, parsed_answer = parse_finish_search_from_text(response) if parsed_answer is not None or is_json_format:
# 检测到found_answer格式可能是JSON格式或函数调用格式
if parsed_found_answer is not None: format_type = "JSON格式" if is_json_format else "函数调用格式"
# 检测到finish_search函数调用格式 if has_answer:
if parsed_found_answer:
# 找到了答案 # 找到了答案
if parsed_answer:
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": parsed_answer},
}
)
step["observations"] = ["检测到finish_search文本格式调用找到答案"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{parsed_answer}",
)
return True, parsed_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误继续迭代
logger.warning(
f"{react_log_prefix}{iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append( step["actions"].append(
{"action_type": "finish_search", "action_params": {"found_answer": False}} {
"action_type": "found_answer",
"action_params": {"answer": parsed_answer},
}
) )
step["observations"] = ["检测到finish_search文本格式调用未找到答案"] step["observations"] = [f"检测到found_answer{format_type}调用,找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info( logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案" f"{react_log_prefix}{iteration + 1} 次迭代 通过found_answer{format_type}找到关于问题{question}的答案: {parsed_answer[:100]}..."
) )
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search文本格式判断未找到答案", final_status=f"找到答案:{parsed_answer}",
)
return True, parsed_answer, thinking_steps, False
else:
# 未找到答案,直接退出查询
step["actions"].append(
{"action_type": "found_answer", "action_params": {"answer": ""}}
)
step["observations"] = [f"检测到found_answer{format_type}调用,未找到答案"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过found_answer{format_type}判断未找到答案"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案通过found_answer文本格式判断未找到答案",
) )
return False, "", thinking_steps, False return False, "", thinking_steps, False
# 如果没有检测到finish_search格式记录思考过程继续下一轮迭代 # 如果没有检测到found_answer格式,记录思考过程,继续下一轮迭代
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"] step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
logger.info( logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 思考完成但未调用工具: {response}" f"{react_log_prefix}{iteration + 1} 次迭代 思考完成但未调用工具: {response}"
@ -525,62 +555,55 @@ async def _react_agent_solve_question(
continue continue
# 处理工具调用 # 处理工具调用
# 首先检查是否有finish_search工具调用如果有则立即返回不再处理其他工具 # 首先检查是否有found_answer工具调用如果有则立即返回不再处理其他工具
finish_search_found = None found_answer_answer = None
finish_search_answer = None
for tool_call in tool_calls: for tool_call in tool_calls:
tool_name = tool_call.func_name tool_name = tool_call.func_name
tool_args = tool_call.args or {} tool_args = tool_call.args or {}
if tool_name == "finish_search": if tool_name == "found_answer":
finish_search_found = tool_args.get("found_answer", False) found_answer_answer = tool_args.get("answer", "")
finish_search_answer = tool_args.get("answer", "")
if finish_search_found: # 如果answer存在且非空表示找到答案否则表示未找到答案
if found_answer_answer and found_answer_answer.strip():
# 找到了答案 # 找到了答案
if finish_search_answer: step["actions"].append(
step["actions"].append( {
{ "action_type": "found_answer",
"action_type": "finish_search", "action_params": {"answer": found_answer_answer},
"action_params": {"found_answer": True, "answer": finish_search_answer}, }
} )
) step["observations"] = ["检测到found_answer工具调用找到答案"]
step["observations"] = ["检测到finish_search工具调用找到答案"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{finish_search_answer}",
)
return True, finish_search_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误
logger.warning(
f"{react_log_prefix}{iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["observations"] = ["检测到finish_search工具调用未找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info( logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search工具判断未找到答案" f"{react_log_prefix}{iteration + 1} 次迭代 通过found_answer工具找到关于问题{question}的答案: {found_answer_answer}"
) )
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search工具判断未找到答案", final_status=f"找到答案:{found_answer_answer}",
)
return True, found_answer_answer, thinking_steps, False
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": ""}})
step["observations"] = ["检测到found_answer工具调用未找到答案"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过found_answer工具判断未找到答案"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案通过found_answer工具判断未找到答案",
) )
return False, "", thinking_steps, False return False, "", thinking_steps, False
# 如果没有finish_search工具调用继续处理其他工具 # 如果没有found_answer工具调用,继续处理其他工具
tool_tasks = [] tool_tasks = []
for i, tool_call in enumerate(tool_calls): for i, tool_call in enumerate(tool_calls):
tool_name = tool_call.func_name tool_name = tool_call.func_name
@ -590,8 +613,8 @@ async def _react_agent_solve_question(
f"{react_log_prefix}{iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})" f"{react_log_prefix}{iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
) )
# 跳过finish_search工具调用(已经在上面处理过了) # 跳过found_answer工具调用(已经在上面处理过了)
if tool_name == "finish_search": if tool_name == "found_answer":
continue continue
# 记录最后一次使用的工具名称(用于判断是否需要额外迭代) # 记录最后一次使用的工具名称(用于判断是否需要额外迭代)

View File

@ -15,7 +15,7 @@ from .query_chat_history import register_tool as register_query_chat_history
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_person_info import register_tool as register_query_person_info from .query_person_info import register_tool as register_query_person_info
from .query_words import register_tool as register_query_words from .query_words import register_tool as register_query_words
from .found_answer import register_tool as register_finish_search from .found_answer import register_tool as register_found_answer
from src.config.config import global_config from src.config.config import global_config
@ -24,7 +24,7 @@ def init_all_tools():
register_query_chat_history() register_query_chat_history()
register_query_person_info() register_query_person_info()
register_query_words() # 注册query_words工具 register_query_words() # 注册query_words工具
register_finish_search() # 注册finish_search工具 register_found_answer() # 注册found_answer工具
if global_config.lpmm_knowledge.lpmm_mode == "agent": if global_config.lpmm_knowledge.lpmm_mode == "agent":
register_lpmm_knowledge() register_lpmm_knowledge()

View File

@ -1,5 +1,5 @@
""" """
finish_search工具 - 用于在记忆检索过程中结束查询 found_answer工具 - 用于在记忆检索过程中结束查询
""" """
from src.common.logger import get_logger from src.common.logger import get_logger
@ -8,17 +8,16 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools") logger = get_logger("memory_retrieval_tools")
async def finish_search(found_answer: bool, answer: str = "") -> str: async def found_answer(answer: str = "") -> str:
"""结束查询 """结束查询
Args: Args:
found_answer: 是否找到了答案 answer: 如果找到了答案提供答案内容如果未找到答案可以为空或不提供此参数
answer: 如果找到了答案提供答案内容如果未找到可以为空
Returns: Returns:
str: 确认信息 str: 确认信息
""" """
if found_answer: if answer and answer.strip():
logger.info(f"找到答案: {answer}") logger.info(f"找到答案: {answer}")
return f"已确认找到答案: {answer}" return f"已确认找到答案: {answer}"
else: else:
@ -27,23 +26,17 @@ async def finish_search(found_answer: bool, answer: str = "") -> str:
def register_tool(): def register_tool():
"""注册finish_search工具""" """注册found_answer工具"""
register_memory_retrieval_tool( register_memory_retrieval_tool(
name="finish_search", name="found_answer",
description="当你决定结束查询时,调用此工具。如果找到了明确答案,设置found_answer为true并在answer中提供答案如果未找到答案设置found_answer为false。只有在检索到明确、具体的答案时才设置found_answer为true,不要编造信息。", description="当你决定结束查询时,调用此工具。如果找到了明确答案,在answer参数中提供答案内容如果未找到答案可以不提供answer参数或提供空字符串。只有在检索到明确、具体的答案时才提供answer,不要编造信息。",
parameters=[ parameters=[
{
"name": "found_answer",
"type": "boolean",
"description": "是否找到了答案",
"required": True,
},
{ {
"name": "answer", "name": "answer",
"type": "string", "type": "string",
"description": "如果found_answer为true提供找到的答案内容必须基于已收集的信息不要编造如果found_answer为false可以为空", "description": "如果找到了答案,提供找到的答案内容,必须基于已收集的信息,不要编造;如果未找到答案,可以不提供此参数或提供空字符串",
"required": False, "required": False,
}, },
], ],
execute_func=finish_search, execute_func=found_answer,
) )