MaiBot/scripts/compare_finish_search_token.py

508 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import argparse
import asyncio
import os
import sys
import time
import json
import importlib
from typing import Dict, Any
from datetime import datetime
# 强制使用 utf-8避免控制台编码报错
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
# 确保能导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db
from src.common.database.database_model import LLMUsage
logger = get_logger("compare_finish_search_token")
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况
Args:
start_time: 开始时间戳
Returns:
包含token使用统计的字典
"""
try:
start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录
records = (
LLMUsage.select()
.where(
(LLMUsage.timestamp >= start_datetime)
& (
(LLMUsage.request_type.like("%memory%"))
| (LLMUsage.request_type == "memory.question")
| (LLMUsage.request_type == "memory.react")
| (LLMUsage.request_type == "memory.react.final")
)
)
.order_by(LLMUsage.timestamp.asc())
)
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
total_cost = 0.0
request_count = 0
model_usage = {} # 按模型统计
for record in records:
total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0
request_count += 1
# 按模型统计
model_name = record.model_name or "unknown"
if model_name not in model_usage:
model_usage[model_name] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost": 0.0,
"request_count": 0,
}
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1
return {
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"total_cost": total_cost,
"request_count": request_count,
"model_usage": model_usage,
}
except Exception as e:
logger.error(f"获取token使用情况失败: {e}")
return {
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"total_tokens": 0,
"total_cost": 0.0,
"request_count": 0,
"model_usage": {},
}
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try:
# 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
)
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith('src.memory_system.') and key != 'src.memory_system':
keys_to_remove.append(key)
for key in keys_to_remove:
try:
del sys.modules[key]
except KeyError:
pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try:
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
try:
import src.chat.replyer.group_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
try:
import src.chat.replyer.private_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name)
return (
memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question,
)
except (ImportError, AttributeError) as e:
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
raise
def _init_tools_without_finish_search():
"""初始化工具但不注册 finish_search"""
from src.memory_system.retrieval_tools import (
register_query_chat_history,
register_query_person_info,
register_query_words,
)
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
from src.config.config import global_config
# 清空工具注册器
tool_registry = get_tool_registry()
tool_registry.tools.clear()
# 注册除 finish_search 外的所有工具
register_query_chat_history()
register_query_person_info()
register_query_words()
# 如果启用 LPMM agent 模式,也注册 LPMM 工具
if global_config.lpmm_knowledge.lpmm_mode == "agent":
from src.memory_system.retrieval_tools.query_lpmm_knowledge import register_tool as register_lpmm_knowledge
register_lpmm_knowledge()
logger.info("已初始化工具(不包含 finish_search")
def _init_tools_with_finish_search():
"""初始化工具并注册 finish_search"""
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
from src.memory_system.retrieval_tools import init_all_tools
# 清空工具注册器
tool_registry = get_tool_registry()
tool_registry.tools.clear()
# 初始化所有工具(包括 finish_search
init_all_tools()
logger.info("已初始化工具(包含 finish_search")
async def get_prompt_tokens_for_tools(
question: str,
chat_id: str,
use_finish_search: bool,
) -> Dict[str, Any]:
"""获取使用不同工具配置时的prompt token消耗
Args:
question: 要查询的问题
chat_id: 聊天ID
use_finish_search: 是否使用 finish_search 工具
Returns:
包含prompt token信息的字典
"""
# 先初始化 prompt如果还未初始化
# 注意init_memory_retrieval_prompt 会调用 init_all_tools所以我们需要在它之后重新设置工具
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt, _ = _import_memory_retrieval()
init_memory_retrieval_prompt()
# 初始化工具(根据参数决定是否包含 finish_search
# 必须在 init_memory_retrieval_prompt 之后调用,因为它会调用 init_all_tools
if use_finish_search:
_init_tools_with_finish_search()
else:
_init_tools_without_finish_search()
# 获取工具注册器
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
tool_registry = get_tool_registry()
tool_definitions = tool_registry.get_tool_definitions()
# 验证工具列表(调试用)
tool_names = [tool["name"] for tool in tool_definitions]
if use_finish_search:
if "finish_search" not in tool_names:
logger.warning("期望包含 finish_search 工具,但工具列表中未找到")
else:
if "finish_search" in tool_names:
logger.warning("期望不包含 finish_search 工具,但工具列表中找到了,将移除")
# 移除 finish_search 工具
tool_registry.tools.pop("finish_search", None)
tool_definitions = tool_registry.get_tool_definitions()
tool_names = [tool["name"] for tool in tool_definitions]
# 构建第一次调用的prompt模拟_react_agent_solve_question的第一次调用
from src.config.config import global_config
bot_name = global_config.bot.nickname
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# 构建head_prompt
head_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_prompt_head",
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info="",
current_iteration=1,
remaining_iterations=global_config.memory.max_agent_iterations - 1,
max_iterations=global_config.memory.max_agent_iterations,
)
# 构建消息列表只包含system message模拟第一次调用
from src.llm_models.payload_content.message import MessageBuilder, RoleType
messages = []
system_builder = MessageBuilder()
system_builder.set_role(RoleType.System)
system_builder.add_text_content(head_prompt)
messages.append(system_builder.build())
# 调用LLM API来计算token只调用一次不实际执行
from src.llm_models.utils_model import LLMRequest, RequestType
from src.config.config import model_config
# 创建LLM请求对象
llm_request = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="memory.react.compare")
# 构建工具选项
tool_built = llm_request._build_tool_options(tool_definitions)
# 直接调用 _execute_request 以获取完整的响应对象(包含 usage
response, model_info = await llm_request._execute_request(
request_type=RequestType.RESPONSE,
message_factory=lambda _client, *, _messages=messages: _messages,
temperature=None,
max_tokens=None,
tool_options=tool_built,
)
# 从响应中获取token使用情况
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
if response and hasattr(response, 'usage') and response.usage:
prompt_tokens = response.usage.prompt_tokens or 0
completion_tokens = response.usage.completion_tokens or 0
total_tokens = response.usage.total_tokens or 0
return {
"use_finish_search": use_finish_search,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"tool_count": len(tool_definitions),
"tool_names": [tool["name"] for tool in tool_definitions],
}
async def compare_prompt_tokens(
question: str,
chat_id: str = "compare_finish_search",
) -> Dict[str, Any]:
"""对比使用 finish_search 工具与否的输入 token 差异
只运行一次,只计算输入 token 的差异,确保除了工具定义外其他内容一致
Args:
question: 要查询的问题
chat_id: 聊天ID
Returns:
包含对比结果的字典
"""
print("\n" + "=" * 80)
print("finish_search 工具 输入 Token 消耗对比测试")
print("=" * 80)
print(f"\n[测试问题] {question}")
print(f"[聊天ID] {chat_id}")
print("\n注意: 只对比第一次LLM调用的输入token差异不运行完整迭代流程")
# 第一次测试:不使用 finish_search
print("\n" + "-" * 80)
print("[测试 1/2] 不使用 finish_search 工具")
print("-" * 80)
result_without = await get_prompt_tokens_for_tools(
question=question,
chat_id=f"{chat_id}_without",
use_finish_search=False,
)
print(f"\n[结果]")
print(f" 工具数量: {result_without['tool_count']}")
print(f" 工具列表: {', '.join(result_without['tool_names'])}")
print(f" 输入 Prompt Tokens: {result_without['prompt_tokens']:,}")
# 等待一下,确保数据库记录已写入
await asyncio.sleep(1)
# 第二次测试:使用 finish_search
print("\n" + "-" * 80)
print("[测试 2/2] 使用 finish_search 工具")
print("-" * 80)
result_with = await get_prompt_tokens_for_tools(
question=question,
chat_id=f"{chat_id}_with",
use_finish_search=True,
)
print(f"\n[结果]")
print(f" 工具数量: {result_with['tool_count']}")
print(f" 工具列表: {', '.join(result_with['tool_names'])}")
print(f" 输入 Prompt Tokens: {result_with['prompt_tokens']:,}")
# 对比结果
print("\n" + "=" * 80)
print("[对比结果]")
print("=" * 80)
prompt_token_diff = result_with['prompt_tokens'] - result_without['prompt_tokens']
prompt_token_diff_percent = (prompt_token_diff / result_without['prompt_tokens'] * 100) if result_without['prompt_tokens'] > 0 else 0
tool_count_diff = result_with['tool_count'] - result_without['tool_count']
print(f"\n[输入 Prompt Token 对比]")
print(f" 不使用 finish_search: {result_without['prompt_tokens']:,} tokens")
print(f" 使用 finish_search: {result_with['prompt_tokens']:,} tokens")
print(f" 差异: {prompt_token_diff:+,} tokens ({prompt_token_diff_percent:+.2f}%)")
print(f"\n[工具数量对比]")
print(f" 不使用 finish_search: {result_without['tool_count']} 个工具")
print(f" 使用 finish_search: {result_with['tool_count']} 个工具")
print(f" 差异: {tool_count_diff:+d} 个工具")
print(f"\n[工具列表对比]")
without_tools = set(result_without['tool_names'])
with_tools = set(result_with['tool_names'])
only_with = with_tools - without_tools
only_without = without_tools - with_tools
if only_with:
print(f" 仅在 '使用 finish_search' 中的工具: {', '.join(only_with)}")
if only_without:
print(f" 仅在 '不使用 finish_search' 中的工具: {', '.join(only_without)}")
if not only_with and not only_without:
print(f" 工具列表相同(除了 finish_search")
# 显示其他token信息
print(f"\n[其他 Token 信息]")
print(f" Completion Tokens (不使用 finish_search): {result_without.get('completion_tokens', 0):,}")
print(f" Completion Tokens (使用 finish_search): {result_with.get('completion_tokens', 0):,}")
print(f" 总 Tokens (不使用 finish_search): {result_without.get('total_tokens', 0):,}")
print(f" 总 Tokens (使用 finish_search): {result_with.get('total_tokens', 0):,}")
print("\n" + "=" * 80)
return {
"question": question,
"without_finish_search": result_without,
"with_finish_search": result_with,
"comparison": {
"prompt_token_diff": prompt_token_diff,
"prompt_token_diff_percent": prompt_token_diff_percent,
"tool_count_diff": tool_count_diff,
},
}
def main() -> None:
parser = argparse.ArgumentParser(
description="对比使用 finish_search 工具与否的 token 消耗差异"
)
parser.add_argument(
"--chat-id",
default="compare_finish_search",
help="测试用的聊天ID默认: compare_finish_search",
)
parser.add_argument(
"--output",
"-o",
help="将结果保存到JSON文件可选",
)
args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False)
# 交互式输入问题
print("\n" + "=" * 80)
print("finish_search 工具 Token 消耗对比测试工具")
print("=" * 80)
question = input("\n请输入要查询的问题: ").strip()
if not question:
print("错误: 问题不能为空")
return
# 连接数据库
try:
db.connect(reuse_if_open=True)
except Exception as e:
logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}")
return
# 运行对比测试
try:
result = asyncio.run(
compare_prompt_tokens(
question=question,
chat_id=args.chat_id,
)
)
# 如果指定了输出文件,保存结果
if args.output:
# 将thinking_steps转换为可序列化的格式
output_result = result.copy()
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试")
except Exception as e:
logger.error(f"测试失败: {e}", exc_info=True)
print(f"\n[错误] 测试失败: {e}")
finally:
try:
db.close()
except Exception:
pass
if __name__ == "__main__":
main()