diff --git a/scripts/compare_finish_search_token.py b/scripts/compare_finish_search_token.py new file mode 100644 index 00000000..b122cfb7 --- /dev/null +++ b/scripts/compare_finish_search_token.py @@ -0,0 +1,507 @@ +import argparse +import asyncio +import os +import sys +import time +import json +import importlib +from typing import Dict, Any +from datetime import datetime + +# 强制使用 utf-8,避免控制台编码报错 +try: + if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(encoding="utf-8") + if hasattr(sys.stderr, "reconfigure"): + sys.stderr.reconfigure(encoding="utf-8") +except Exception: + pass + +# 确保能导入 src.* +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from src.common.logger import initialize_logging, get_logger +from src.common.database.database import db +from src.common.database.database_model import LLMUsage + +logger = get_logger("compare_finish_search_token") + + +def get_token_usage_since(start_time: float) -> Dict[str, Any]: + """获取从指定时间开始的token使用情况 + + Args: + start_time: 开始时间戳 + + Returns: + 包含token使用统计的字典 + """ + try: + start_datetime = datetime.fromtimestamp(start_time) + + # 查询从开始时间到现在的所有memory相关的token使用记录 + records = ( + LLMUsage.select() + .where( + (LLMUsage.timestamp >= start_datetime) + & ( + (LLMUsage.request_type.like("%memory%")) + | (LLMUsage.request_type == "memory.question") + | (LLMUsage.request_type == "memory.react") + | (LLMUsage.request_type == "memory.react.final") + ) + ) + .order_by(LLMUsage.timestamp.asc()) + ) + + total_prompt_tokens = 0 + total_completion_tokens = 0 + total_tokens = 0 + total_cost = 0.0 + request_count = 0 + model_usage = {} # 按模型统计 + + for record in records: + total_prompt_tokens += record.prompt_tokens or 0 + total_completion_tokens += record.completion_tokens or 0 + total_tokens += record.total_tokens or 0 + total_cost += record.cost or 0.0 + request_count += 1 + + # 按模型统计 + model_name = record.model_name or "unknown" + if model_name not in model_usage: + model_usage[model_name] = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "cost": 0.0, + "request_count": 0, + } + model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0 + model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0 + model_usage[model_name]["total_tokens"] += record.total_tokens or 0 + model_usage[model_name]["cost"] += record.cost or 0.0 + model_usage[model_name]["request_count"] += 1 + + return { + "total_prompt_tokens": total_prompt_tokens, + "total_completion_tokens": total_completion_tokens, + "total_tokens": total_tokens, + "total_cost": total_cost, + "request_count": request_count, + "model_usage": model_usage, + } + except Exception as e: + logger.error(f"获取token使用情况失败: {e}") + return { + "total_prompt_tokens": 0, + "total_completion_tokens": 0, + "total_tokens": 0, + "total_cost": 0.0, + "request_count": 0, + "model_usage": {}, + } + + +def _import_memory_retrieval(): + """使用 importlib 动态导入 memory_retrieval 模块,避免循环导入""" + try: + # 先导入 prompt_builder,检查 prompt 是否已经初始化 + from src.chat.utils.prompt_builder import global_prompt_manager + + # 检查 memory_retrieval 相关的 prompt 是否已经注册 + # 如果已经注册,说明模块可能已经通过其他路径初始化过了 + prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts + + module_name = "src.memory_system.memory_retrieval" + + # 如果 prompt 已经初始化,尝试直接使用已加载的模块 + if prompt_already_init and module_name in sys.modules: + existing_module = sys.modules[module_name] + if hasattr(existing_module, 'init_memory_retrieval_prompt'): + return ( + existing_module.init_memory_retrieval_prompt, + existing_module._react_agent_solve_question, + ) + + # 如果模块已经在 sys.modules 中但部分初始化,先移除它 + if module_name in sys.modules: + existing_module = sys.modules[module_name] + if not hasattr(existing_module, 'init_memory_retrieval_prompt'): + # 模块部分初始化,移除它 + logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入") + del sys.modules[module_name] + # 清理可能相关的部分初始化模块 + keys_to_remove = [] + for key in sys.modules.keys(): + if key.startswith('src.memory_system.') and key != 'src.memory_system': + keys_to_remove.append(key) + for key in keys_to_remove: + try: + del sys.modules[key] + except KeyError: + pass + + # 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载 + # 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们 + try: + # 先导入可能触发循环导入的模块,让它们完成初始化 + import src.config.config + import src.chat.utils.prompt_builder + # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval) + # 如果它们已经导入,就确保它们完全初始化 + try: + import src.chat.replyer.group_generator # noqa: F401 + except (ImportError, AttributeError): + pass # 如果导入失败,继续 + try: + import src.chat.replyer.private_generator # noqa: F401 + except (ImportError, AttributeError): + pass # 如果导入失败,继续 + except Exception as e: + logger.warning(f"预加载依赖模块时出现警告: {e}") + + # 现在尝试导入 memory_retrieval + # 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval + memory_retrieval_module = importlib.import_module(module_name) + + return ( + memory_retrieval_module.init_memory_retrieval_prompt, + memory_retrieval_module._react_agent_solve_question, + ) + except (ImportError, AttributeError) as e: + logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True) + raise + + +def _init_tools_without_finish_search(): + """初始化工具但不注册 finish_search""" + from src.memory_system.retrieval_tools import ( + register_query_chat_history, + register_query_person_info, + register_query_words, + ) + from src.memory_system.retrieval_tools.tool_registry import get_tool_registry + from src.config.config import global_config + + # 清空工具注册器 + tool_registry = get_tool_registry() + tool_registry.tools.clear() + + # 注册除 finish_search 外的所有工具 + register_query_chat_history() + register_query_person_info() + register_query_words() + + # 如果启用 LPMM agent 模式,也注册 LPMM 工具 + if global_config.lpmm_knowledge.lpmm_mode == "agent": + from src.memory_system.retrieval_tools.query_lpmm_knowledge import register_tool as register_lpmm_knowledge + register_lpmm_knowledge() + + logger.info("已初始化工具(不包含 finish_search)") + + +def _init_tools_with_finish_search(): + """初始化工具并注册 finish_search""" + from src.memory_system.retrieval_tools.tool_registry import get_tool_registry + from src.memory_system.retrieval_tools import init_all_tools + + # 清空工具注册器 + tool_registry = get_tool_registry() + tool_registry.tools.clear() + + # 初始化所有工具(包括 finish_search) + init_all_tools() + logger.info("已初始化工具(包含 finish_search)") + + +async def get_prompt_tokens_for_tools( + question: str, + chat_id: str, + use_finish_search: bool, +) -> Dict[str, Any]: + """获取使用不同工具配置时的prompt token消耗 + + Args: + question: 要查询的问题 + chat_id: 聊天ID + use_finish_search: 是否使用 finish_search 工具 + + Returns: + 包含prompt token信息的字典 + """ + # 先初始化 prompt(如果还未初始化) + # 注意:init_memory_retrieval_prompt 会调用 init_all_tools,所以我们需要在它之后重新设置工具 + from src.chat.utils.prompt_builder import global_prompt_manager + if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts: + init_memory_retrieval_prompt, _ = _import_memory_retrieval() + init_memory_retrieval_prompt() + + # 初始化工具(根据参数决定是否包含 finish_search) + # 必须在 init_memory_retrieval_prompt 之后调用,因为它会调用 init_all_tools + if use_finish_search: + _init_tools_with_finish_search() + else: + _init_tools_without_finish_search() + + # 获取工具注册器 + from src.memory_system.retrieval_tools.tool_registry import get_tool_registry + tool_registry = get_tool_registry() + tool_definitions = tool_registry.get_tool_definitions() + + # 验证工具列表(调试用) + tool_names = [tool["name"] for tool in tool_definitions] + if use_finish_search: + if "finish_search" not in tool_names: + logger.warning("期望包含 finish_search 工具,但工具列表中未找到") + else: + if "finish_search" in tool_names: + logger.warning("期望不包含 finish_search 工具,但工具列表中找到了,将移除") + # 移除 finish_search 工具 + tool_registry.tools.pop("finish_search", None) + tool_definitions = tool_registry.get_tool_definitions() + tool_names = [tool["name"] for tool in tool_definitions] + + # 构建第一次调用的prompt(模拟_react_agent_solve_question的第一次调用) + from src.config.config import global_config + bot_name = global_config.bot.nickname + time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + + # 构建head_prompt + head_prompt = await global_prompt_manager.format_prompt( + "memory_retrieval_react_prompt_head", + bot_name=bot_name, + time_now=time_now, + question=question, + collected_info="", + current_iteration=1, + remaining_iterations=global_config.memory.max_agent_iterations - 1, + max_iterations=global_config.memory.max_agent_iterations, + ) + + # 构建消息列表(只包含system message,模拟第一次调用) + from src.llm_models.payload_content.message import MessageBuilder, RoleType + messages = [] + system_builder = MessageBuilder() + system_builder.set_role(RoleType.System) + system_builder.add_text_content(head_prompt) + messages.append(system_builder.build()) + + # 调用LLM API来计算token(只调用一次,不实际执行) + from src.llm_models.utils_model import LLMRequest, RequestType + from src.config.config import model_config + + # 创建LLM请求对象 + llm_request = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="memory.react.compare") + + # 构建工具选项 + tool_built = llm_request._build_tool_options(tool_definitions) + + # 直接调用 _execute_request 以获取完整的响应对象(包含 usage) + response, model_info = await llm_request._execute_request( + request_type=RequestType.RESPONSE, + message_factory=lambda _client, *, _messages=messages: _messages, + temperature=None, + max_tokens=None, + tool_options=tool_built, + ) + + # 从响应中获取token使用情况 + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + if response and hasattr(response, 'usage') and response.usage: + prompt_tokens = response.usage.prompt_tokens or 0 + completion_tokens = response.usage.completion_tokens or 0 + total_tokens = response.usage.total_tokens or 0 + + return { + "use_finish_search": use_finish_search, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "tool_count": len(tool_definitions), + "tool_names": [tool["name"] for tool in tool_definitions], + } + + +async def compare_prompt_tokens( + question: str, + chat_id: str = "compare_finish_search", +) -> Dict[str, Any]: + """对比使用 finish_search 工具与否的输入 token 差异 + + 只运行一次,只计算输入 token 的差异,确保除了工具定义外其他内容一致 + + Args: + question: 要查询的问题 + chat_id: 聊天ID + + Returns: + 包含对比结果的字典 + """ + print("\n" + "=" * 80) + print("finish_search 工具 输入 Token 消耗对比测试") + print("=" * 80) + print(f"\n[测试问题] {question}") + print(f"[聊天ID] {chat_id}") + print("\n注意: 只对比第一次LLM调用的输入token差异,不运行完整迭代流程") + + # 第一次测试:不使用 finish_search + print("\n" + "-" * 80) + print("[测试 1/2] 不使用 finish_search 工具") + print("-" * 80) + result_without = await get_prompt_tokens_for_tools( + question=question, + chat_id=f"{chat_id}_without", + use_finish_search=False, + ) + + print(f"\n[结果]") + print(f" 工具数量: {result_without['tool_count']}") + print(f" 工具列表: {', '.join(result_without['tool_names'])}") + print(f" 输入 Prompt Tokens: {result_without['prompt_tokens']:,}") + + # 等待一下,确保数据库记录已写入 + await asyncio.sleep(1) + + # 第二次测试:使用 finish_search + print("\n" + "-" * 80) + print("[测试 2/2] 使用 finish_search 工具") + print("-" * 80) + result_with = await get_prompt_tokens_for_tools( + question=question, + chat_id=f"{chat_id}_with", + use_finish_search=True, + ) + + print(f"\n[结果]") + print(f" 工具数量: {result_with['tool_count']}") + print(f" 工具列表: {', '.join(result_with['tool_names'])}") + print(f" 输入 Prompt Tokens: {result_with['prompt_tokens']:,}") + + # 对比结果 + print("\n" + "=" * 80) + print("[对比结果]") + print("=" * 80) + + prompt_token_diff = result_with['prompt_tokens'] - result_without['prompt_tokens'] + prompt_token_diff_percent = (prompt_token_diff / result_without['prompt_tokens'] * 100) if result_without['prompt_tokens'] > 0 else 0 + + tool_count_diff = result_with['tool_count'] - result_without['tool_count'] + + print(f"\n[输入 Prompt Token 对比]") + print(f" 不使用 finish_search: {result_without['prompt_tokens']:,} tokens") + print(f" 使用 finish_search: {result_with['prompt_tokens']:,} tokens") + print(f" 差异: {prompt_token_diff:+,} tokens ({prompt_token_diff_percent:+.2f}%)") + + print(f"\n[工具数量对比]") + print(f" 不使用 finish_search: {result_without['tool_count']} 个工具") + print(f" 使用 finish_search: {result_with['tool_count']} 个工具") + print(f" 差异: {tool_count_diff:+d} 个工具") + + print(f"\n[工具列表对比]") + without_tools = set(result_without['tool_names']) + with_tools = set(result_with['tool_names']) + only_with = with_tools - without_tools + only_without = without_tools - with_tools + + if only_with: + print(f" 仅在 '使用 finish_search' 中的工具: {', '.join(only_with)}") + if only_without: + print(f" 仅在 '不使用 finish_search' 中的工具: {', '.join(only_without)}") + if not only_with and not only_without: + print(f" 工具列表相同(除了 finish_search)") + + # 显示其他token信息 + print(f"\n[其他 Token 信息]") + print(f" Completion Tokens (不使用 finish_search): {result_without.get('completion_tokens', 0):,}") + print(f" Completion Tokens (使用 finish_search): {result_with.get('completion_tokens', 0):,}") + print(f" 总 Tokens (不使用 finish_search): {result_without.get('total_tokens', 0):,}") + print(f" 总 Tokens (使用 finish_search): {result_with.get('total_tokens', 0):,}") + + print("\n" + "=" * 80) + + return { + "question": question, + "without_finish_search": result_without, + "with_finish_search": result_with, + "comparison": { + "prompt_token_diff": prompt_token_diff, + "prompt_token_diff_percent": prompt_token_diff_percent, + "tool_count_diff": tool_count_diff, + }, + } + + +def main() -> None: + parser = argparse.ArgumentParser( + description="对比使用 finish_search 工具与否的 token 消耗差异" + ) + parser.add_argument( + "--chat-id", + default="compare_finish_search", + help="测试用的聊天ID(默认: compare_finish_search)", + ) + parser.add_argument( + "--output", + "-o", + help="将结果保存到JSON文件(可选)", + ) + + args = parser.parse_args() + + # 初始化日志(使用较低的详细程度,避免输出过多日志) + initialize_logging(verbose=False) + + # 交互式输入问题 + print("\n" + "=" * 80) + print("finish_search 工具 Token 消耗对比测试工具") + print("=" * 80) + question = input("\n请输入要查询的问题: ").strip() + if not question: + print("错误: 问题不能为空") + return + + # 连接数据库 + try: + db.connect(reuse_if_open=True) + except Exception as e: + logger.error(f"数据库连接失败: {e}") + print(f"错误: 数据库连接失败: {e}") + return + + # 运行对比测试 + try: + result = asyncio.run( + compare_prompt_tokens( + question=question, + chat_id=args.chat_id, + ) + ) + + # 如果指定了输出文件,保存结果 + if args.output: + # 将thinking_steps转换为可序列化的格式 + output_result = result.copy() + with open(args.output, "w", encoding="utf-8") as f: + json.dump(output_result, f, ensure_ascii=False, indent=2) + print(f"\n[结果已保存] {args.output}") + + except KeyboardInterrupt: + print("\n\n[中断] 用户中断测试") + except Exception as e: + logger.error(f"测试失败: {e}", exc_info=True) + print(f"\n[错误] 测试失败: {e}") + finally: + try: + db.close() + except Exception: + pass + + +if __name__ == "__main__": + main() + diff --git a/scripts/test_memory_retrieval.py b/scripts/test_memory_retrieval.py new file mode 100644 index 00000000..5348bdc4 --- /dev/null +++ b/scripts/test_memory_retrieval.py @@ -0,0 +1,447 @@ +import argparse +import asyncio +import os +import sys +import time +import json +import importlib +from typing import Optional, Dict, Any +from datetime import datetime + +# 强制使用 utf-8,避免控制台编码报错 +try: + if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(encoding="utf-8") + if hasattr(sys.stderr, "reconfigure"): + sys.stderr.reconfigure(encoding="utf-8") +except Exception: + pass + +# 确保能导入 src.* +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from src.common.logger import initialize_logging, get_logger +from src.common.database.database import db +from src.common.database.database_model import LLMUsage +from src.chat.message_receive.chat_stream import ChatStream +from maim_message import UserInfo, GroupInfo + +logger = get_logger("test_memory_retrieval") + +# 使用 importlib 动态导入,避免循环导入问题 +def _import_memory_retrieval(): + """使用 importlib 动态导入 memory_retrieval 模块,避免循环导入""" + try: + # 先导入 prompt_builder,检查 prompt 是否已经初始化 + from src.chat.utils.prompt_builder import global_prompt_manager + + # 检查 memory_retrieval 相关的 prompt 是否已经注册 + # 如果已经注册,说明模块可能已经通过其他路径初始化过了 + prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts + + module_name = "src.memory_system.memory_retrieval" + + # 如果 prompt 已经初始化,尝试直接使用已加载的模块 + if prompt_already_init and module_name in sys.modules: + existing_module = sys.modules[module_name] + if hasattr(existing_module, 'init_memory_retrieval_prompt'): + return ( + existing_module.init_memory_retrieval_prompt, + existing_module._react_agent_solve_question, + existing_module._process_single_question, + ) + + # 如果模块已经在 sys.modules 中但部分初始化,先移除它 + if module_name in sys.modules: + existing_module = sys.modules[module_name] + if not hasattr(existing_module, 'init_memory_retrieval_prompt'): + # 模块部分初始化,移除它 + logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入") + del sys.modules[module_name] + # 清理可能相关的部分初始化模块 + keys_to_remove = [] + for key in sys.modules.keys(): + if key.startswith('src.memory_system.') and key != 'src.memory_system': + keys_to_remove.append(key) + for key in keys_to_remove: + try: + del sys.modules[key] + except KeyError: + pass + + # 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载 + # 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们 + try: + # 先导入可能触发循环导入的模块,让它们完成初始化 + import src.config.config + import src.chat.utils.prompt_builder + # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval) + # 如果它们已经导入,就确保它们完全初始化 + # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval) + # 如果它们已经导入,就确保它们完全初始化 + try: + import src.chat.replyer.group_generator # noqa: F401 + except (ImportError, AttributeError): + pass # 如果导入失败,继续 + try: + import src.chat.replyer.private_generator # noqa: F401 + except (ImportError, AttributeError): + pass # 如果导入失败,继续 + except Exception as e: + logger.warning(f"预加载依赖模块时出现警告: {e}") + + # 现在尝试导入 memory_retrieval + # 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval + memory_retrieval_module = importlib.import_module(module_name) + + return ( + memory_retrieval_module.init_memory_retrieval_prompt, + memory_retrieval_module._react_agent_solve_question, + memory_retrieval_module._process_single_question, + ) + except (ImportError, AttributeError) as e: + logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True) + raise + + +def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream: + """创建一个测试用的 ChatStream 对象""" + user_info = UserInfo( + platform="test", + user_id="test_user", + user_nickname="测试用户", + ) + group_info = GroupInfo( + platform="test", + group_id="test_group", + group_name="测试群组", + ) + return ChatStream( + stream_id=chat_id, + platform="test", + user_info=user_info, + group_info=group_info, + ) + + +def get_token_usage_since(start_time: float) -> Dict[str, Any]: + """获取从指定时间开始的token使用情况 + + Args: + start_time: 开始时间戳 + + Returns: + 包含token使用统计的字典 + """ + try: + start_datetime = datetime.fromtimestamp(start_time) + + # 查询从开始时间到现在的所有memory相关的token使用记录 + records = ( + LLMUsage.select() + .where( + (LLMUsage.timestamp >= start_datetime) + & ( + (LLMUsage.request_type.like("%memory%")) + | (LLMUsage.request_type == "memory.question") + | (LLMUsage.request_type == "memory.react") + | (LLMUsage.request_type == "memory.react.final") + ) + ) + .order_by(LLMUsage.timestamp.asc()) + ) + + total_prompt_tokens = 0 + total_completion_tokens = 0 + total_tokens = 0 + total_cost = 0.0 + request_count = 0 + model_usage = {} # 按模型统计 + + for record in records: + total_prompt_tokens += record.prompt_tokens or 0 + total_completion_tokens += record.completion_tokens or 0 + total_tokens += record.total_tokens or 0 + total_cost += record.cost or 0.0 + request_count += 1 + + # 按模型统计 + model_name = record.model_name or "unknown" + if model_name not in model_usage: + model_usage[model_name] = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "cost": 0.0, + "request_count": 0, + } + model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0 + model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0 + model_usage[model_name]["total_tokens"] += record.total_tokens or 0 + model_usage[model_name]["cost"] += record.cost or 0.0 + model_usage[model_name]["request_count"] += 1 + + return { + "total_prompt_tokens": total_prompt_tokens, + "total_completion_tokens": total_completion_tokens, + "total_tokens": total_tokens, + "total_cost": total_cost, + "request_count": request_count, + "model_usage": model_usage, + } + except Exception as e: + logger.error(f"获取token使用情况失败: {e}") + return { + "total_prompt_tokens": 0, + "total_completion_tokens": 0, + "total_tokens": 0, + "total_cost": 0.0, + "request_count": 0, + "model_usage": {}, + } + + +def format_thinking_steps(thinking_steps: list) -> str: + """格式化思考步骤为可读字符串""" + if not thinking_steps: + return "无思考步骤" + + lines = [] + for step in thinking_steps: + iteration = step.get("iteration", "?") + thought = step.get("thought", "") + actions = step.get("actions", []) + observations = step.get("observations", []) + + lines.append(f"\n--- 迭代 {iteration} ---") + if thought: + lines.append(f"思考: {thought[:200]}...") + + if actions: + lines.append("行动:") + for action in actions: + action_type = action.get("action_type", "unknown") + action_params = action.get("action_params", {}) + lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}") + + if observations: + lines.append("观察:") + for obs in observations: + obs_str = str(obs)[:200] + if len(str(obs)) > 200: + obs_str += "..." + lines.append(f" - {obs_str}") + + return "\n".join(lines) + + +async def test_memory_retrieval( + question: str, + chat_id: str = "test_memory_retrieval", + context: str = "", + max_iterations: Optional[int] = None, +) -> Dict[str, Any]: + """测试记忆检索功能 + + Args: + question: 要查询的问题 + chat_id: 聊天ID + context: 上下文信息 + max_iterations: 最大迭代次数 + + Returns: + 包含测试结果的字典 + """ + print("\n" + "=" * 80) + print(f"[测试] 记忆检索测试") + print(f"[问题] {question}") + print("=" * 80) + + # 记录开始时间 + start_time = time.time() + + # 延迟导入并初始化记忆检索prompt(这会自动加载 global_config) + # 注意:必须在函数内部调用,避免在模块级别触发循环导入 + try: + init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval() + + # 检查 prompt 是否已经初始化,避免重复初始化 + from src.chat.utils.prompt_builder import global_prompt_manager + if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts: + init_memory_retrieval_prompt() + else: + logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化") + except Exception as e: + logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True) + raise + + # 获取 global_config(此时应该已经加载) + from src.config.config import global_config + + # 直接调用 _react_agent_solve_question 来获取详细的迭代信息 + if max_iterations is None: + max_iterations = global_config.memory.max_agent_iterations + + timeout = global_config.memory.agent_timeout_seconds + + print(f"\n[配置]") + print(f" 最大迭代次数: {max_iterations}") + print(f" 超时时间: {timeout}秒") + print(f" 聊天ID: {chat_id}") + + # 执行检索 + print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}") + + found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( + question=question, + chat_id=chat_id, + max_iterations=max_iterations, + timeout=timeout, + initial_info="", + ) + + # 记录结束时间 + end_time = time.time() + elapsed_time = end_time - start_time + + # 获取token使用情况 + token_usage = get_token_usage_since(start_time) + + # 构建结果 + result = { + "question": question, + "found_answer": found_answer, + "answer": answer, + "is_timeout": is_timeout, + "elapsed_time": elapsed_time, + "thinking_steps": thinking_steps, + "iteration_count": len(thinking_steps), + "token_usage": token_usage, + } + + # 输出结果 + print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}") + print(f"\n[结果]") + print(f" 是否找到答案: {'是' if found_answer else '否'}") + if found_answer and answer: + print(f" 答案: {answer}") + else: + print(f" 答案: (未找到答案)") + print(f" 是否超时: {'是' if is_timeout else '否'}") + print(f" 迭代次数: {len(thinking_steps)}") + print(f" 总耗时: {elapsed_time:.2f}秒") + + print(f"\n[Token使用情况]") + print(f" 总请求数: {token_usage['request_count']}") + print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}") + print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}") + print(f" 总Tokens: {token_usage['total_tokens']:,}") + print(f" 总成本: ${token_usage['total_cost']:.6f}") + + if token_usage['model_usage']: + print(f"\n[按模型统计]") + for model_name, usage in token_usage['model_usage'].items(): + print(f" {model_name}:") + print(f" 请求数: {usage['request_count']}") + print(f" Prompt Tokens: {usage['prompt_tokens']:,}") + print(f" Completion Tokens: {usage['completion_tokens']:,}") + print(f" 总Tokens: {usage['total_tokens']:,}") + print(f" 成本: ${usage['cost']:.6f}") + + print(f"\n[迭代详情]") + print(format_thinking_steps(thinking_steps)) + + print("\n" + "=" * 80) + + return result + + +def main() -> None: + parser = argparse.ArgumentParser( + description="测试记忆检索功能。可以输入一个问题,脚本会使用记忆检索的逻辑进行检索,并记录迭代信息、时间和token总消耗。" + ) + parser.add_argument( + "--chat-id", + default="test_memory_retrieval", + help="测试用的聊天ID(默认: test_memory_retrieval)", + ) + parser.add_argument( + "--context", + default="", + help="上下文信息(可选)", + ) + parser.add_argument( + "--output", + "-o", + help="将结果保存到JSON文件(可选)", + ) + + args = parser.parse_args() + + # 初始化日志(使用较低的详细程度,避免输出过多日志) + initialize_logging(verbose=False) + + # 交互式输入问题 + print("\n" + "=" * 80) + print("记忆检索测试工具") + print("=" * 80) + question = input("\n请输入要查询的问题: ").strip() + if not question: + print("错误: 问题不能为空") + return + + # 交互式输入最大迭代次数 + max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip() + max_iterations = None + if max_iterations_input: + try: + max_iterations = int(max_iterations_input) + if max_iterations <= 0: + print("警告: 迭代次数必须大于0,将使用配置默认值") + max_iterations = None + except ValueError: + print("警告: 无效的迭代次数,将使用配置默认值") + max_iterations = None + + # 连接数据库 + try: + db.connect(reuse_if_open=True) + except Exception as e: + logger.error(f"数据库连接失败: {e}") + print(f"错误: 数据库连接失败: {e}") + return + + # 运行测试 + try: + result = asyncio.run( + test_memory_retrieval( + question=question, + chat_id=args.chat_id, + context=args.context, + max_iterations=max_iterations, + ) + ) + + # 如果指定了输出文件,保存结果 + if args.output: + # 将thinking_steps转换为可序列化的格式 + output_result = result.copy() + with open(args.output, "w", encoding="utf-8") as f: + json.dump(output_result, f, ensure_ascii=False, indent=2) + print(f"\n[结果已保存] {args.output}") + + except KeyboardInterrupt: + print("\n\n[中断] 用户中断测试") + except Exception as e: + logger.error(f"测试失败: {e}", exc_info=True) + print(f"\n[错误] 测试失败: {e}") + finally: + try: + db.close() + except Exception: + pass + + +if __name__ == "__main__": + main() + diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 9de93ef2..2f58f704 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -531,11 +531,15 @@ class HeartFChatting: quote_message: Optional[bool] = None, ) -> str: # 根据 llm_quote 配置决定是否使用 quote_message 参数 - if global_config.chat.llm_quote and quote_message is not None: + if global_config.chat.llm_quote: # 如果配置为 true,使用 llm_quote 参数决定是否引用回复 - need_reply = quote_message - if need_reply: - logger.info(f"{self.log_prefix} LLM 决定使用引用回复") + if quote_message is None: + logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用") + need_reply = False + else: + need_reply = quote_message + if need_reply: + logger.info(f"{self.log_prefix} LLM 决定使用引用回复") else: # 如果配置为 false,使用原来的模式 new_message_count = message_api.count_new_messages( @@ -663,7 +667,7 @@ class HeartFChatting: unknown_words = cleaned_uw # 从 Planner 的 action_data 中提取 quote_message 参数 - qm = action_planner_info.action_data.get("quote_message") + qm = action_planner_info.action_data.get("quote") if qm is not None: # 支持多种格式:true/false, "true"/"false", 1/0 if isinstance(qm, bool): @@ -672,6 +676,8 @@ class HeartFChatting: quote_message = qm.lower() in ("true", "1", "yes") elif isinstance(qm, (int, float)): quote_message = bool(qm) + + logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}") success, llm_response = await generator_api.generate_reply( chat_stream=self.chat_stream, diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 92ff1d79..bdb33556 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -531,7 +531,7 @@ class ActionPlanner: '"question":"需要查询的问题"' ) if global_config.chat.llm_quote: - reply_action_example += ', "quote_message":"如果需要引用该message,设置为true"' + reply_action_example += ', "quote":"如果需要引用该message,设置为true"' reply_action_example += "}" else: reply_action_example = ( @@ -546,7 +546,7 @@ class ActionPlanner: '"question":"需要查询的问题"' ) if global_config.chat.llm_quote: - reply_action_example += ', "quote_message":"如果需要引用该message,设置为true"' + reply_action_example += ', "quote":"如果需要引用该message,设置为true"' reply_action_example += "}" planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py index 1abe44ef..fa467272 100644 --- a/src/memory_system/retrieval_tools/query_chat_history.py +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -82,21 +82,51 @@ def _is_chat_id_in_blacklist(chat_id: str) -> bool: return chat_id in blacklist_chat_ids -async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str: +async def search_chat_history( + chat_id: str, + keyword: Optional[str] = None, + participant: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, +) -> str: """根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords Args: chat_id: 聊天ID keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配) participant: 参与人昵称(可选) + start_time: 开始时间(可选,格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01')。如果只提供start_time,查询该时间点之后的记录 + end_time: 结束时间(可选,格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01')。如果只提供end_time,查询该时间点之前的记录。如果同时提供start_time和end_time,查询该时间段内的记录 Returns: str: 查询结果,包含记忆id、theme和keywords """ try: # 检查参数 - if not keyword and not participant: - return "未指定查询参数(需要提供keyword或participant之一)" + if not keyword and not participant and not start_time and not end_time: + return "未指定查询参数(需要提供keyword、participant、start_time或end_time之一)" + + # 解析时间参数 + start_timestamp = None + end_timestamp = None + + if start_time: + try: + from src.memory_system.memory_utils import parse_datetime_to_timestamp + start_timestamp = parse_datetime_to_timestamp(start_time) + except ValueError as e: + return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'" + + if end_time: + try: + from src.memory_system.memory_utils import parse_datetime_to_timestamp + end_timestamp = parse_datetime_to_timestamp(end_time) + except ValueError as e: + return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'" + + # 验证时间范围 + if start_timestamp and end_timestamp and start_timestamp > end_timestamp: + return "开始时间不能晚于结束时间" # 构建查询条件 # 检查当前chat_id是否在黑名单中 @@ -128,6 +158,40 @@ async def search_chat_history(chat_id: str, keyword: Optional[str] = None, parti f"search_chat_history 当前聊天流在黑名单中,强制使用本地查询,chat_id={chat_id}, keyword={keyword}, participant={participant}" ) query = ChatHistory.select().where(ChatHistory.chat_id == chat_id) + + # 添加时间过滤条件 + if start_timestamp is not None and end_timestamp is not None: + # 查询指定时间段内的记录(记录的时间范围与查询时间段有交集) + # 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段 + query = query.where( + ( + (ChatHistory.start_time >= start_timestamp) + & (ChatHistory.start_time <= end_timestamp) + ) # 记录开始时间在查询时间段内 + | ( + (ChatHistory.end_time >= start_timestamp) + & (ChatHistory.end_time <= end_timestamp) + ) # 记录结束时间在查询时间段内 + | ( + (ChatHistory.start_time <= start_timestamp) + & (ChatHistory.end_time >= end_timestamp) + ) # 记录完全包含查询时间段 + ) + logger.debug( + f"search_chat_history 添加时间范围过滤: {start_timestamp} - {end_timestamp}, keyword={keyword}, participant={participant}" + ) + elif start_timestamp is not None: + # 只提供开始时间,查询该时间点之后的记录(记录的开始时间或结束时间在该时间点之后) + query = query.where(ChatHistory.end_time >= start_timestamp) + logger.debug( + f"search_chat_history 添加开始时间过滤: >= {start_timestamp}, keyword={keyword}, participant={participant}" + ) + elif end_timestamp is not None: + # 只提供结束时间,查询该时间点之前的记录(记录的开始时间或结束时间在该时间点之前) + query = query.where(ChatHistory.start_time <= end_timestamp) + logger.debug( + f"search_chat_history 添加结束时间过滤: <= {end_timestamp}, keyword={keyword}, participant={participant}" + ) # 执行查询 records = list(query.order_by(ChatHistory.start_time.desc()).limit(50)) @@ -217,21 +281,31 @@ async def search_chat_history(chat_id: str, keyword: Optional[str] = None, parti filtered_records.append(record) if not filtered_records: - if keyword and participant: - keywords_str = "、".join(parse_keywords_string(keyword) if keyword else []) - return f"未找到包含关键词'{keywords_str}'且参与人包含'{participant}'的聊天记录" - elif keyword: + # 构建查询条件描述 + conditions = [] + if keyword: keywords_str = "、".join(parse_keywords_string(keyword)) - keywords_list = parse_keywords_string(keyword) - if len(keywords_list) > 2: - required_count = len(keywords_list) - 1 - return ( - f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录" - ) - else: - return f"未找到包含所有关键词'{keywords_str}'的聊天记录" - elif participant: - return f"未找到参与人包含'{participant}'的聊天记录" + conditions.append(f"关键词'{keywords_str}'") + if participant: + conditions.append(f"参与人'{participant}'") + if start_timestamp or end_timestamp: + time_desc = "" + if start_timestamp and end_timestamp: + start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S") + end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S") + time_desc = f"时间范围'{start_str}' 至 '{end_str}'" + elif start_timestamp: + start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S") + time_desc = f"时间>='{start_str}'" + elif end_timestamp: + end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S") + time_desc = f"时间<='{end_str}'" + if time_desc: + conditions.append(time_desc) + + if conditions: + conditions_str = "且".join(conditions) + return f"未找到满足条件({conditions_str})的聊天记录" else: return "未找到相关聊天记录" @@ -419,7 +493,7 @@ def register_tool(): # 注册工具1:搜索记忆 register_memory_retrieval_tool( name="search_chat_history", - description="根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配(容错匹配)。", + description="根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配(容错匹配)。支持按时间点或时间段进行查询。", parameters=[ { "name": "keyword", @@ -433,6 +507,18 @@ def register_tool(): "description": "参与人昵称(可选),用于查询包含该参与人的记忆", "required": False, }, + { + "name": "start_time", + "type": "string", + "description": "开始时间(可选),格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'。如果只提供start_time,查询该时间点之后的记录。如果同时提供start_time和end_time,查询该时间段内的记录", + "required": False, + }, + { + "name": "end_time", + "type": "string", + "description": "结束时间(可选),格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'。如果只提供end_time,查询该时间点之前的记录。如果同时提供start_time和end_time,查询该时间段内的记录", + "required": False, + }, ], execute_func=search_chat_history, )