mirror of https://github.com/Mai-with-u/MaiBot.git
feat:优化记忆查询,添加时间信息
parent
ce9e17df25
commit
b296f0683f
|
|
@ -0,0 +1,507 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import importlib
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import initialize_logging, get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import LLMUsage
|
||||
|
||||
logger = get_logger("compare_finish_search_token")
|
||||
|
||||
|
||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
"""获取从指定时间开始的token使用情况
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
|
||||
Returns:
|
||||
包含token使用统计的字典
|
||||
"""
|
||||
try:
|
||||
start_datetime = datetime.fromtimestamp(start_time)
|
||||
|
||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
||||
records = (
|
||||
LLMUsage.select()
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_datetime)
|
||||
& (
|
||||
(LLMUsage.request_type.like("%memory%"))
|
||||
| (LLMUsage.request_type == "memory.question")
|
||||
| (LLMUsage.request_type == "memory.react")
|
||||
| (LLMUsage.request_type == "memory.react.final")
|
||||
)
|
||||
)
|
||||
.order_by(LLMUsage.timestamp.asc())
|
||||
)
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
request_count = 0
|
||||
model_usage = {} # 按模型统计
|
||||
|
||||
for record in records:
|
||||
total_prompt_tokens += record.prompt_tokens or 0
|
||||
total_completion_tokens += record.completion_tokens or 0
|
||||
total_tokens += record.total_tokens or 0
|
||||
total_cost += record.cost or 0.0
|
||||
request_count += 1
|
||||
|
||||
# 按模型统计
|
||||
model_name = record.model_name or "unknown"
|
||||
if model_name not in model_usage:
|
||||
model_usage[model_name] = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": 0.0,
|
||||
"request_count": 0,
|
||||
}
|
||||
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
|
||||
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
|
||||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
||||
model_usage[model_name]["request_count"] += 1
|
||||
|
||||
return {
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"request_count": request_count,
|
||||
"model_usage": model_usage,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取token使用情况失败: {e}")
|
||||
return {
|
||||
"total_prompt_tokens": 0,
|
||||
"total_completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"request_count": 0,
|
||||
"model_usage": {},
|
||||
}
|
||||
|
||||
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
try:
|
||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
||||
|
||||
module_name = "src.memory_system.memory_retrieval"
|
||||
|
||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
)
|
||||
|
||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith('src.memory_system.') and key != 'src.memory_system':
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
del sys.modules[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
||||
try:
|
||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
try:
|
||||
import src.chat.replyer.group_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
try:
|
||||
import src.chat.replyer.private_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
||||
|
||||
# 现在尝试导入 memory_retrieval
|
||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
||||
memory_retrieval_module = importlib.import_module(module_name)
|
||||
|
||||
return (
|
||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
||||
memory_retrieval_module._react_agent_solve_question,
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _init_tools_without_finish_search():
|
||||
"""初始化工具但不注册 finish_search"""
|
||||
from src.memory_system.retrieval_tools import (
|
||||
register_query_chat_history,
|
||||
register_query_person_info,
|
||||
register_query_words,
|
||||
)
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
from src.config.config import global_config
|
||||
|
||||
# 清空工具注册器
|
||||
tool_registry = get_tool_registry()
|
||||
tool_registry.tools.clear()
|
||||
|
||||
# 注册除 finish_search 外的所有工具
|
||||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
register_query_words()
|
||||
|
||||
# 如果启用 LPMM agent 模式,也注册 LPMM 工具
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
from src.memory_system.retrieval_tools.query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
register_lpmm_knowledge()
|
||||
|
||||
logger.info("已初始化工具(不包含 finish_search)")
|
||||
|
||||
|
||||
def _init_tools_with_finish_search():
|
||||
"""初始化工具并注册 finish_search"""
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
from src.memory_system.retrieval_tools import init_all_tools
|
||||
|
||||
# 清空工具注册器
|
||||
tool_registry = get_tool_registry()
|
||||
tool_registry.tools.clear()
|
||||
|
||||
# 初始化所有工具(包括 finish_search)
|
||||
init_all_tools()
|
||||
logger.info("已初始化工具(包含 finish_search)")
|
||||
|
||||
|
||||
async def get_prompt_tokens_for_tools(
|
||||
question: str,
|
||||
chat_id: str,
|
||||
use_finish_search: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""获取使用不同工具配置时的prompt token消耗
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
use_finish_search: 是否使用 finish_search 工具
|
||||
|
||||
Returns:
|
||||
包含prompt token信息的字典
|
||||
"""
|
||||
# 先初始化 prompt(如果还未初始化)
|
||||
# 注意:init_memory_retrieval_prompt 会调用 init_all_tools,所以我们需要在它之后重新设置工具
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt, _ = _import_memory_retrieval()
|
||||
init_memory_retrieval_prompt()
|
||||
|
||||
# 初始化工具(根据参数决定是否包含 finish_search)
|
||||
# 必须在 init_memory_retrieval_prompt 之后调用,因为它会调用 init_all_tools
|
||||
if use_finish_search:
|
||||
_init_tools_with_finish_search()
|
||||
else:
|
||||
_init_tools_without_finish_search()
|
||||
|
||||
# 获取工具注册器
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
tool_registry = get_tool_registry()
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
|
||||
# 验证工具列表(调试用)
|
||||
tool_names = [tool["name"] for tool in tool_definitions]
|
||||
if use_finish_search:
|
||||
if "finish_search" not in tool_names:
|
||||
logger.warning("期望包含 finish_search 工具,但工具列表中未找到")
|
||||
else:
|
||||
if "finish_search" in tool_names:
|
||||
logger.warning("期望不包含 finish_search 工具,但工具列表中找到了,将移除")
|
||||
# 移除 finish_search 工具
|
||||
tool_registry.tools.pop("finish_search", None)
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
tool_names = [tool["name"] for tool in tool_definitions]
|
||||
|
||||
# 构建第一次调用的prompt(模拟_react_agent_solve_question的第一次调用)
|
||||
from src.config.config import global_config
|
||||
bot_name = global_config.bot.nickname
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
# 构建head_prompt
|
||||
head_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_react_prompt_head",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
question=question,
|
||||
collected_info="",
|
||||
current_iteration=1,
|
||||
remaining_iterations=global_config.memory.max_agent_iterations - 1,
|
||||
max_iterations=global_config.memory.max_agent_iterations,
|
||||
)
|
||||
|
||||
# 构建消息列表(只包含system message,模拟第一次调用)
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
messages = []
|
||||
system_builder = MessageBuilder()
|
||||
system_builder.set_role(RoleType.System)
|
||||
system_builder.add_text_content(head_prompt)
|
||||
messages.append(system_builder.build())
|
||||
|
||||
# 调用LLM API来计算token(只调用一次,不实际执行)
|
||||
from src.llm_models.utils_model import LLMRequest, RequestType
|
||||
from src.config.config import model_config
|
||||
|
||||
# 创建LLM请求对象
|
||||
llm_request = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="memory.react.compare")
|
||||
|
||||
# 构建工具选项
|
||||
tool_built = llm_request._build_tool_options(tool_definitions)
|
||||
|
||||
# 直接调用 _execute_request 以获取完整的响应对象(包含 usage)
|
||||
response, model_info = await llm_request._execute_request(
|
||||
request_type=RequestType.RESPONSE,
|
||||
message_factory=lambda _client, *, _messages=messages: _messages,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
tool_options=tool_built,
|
||||
)
|
||||
|
||||
# 从响应中获取token使用情况
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
if response and hasattr(response, 'usage') and response.usage:
|
||||
prompt_tokens = response.usage.prompt_tokens or 0
|
||||
completion_tokens = response.usage.completion_tokens or 0
|
||||
total_tokens = response.usage.total_tokens or 0
|
||||
|
||||
return {
|
||||
"use_finish_search": use_finish_search,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"tool_count": len(tool_definitions),
|
||||
"tool_names": [tool["name"] for tool in tool_definitions],
|
||||
}
|
||||
|
||||
|
||||
async def compare_prompt_tokens(
|
||||
question: str,
|
||||
chat_id: str = "compare_finish_search",
|
||||
) -> Dict[str, Any]:
|
||||
"""对比使用 finish_search 工具与否的输入 token 差异
|
||||
|
||||
只运行一次,只计算输入 token 的差异,确保除了工具定义外其他内容一致
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
包含对比结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("finish_search 工具 输入 Token 消耗对比测试")
|
||||
print("=" * 80)
|
||||
print(f"\n[测试问题] {question}")
|
||||
print(f"[聊天ID] {chat_id}")
|
||||
print("\n注意: 只对比第一次LLM调用的输入token差异,不运行完整迭代流程")
|
||||
|
||||
# 第一次测试:不使用 finish_search
|
||||
print("\n" + "-" * 80)
|
||||
print("[测试 1/2] 不使用 finish_search 工具")
|
||||
print("-" * 80)
|
||||
result_without = await get_prompt_tokens_for_tools(
|
||||
question=question,
|
||||
chat_id=f"{chat_id}_without",
|
||||
use_finish_search=False,
|
||||
)
|
||||
|
||||
print(f"\n[结果]")
|
||||
print(f" 工具数量: {result_without['tool_count']}")
|
||||
print(f" 工具列表: {', '.join(result_without['tool_names'])}")
|
||||
print(f" 输入 Prompt Tokens: {result_without['prompt_tokens']:,}")
|
||||
|
||||
# 等待一下,确保数据库记录已写入
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 第二次测试:使用 finish_search
|
||||
print("\n" + "-" * 80)
|
||||
print("[测试 2/2] 使用 finish_search 工具")
|
||||
print("-" * 80)
|
||||
result_with = await get_prompt_tokens_for_tools(
|
||||
question=question,
|
||||
chat_id=f"{chat_id}_with",
|
||||
use_finish_search=True,
|
||||
)
|
||||
|
||||
print(f"\n[结果]")
|
||||
print(f" 工具数量: {result_with['tool_count']}")
|
||||
print(f" 工具列表: {', '.join(result_with['tool_names'])}")
|
||||
print(f" 输入 Prompt Tokens: {result_with['prompt_tokens']:,}")
|
||||
|
||||
# 对比结果
|
||||
print("\n" + "=" * 80)
|
||||
print("[对比结果]")
|
||||
print("=" * 80)
|
||||
|
||||
prompt_token_diff = result_with['prompt_tokens'] - result_without['prompt_tokens']
|
||||
prompt_token_diff_percent = (prompt_token_diff / result_without['prompt_tokens'] * 100) if result_without['prompt_tokens'] > 0 else 0
|
||||
|
||||
tool_count_diff = result_with['tool_count'] - result_without['tool_count']
|
||||
|
||||
print(f"\n[输入 Prompt Token 对比]")
|
||||
print(f" 不使用 finish_search: {result_without['prompt_tokens']:,} tokens")
|
||||
print(f" 使用 finish_search: {result_with['prompt_tokens']:,} tokens")
|
||||
print(f" 差异: {prompt_token_diff:+,} tokens ({prompt_token_diff_percent:+.2f}%)")
|
||||
|
||||
print(f"\n[工具数量对比]")
|
||||
print(f" 不使用 finish_search: {result_without['tool_count']} 个工具")
|
||||
print(f" 使用 finish_search: {result_with['tool_count']} 个工具")
|
||||
print(f" 差异: {tool_count_diff:+d} 个工具")
|
||||
|
||||
print(f"\n[工具列表对比]")
|
||||
without_tools = set(result_without['tool_names'])
|
||||
with_tools = set(result_with['tool_names'])
|
||||
only_with = with_tools - without_tools
|
||||
only_without = without_tools - with_tools
|
||||
|
||||
if only_with:
|
||||
print(f" 仅在 '使用 finish_search' 中的工具: {', '.join(only_with)}")
|
||||
if only_without:
|
||||
print(f" 仅在 '不使用 finish_search' 中的工具: {', '.join(only_without)}")
|
||||
if not only_with and not only_without:
|
||||
print(f" 工具列表相同(除了 finish_search)")
|
||||
|
||||
# 显示其他token信息
|
||||
print(f"\n[其他 Token 信息]")
|
||||
print(f" Completion Tokens (不使用 finish_search): {result_without.get('completion_tokens', 0):,}")
|
||||
print(f" Completion Tokens (使用 finish_search): {result_with.get('completion_tokens', 0):,}")
|
||||
print(f" 总 Tokens (不使用 finish_search): {result_without.get('total_tokens', 0):,}")
|
||||
print(f" 总 Tokens (使用 finish_search): {result_with.get('total_tokens', 0):,}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"without_finish_search": result_without,
|
||||
"with_finish_search": result_with,
|
||||
"comparison": {
|
||||
"prompt_token_diff": prompt_token_diff,
|
||||
"prompt_token_diff_percent": prompt_token_diff_percent,
|
||||
"tool_count_diff": tool_count_diff,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="对比使用 finish_search 工具与否的 token 消耗差异"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
default="compare_finish_search",
|
||||
help="测试用的聊天ID(默认: compare_finish_search)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
help="将结果保存到JSON文件(可选)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
||||
initialize_logging(verbose=False)
|
||||
|
||||
# 交互式输入问题
|
||||
print("\n" + "=" * 80)
|
||||
print("finish_search 工具 Token 消耗对比测试工具")
|
||||
print("=" * 80)
|
||||
question = input("\n请输入要查询的问题: ").strip()
|
||||
if not question:
|
||||
print("错误: 问题不能为空")
|
||||
return
|
||||
|
||||
# 连接数据库
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
print(f"错误: 数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 运行对比测试
|
||||
try:
|
||||
result = asyncio.run(
|
||||
compare_prompt_tokens(
|
||||
question=question,
|
||||
chat_id=args.chat_id,
|
||||
)
|
||||
)
|
||||
|
||||
# 如果指定了输出文件,保存结果
|
||||
if args.output:
|
||||
# 将thinking_steps转换为可序列化的格式
|
||||
output_result = result.copy()
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[结果已保存] {args.output}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[中断] 用户中断测试")
|
||||
except Exception as e:
|
||||
logger.error(f"测试失败: {e}", exc_info=True)
|
||||
print(f"\n[错误] 测试失败: {e}")
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -0,0 +1,447 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import importlib
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import initialize_logging, get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import LLMUsage
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
logger = get_logger("test_memory_retrieval")
|
||||
|
||||
# 使用 importlib 动态导入,避免循环导入问题
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
try:
|
||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
||||
|
||||
module_name = "src.memory_system.memory_retrieval"
|
||||
|
||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
existing_module._process_single_question,
|
||||
)
|
||||
|
||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith('src.memory_system.') and key != 'src.memory_system':
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
del sys.modules[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
||||
try:
|
||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
try:
|
||||
import src.chat.replyer.group_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
try:
|
||||
import src.chat.replyer.private_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
||||
|
||||
# 现在尝试导入 memory_retrieval
|
||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
||||
memory_retrieval_module = importlib.import_module(module_name)
|
||||
|
||||
return (
|
||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
||||
memory_retrieval_module._react_agent_solve_question,
|
||||
memory_retrieval_module._process_single_question,
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream:
|
||||
"""创建一个测试用的 ChatStream 对象"""
|
||||
user_info = UserInfo(
|
||||
platform="test",
|
||||
user_id="test_user",
|
||||
user_nickname="测试用户",
|
||||
)
|
||||
group_info = GroupInfo(
|
||||
platform="test",
|
||||
group_id="test_group",
|
||||
group_name="测试群组",
|
||||
)
|
||||
return ChatStream(
|
||||
stream_id=chat_id,
|
||||
platform="test",
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
|
||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
"""获取从指定时间开始的token使用情况
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
|
||||
Returns:
|
||||
包含token使用统计的字典
|
||||
"""
|
||||
try:
|
||||
start_datetime = datetime.fromtimestamp(start_time)
|
||||
|
||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
||||
records = (
|
||||
LLMUsage.select()
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_datetime)
|
||||
& (
|
||||
(LLMUsage.request_type.like("%memory%"))
|
||||
| (LLMUsage.request_type == "memory.question")
|
||||
| (LLMUsage.request_type == "memory.react")
|
||||
| (LLMUsage.request_type == "memory.react.final")
|
||||
)
|
||||
)
|
||||
.order_by(LLMUsage.timestamp.asc())
|
||||
)
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
request_count = 0
|
||||
model_usage = {} # 按模型统计
|
||||
|
||||
for record in records:
|
||||
total_prompt_tokens += record.prompt_tokens or 0
|
||||
total_completion_tokens += record.completion_tokens or 0
|
||||
total_tokens += record.total_tokens or 0
|
||||
total_cost += record.cost or 0.0
|
||||
request_count += 1
|
||||
|
||||
# 按模型统计
|
||||
model_name = record.model_name or "unknown"
|
||||
if model_name not in model_usage:
|
||||
model_usage[model_name] = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": 0.0,
|
||||
"request_count": 0,
|
||||
}
|
||||
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
|
||||
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
|
||||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
||||
model_usage[model_name]["request_count"] += 1
|
||||
|
||||
return {
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"request_count": request_count,
|
||||
"model_usage": model_usage,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取token使用情况失败: {e}")
|
||||
return {
|
||||
"total_prompt_tokens": 0,
|
||||
"total_completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"request_count": 0,
|
||||
"model_usage": {},
|
||||
}
|
||||
|
||||
|
||||
def format_thinking_steps(thinking_steps: list) -> str:
|
||||
"""格式化思考步骤为可读字符串"""
|
||||
if not thinking_steps:
|
||||
return "无思考步骤"
|
||||
|
||||
lines = []
|
||||
for step in thinking_steps:
|
||||
iteration = step.get("iteration", "?")
|
||||
thought = step.get("thought", "")
|
||||
actions = step.get("actions", [])
|
||||
observations = step.get("observations", [])
|
||||
|
||||
lines.append(f"\n--- 迭代 {iteration} ---")
|
||||
if thought:
|
||||
lines.append(f"思考: {thought[:200]}...")
|
||||
|
||||
if actions:
|
||||
lines.append("行动:")
|
||||
for action in actions:
|
||||
action_type = action.get("action_type", "unknown")
|
||||
action_params = action.get("action_params", {})
|
||||
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
|
||||
|
||||
if observations:
|
||||
lines.append("观察:")
|
||||
for obs in observations:
|
||||
obs_str = str(obs)[:200]
|
||||
if len(str(obs)) > 200:
|
||||
obs_str += "..."
|
||||
lines.append(f" - {obs_str}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def test_memory_retrieval(
|
||||
question: str,
|
||||
chat_id: str = "test_memory_retrieval",
|
||||
context: str = "",
|
||||
max_iterations: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""测试记忆检索功能
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
max_iterations: 最大迭代次数
|
||||
|
||||
Returns:
|
||||
包含测试结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"[测试] 记忆检索测试")
|
||||
print(f"[问题] {question}")
|
||||
print("=" * 80)
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 延迟导入并初始化记忆检索prompt(这会自动加载 global_config)
|
||||
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
|
||||
try:
|
||||
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
|
||||
|
||||
# 检查 prompt 是否已经初始化,避免重复初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt()
|
||||
else:
|
||||
logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# 获取 global_config(此时应该已经加载)
|
||||
from src.config.config import global_config
|
||||
|
||||
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
||||
timeout = global_config.memory.agent_timeout_seconds
|
||||
|
||||
print(f"\n[配置]")
|
||||
print(f" 最大迭代次数: {max_iterations}")
|
||||
print(f" 超时时间: {timeout}秒")
|
||||
print(f" 聊天ID: {chat_id}")
|
||||
|
||||
# 执行检索
|
||||
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
max_iterations=max_iterations,
|
||||
timeout=timeout,
|
||||
initial_info="",
|
||||
)
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
# 获取token使用情况
|
||||
token_usage = get_token_usage_since(start_time)
|
||||
|
||||
# 构建结果
|
||||
result = {
|
||||
"question": question,
|
||||
"found_answer": found_answer,
|
||||
"answer": answer,
|
||||
"is_timeout": is_timeout,
|
||||
"elapsed_time": elapsed_time,
|
||||
"thinking_steps": thinking_steps,
|
||||
"iteration_count": len(thinking_steps),
|
||||
"token_usage": token_usage,
|
||||
}
|
||||
|
||||
# 输出结果
|
||||
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
print(f"\n[结果]")
|
||||
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
||||
if found_answer and answer:
|
||||
print(f" 答案: {answer}")
|
||||
else:
|
||||
print(f" 答案: (未找到答案)")
|
||||
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
||||
print(f" 迭代次数: {len(thinking_steps)}")
|
||||
print(f" 总耗时: {elapsed_time:.2f}秒")
|
||||
|
||||
print(f"\n[Token使用情况]")
|
||||
print(f" 总请求数: {token_usage['request_count']}")
|
||||
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
||||
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
||||
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
||||
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
||||
|
||||
if token_usage['model_usage']:
|
||||
print(f"\n[按模型统计]")
|
||||
for model_name, usage in token_usage['model_usage'].items():
|
||||
print(f" {model_name}:")
|
||||
print(f" 请求数: {usage['request_count']}")
|
||||
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
||||
print(f" Completion Tokens: {usage['completion_tokens']:,}")
|
||||
print(f" 总Tokens: {usage['total_tokens']:,}")
|
||||
print(f" 成本: ${usage['cost']:.6f}")
|
||||
|
||||
print(f"\n[迭代详情]")
|
||||
print(format_thinking_steps(thinking_steps))
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="测试记忆检索功能。可以输入一个问题,脚本会使用记忆检索的逻辑进行检索,并记录迭代信息、时间和token总消耗。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
default="test_memory_retrieval",
|
||||
help="测试用的聊天ID(默认: test_memory_retrieval)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context",
|
||||
default="",
|
||||
help="上下文信息(可选)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
help="将结果保存到JSON文件(可选)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
||||
initialize_logging(verbose=False)
|
||||
|
||||
# 交互式输入问题
|
||||
print("\n" + "=" * 80)
|
||||
print("记忆检索测试工具")
|
||||
print("=" * 80)
|
||||
question = input("\n请输入要查询的问题: ").strip()
|
||||
if not question:
|
||||
print("错误: 问题不能为空")
|
||||
return
|
||||
|
||||
# 交互式输入最大迭代次数
|
||||
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
|
||||
max_iterations = None
|
||||
if max_iterations_input:
|
||||
try:
|
||||
max_iterations = int(max_iterations_input)
|
||||
if max_iterations <= 0:
|
||||
print("警告: 迭代次数必须大于0,将使用配置默认值")
|
||||
max_iterations = None
|
||||
except ValueError:
|
||||
print("警告: 无效的迭代次数,将使用配置默认值")
|
||||
max_iterations = None
|
||||
|
||||
# 连接数据库
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
print(f"错误: 数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 运行测试
|
||||
try:
|
||||
result = asyncio.run(
|
||||
test_memory_retrieval(
|
||||
question=question,
|
||||
chat_id=args.chat_id,
|
||||
context=args.context,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
)
|
||||
|
||||
# 如果指定了输出文件,保存结果
|
||||
if args.output:
|
||||
# 将thinking_steps转换为可序列化的格式
|
||||
output_result = result.copy()
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[结果已保存] {args.output}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[中断] 用户中断测试")
|
||||
except Exception as e:
|
||||
logger.error(f"测试失败: {e}", exc_info=True)
|
||||
print(f"\n[错误] 测试失败: {e}")
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -531,11 +531,15 @@ class HeartFChatting:
|
|||
quote_message: Optional[bool] = None,
|
||||
) -> str:
|
||||
# 根据 llm_quote 配置决定是否使用 quote_message 参数
|
||||
if global_config.chat.llm_quote and quote_message is not None:
|
||||
if global_config.chat.llm_quote:
|
||||
# 如果配置为 true,使用 llm_quote 参数决定是否引用回复
|
||||
need_reply = quote_message
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
|
||||
if quote_message is None:
|
||||
logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用")
|
||||
need_reply = False
|
||||
else:
|
||||
need_reply = quote_message
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
|
||||
else:
|
||||
# 如果配置为 false,使用原来的模式
|
||||
new_message_count = message_api.count_new_messages(
|
||||
|
|
@ -663,7 +667,7 @@ class HeartFChatting:
|
|||
unknown_words = cleaned_uw
|
||||
|
||||
# 从 Planner 的 action_data 中提取 quote_message 参数
|
||||
qm = action_planner_info.action_data.get("quote_message")
|
||||
qm = action_planner_info.action_data.get("quote")
|
||||
if qm is not None:
|
||||
# 支持多种格式:true/false, "true"/"false", 1/0
|
||||
if isinstance(qm, bool):
|
||||
|
|
@ -672,6 +676,8 @@ class HeartFChatting:
|
|||
quote_message = qm.lower() in ("true", "1", "yes")
|
||||
elif isinstance(qm, (int, float)):
|
||||
quote_message = bool(qm)
|
||||
|
||||
logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
|
|
|
|||
|
|
@ -531,7 +531,7 @@ class ActionPlanner:
|
|||
'"question":"需要查询的问题"'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote_message":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += "}"
|
||||
else:
|
||||
reply_action_example = (
|
||||
|
|
@ -546,7 +546,7 @@ class ActionPlanner:
|
|||
'"question":"需要查询的问题"'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote_message":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += "}"
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
|
|
|
|||
|
|
@ -82,21 +82,51 @@ def _is_chat_id_in_blacklist(chat_id: str) -> bool:
|
|||
return chat_id in blacklist_chat_ids
|
||||
|
||||
|
||||
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
|
||||
async def search_chat_history(
|
||||
chat_id: str,
|
||||
keyword: Optional[str] = None,
|
||||
participant: Optional[str] = None,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
) -> str:
|
||||
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配)
|
||||
participant: 参与人昵称(可选)
|
||||
start_time: 开始时间(可选,格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01')。如果只提供start_time,查询该时间点之后的记录
|
||||
end_time: 结束时间(可选,格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01')。如果只提供end_time,查询该时间点之前的记录。如果同时提供start_time和end_time,查询该时间段内的记录
|
||||
|
||||
Returns:
|
||||
str: 查询结果,包含记忆id、theme和keywords
|
||||
"""
|
||||
try:
|
||||
# 检查参数
|
||||
if not keyword and not participant:
|
||||
return "未指定查询参数(需要提供keyword或participant之一)"
|
||||
if not keyword and not participant and not start_time and not end_time:
|
||||
return "未指定查询参数(需要提供keyword、participant、start_time或end_time之一)"
|
||||
|
||||
# 解析时间参数
|
||||
start_timestamp = None
|
||||
end_timestamp = None
|
||||
|
||||
if start_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
start_timestamp = parse_datetime_to_timestamp(start_time)
|
||||
except ValueError as e:
|
||||
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
||||
if end_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
end_timestamp = parse_datetime_to_timestamp(end_time)
|
||||
except ValueError as e:
|
||||
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
||||
# 验证时间范围
|
||||
if start_timestamp and end_timestamp and start_timestamp > end_timestamp:
|
||||
return "开始时间不能晚于结束时间"
|
||||
|
||||
# 构建查询条件
|
||||
# 检查当前chat_id是否在黑名单中
|
||||
|
|
@ -128,6 +158,40 @@ async def search_chat_history(chat_id: str, keyword: Optional[str] = None, parti
|
|||
f"search_chat_history 当前聊天流在黑名单中,强制使用本地查询,chat_id={chat_id}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
# 添加时间过滤条件
|
||||
if start_timestamp is not None and end_timestamp is not None:
|
||||
# 查询指定时间段内的记录(记录的时间范围与查询时间段有交集)
|
||||
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
|
||||
query = query.where(
|
||||
(
|
||||
(ChatHistory.start_time >= start_timestamp)
|
||||
& (ChatHistory.start_time <= end_timestamp)
|
||||
) # 记录开始时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.end_time >= start_timestamp)
|
||||
& (ChatHistory.end_time <= end_timestamp)
|
||||
) # 记录结束时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.start_time <= start_timestamp)
|
||||
& (ChatHistory.end_time >= end_timestamp)
|
||||
) # 记录完全包含查询时间段
|
||||
)
|
||||
logger.debug(
|
||||
f"search_chat_history 添加时间范围过滤: {start_timestamp} - {end_timestamp}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
elif start_timestamp is not None:
|
||||
# 只提供开始时间,查询该时间点之后的记录(记录的开始时间或结束时间在该时间点之后)
|
||||
query = query.where(ChatHistory.end_time >= start_timestamp)
|
||||
logger.debug(
|
||||
f"search_chat_history 添加开始时间过滤: >= {start_timestamp}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
elif end_timestamp is not None:
|
||||
# 只提供结束时间,查询该时间点之前的记录(记录的开始时间或结束时间在该时间点之前)
|
||||
query = query.where(ChatHistory.start_time <= end_timestamp)
|
||||
logger.debug(
|
||||
f"search_chat_history 添加结束时间过滤: <= {end_timestamp}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
|
||||
# 执行查询
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
|
|
@ -217,21 +281,31 @@ async def search_chat_history(chat_id: str, keyword: Optional[str] = None, parti
|
|||
filtered_records.append(record)
|
||||
|
||||
if not filtered_records:
|
||||
if keyword and participant:
|
||||
keywords_str = "、".join(parse_keywords_string(keyword) if keyword else [])
|
||||
return f"未找到包含关键词'{keywords_str}'且参与人包含'{participant}'的聊天记录"
|
||||
elif keyword:
|
||||
# 构建查询条件描述
|
||||
conditions = []
|
||||
if keyword:
|
||||
keywords_str = "、".join(parse_keywords_string(keyword))
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if len(keywords_list) > 2:
|
||||
required_count = len(keywords_list) - 1
|
||||
return (
|
||||
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
)
|
||||
else:
|
||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||
elif participant:
|
||||
return f"未找到参与人包含'{participant}'的聊天记录"
|
||||
conditions.append(f"关键词'{keywords_str}'")
|
||||
if participant:
|
||||
conditions.append(f"参与人'{participant}'")
|
||||
if start_timestamp or end_timestamp:
|
||||
time_desc = ""
|
||||
if start_timestamp and end_timestamp:
|
||||
start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
time_desc = f"时间范围'{start_str}' 至 '{end_str}'"
|
||||
elif start_timestamp:
|
||||
start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
time_desc = f"时间>='{start_str}'"
|
||||
elif end_timestamp:
|
||||
end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
time_desc = f"时间<='{end_str}'"
|
||||
if time_desc:
|
||||
conditions.append(time_desc)
|
||||
|
||||
if conditions:
|
||||
conditions_str = "且".join(conditions)
|
||||
return f"未找到满足条件({conditions_str})的聊天记录"
|
||||
else:
|
||||
return "未找到相关聊天记录"
|
||||
|
||||
|
|
@ -419,7 +493,7 @@ def register_tool():
|
|||
# 注册工具1:搜索记忆
|
||||
register_memory_retrieval_tool(
|
||||
name="search_chat_history",
|
||||
description="根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配(容错匹配)。",
|
||||
description="根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配(容错匹配)。支持按时间点或时间段进行查询。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "keyword",
|
||||
|
|
@ -433,6 +507,18 @@ def register_tool():
|
|||
"description": "参与人昵称(可选),用于查询包含该参与人的记忆",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "start_time",
|
||||
"type": "string",
|
||||
"description": "开始时间(可选),格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'。如果只提供start_time,查询该时间点之后的记录。如果同时提供start_time和end_time,查询该时间段内的记录",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "end_time",
|
||||
"type": "string",
|
||||
"description": "结束时间(可选),格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'。如果只提供end_time,查询该时间点之前的记录。如果同时提供start_time和end_time,查询该时间段内的记录",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=search_chat_history,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue