MaiBot/scripts/expression_merge_simulation.py

568 lines
21 KiB
Python
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
模拟 Expression 合并过程
用法:
python scripts/expression_merge_simulation.py
或指定 chat_id:
python scripts/expression_merge_simulation.py --chat-id <chat_id>
或指定相似度阈值:
python scripts/expression_merge_simulation.py --similarity-threshold 0.8
"""
import sys
import os
import json
import argparse
import asyncio
import random
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# Import after setting up path (required for project imports)
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
from src.bw_learner.learner_utils import calculate_style_similarity # noqa: E402
from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config # noqa: E402
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id[:8]}...)"
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
else:
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def parse_content_list(stored_list: Optional[str]) -> List[str]:
"""解析 content_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def parse_style_list(stored_list: Optional[str]) -> List[str]:
"""解析 style_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def find_exact_style_match(
expressions: List[Expression],
target_style: str,
chat_id: str,
exclude_ids: set
) -> Optional[Expression]:
"""
查找具有完全匹配 style 的 Expression 记录
检查 style 字段和 style_list 中的每一项
"""
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
if expr.style == target_style:
return expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
if target_style in style_list:
return expr
return None
def find_similar_style_expression(
expressions: List[Expression],
target_style: str,
chat_id: str,
similarity_threshold: float,
exclude_ids: set
) -> Optional[Tuple[Expression, float]]:
"""
查找具有相似 style 的 Expression 记录
检查 style 字段和 style_list 中的每一项
Returns:
(Expression, similarity) 或 None
"""
best_match = None
best_similarity = 0.0
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
similarity = calculate_style_similarity(target_style, expr.style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
for existing_style in style_list:
similarity = calculate_style_similarity(target_style, existing_style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match:
return (best_match, best_similarity)
return None
async def compose_situation_text(content_list: List[str], summary_model: LLMRequest) -> str:
"""组合 situation 文本,尝试使用 LLM 总结"""
sanitized = [c.strip() for c in content_list if c.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 situation 失败: {e}")
# 如果总结失败,返回用 "/" 连接的字符串
return "/".join(sanitized)
async def compose_style_text(style_list: List[str], summary_model: LLMRequest) -> str:
"""组合 style 文本,尝试使用 LLM 总结"""
sanitized = [s.strip() for s in style_list if s.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
print(f"Prompt:{prompt} Summary:{summary}")
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 style 失败: {e}")
# 如果总结失败,返回第一个
return sanitized[0]
async def simulate_merge(
expressions: List[Expression],
similarity_threshold: float = 0.75,
use_llm: bool = False,
max_samples: int = 10,
) -> Dict:
"""
模拟合并过程
Args:
expressions: Expression 列表(从数据库读出的原始记录)
similarity_threshold: style 相似度阈值
use_llm: 是否使用 LLM 进行实际总结
max_samples: 最多随机抽取的 Expression 数量(为 0 或 None 时表示不限制)
Returns:
包含合并统计信息的字典
"""
# 如果样本太多,随机抽取一部分进行模拟,避免运行时间过长
if max_samples and len(expressions) > max_samples:
expressions = random.sample(expressions, max_samples)
# 按 chat_id 分组
expressions_by_chat = defaultdict(list)
for expr in expressions:
expressions_by_chat[expr.chat_id].append(expr)
# 初始化 LLM 模型(如果需要)
summary_model = None
if use_llm:
try:
summary_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="expression.summary"
)
print("✅ LLM 模型已初始化,将进行实际总结")
except Exception as e:
print(f"⚠️ LLM 模型初始化失败: {e},将跳过 LLM 总结")
use_llm = False
merge_stats = {
"total_expressions": len(expressions),
"total_chats": len(expressions_by_chat),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merge_details": [],
"chat_stats": {},
"use_llm": use_llm
}
# 为每个 chat_id 模拟合并
for chat_id, chat_expressions in expressions_by_chat.items():
chat_name = get_chat_name(chat_id)
chat_stat = {
"chat_id": chat_id,
"chat_name": chat_name,
"total": len(chat_expressions),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merges": []
}
processed_ids = set()
for expr in chat_expressions:
if expr.id in processed_ids:
continue
target_style = expr.style
target_situation = expr.situation
# 第一层:检查完全匹配
exact_match = find_exact_style_match(
chat_expressions,
target_style,
chat_id,
{expr.id}
)
if exact_match:
# 完全匹配(不使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(exact_match.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(exact_match.style_list)
if exact_match.style and exact_match.style not in target_style_list:
target_style_list.append(exact_match.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
merge_info = {
"type": "exact",
"source_id": expr.id,
"target_id": exact_match.id,
"source_style": target_style,
"target_style": exact_match.style,
"source_situation": target_situation,
"target_situation": exact_match.situation,
"similarity": 1.0,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": exact_match.situation, # 完全匹配时保持原 situation
"merged_style": exact_match.style # 完全匹配时保持原 style
}
chat_stat["exact_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["exact_matches"] += 1
processed_ids.add(expr.id)
continue
# 第二层:检查相似匹配
similar_match = find_similar_style_expression(
chat_expressions,
target_style,
chat_id,
similarity_threshold,
{expr.id}
)
if similar_match:
match_expr, similarity = similar_match
# 相似匹配(使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(match_expr.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(match_expr.style_list)
if match_expr.style and match_expr.style not in target_style_list:
target_style_list.append(match_expr.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
# 使用 LLM 总结(如果启用)
merged_situation = match_expr.situation
merged_style = match_expr.style or target_style
if use_llm and summary_model:
try:
merged_situation = await compose_situation_text(target_content_list, summary_model)
merged_style = await compose_style_text(target_style_list, summary_model)
except Exception as e:
print(f" ⚠️ 处理记录 {expr.id} 时 LLM 总结失败: {e}")
# 如果总结失败,使用 fallback
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
else:
# 不使用 LLM 时,使用简单拼接
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
merge_info = {
"type": "similar",
"source_id": expr.id,
"target_id": match_expr.id,
"source_style": target_style,
"target_style": match_expr.style,
"source_situation": target_situation,
"target_situation": match_expr.situation,
"similarity": similarity,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": merged_situation,
"merged_style": merged_style,
"llm_used": use_llm and summary_model is not None
}
chat_stat["similar_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["similar_matches"] += 1
processed_ids.add(expr.id)
continue
# 没有匹配,作为新记录
chat_stat["new_records"] += 1
merge_stats["new_records"] += 1
processed_ids.add(expr.id)
merge_stats["chat_stats"][chat_id] = chat_stat
merge_stats["merge_details"].extend(chat_stat["merges"])
return merge_stats
def print_merge_results(stats: Dict, show_details: bool = True, max_details: int = 50):
"""打印合并结果"""
print("\n" + "=" * 80)
print("Expression 合并模拟结果")
print("=" * 80)
print("\n📊 总体统计:")
print(f" 总 Expression 数: {stats['total_expressions']}")
print(f" 总聊天数: {stats['total_chats']}")
print(f" 完全匹配合并: {stats['exact_matches']}")
print(f" 相似匹配合并: {stats['similar_matches']}")
print(f" 新记录(无匹配): {stats['new_records']}")
if stats.get('use_llm'):
print(" LLM 总结: 已启用")
else:
print(" LLM 总结: 未启用(仅模拟)")
total_merges = stats['exact_matches'] + stats['similar_matches']
if stats['total_expressions'] > 0:
merge_ratio = (total_merges / stats['total_expressions']) * 100
print(f" 合并比例: {merge_ratio:.1f}%")
# 按聊天分组显示
print("\n📋 按聊天分组统计:")
for chat_id, chat_stat in stats['chat_stats'].items():
print(f"\n {chat_stat['chat_name']} ({chat_id[:8]}...):")
print(f" 总数: {chat_stat['total']}")
print(f" 完全匹配: {chat_stat['exact_matches']}")
print(f" 相似匹配: {chat_stat['similar_matches']}")
print(f" 新记录: {chat_stat['new_records']}")
# 显示合并详情
if show_details and stats['merge_details']:
print(f"\n📝 合并详情 (显示前 {min(max_details, len(stats['merge_details']))} 条):")
print()
for idx, merge in enumerate(stats['merge_details'][:max_details], 1):
merge_type = "完全匹配" if merge['type'] == 'exact' else f"相似匹配 (相似度: {merge['similarity']:.3f})"
print(f" {idx}. {merge_type}")
print(f" 源记录 ID: {merge['source_id']}")
print(f" 目标记录 ID: {merge['target_id']}")
print(f" 源 Style: {merge['source_style'][:50]}")
print(f" 目标 Style: {merge['target_style'][:50]}")
print(f" 源 Situation: {merge['source_situation'][:50]}")
print(f" 目标 Situation: {merge['target_situation'][:50]}")
# 显示合并后的结果
if 'merged_situation' in merge:
print(f" → 合并后 Situation: {merge['merged_situation'][:50]}")
if 'merged_style' in merge:
print(f" → 合并后 Style: {merge['merged_style'][:50]}")
if merge.get('llm_used'):
print(" → LLM 总结: 已使用")
elif merge['type'] == 'similar':
print(" → LLM 总结: 未使用(模拟模式)")
# 显示合并后的列表
if 'merged_content_list' in merge and len(merge['merged_content_list']) > 1:
print(f" → Content List ({len(merge['merged_content_list'])} 项): {', '.join(merge['merged_content_list'][:3])}")
if len(merge['merged_content_list']) > 3:
print(f" ... 还有 {len(merge['merged_content_list']) - 3}")
if 'merged_style_list' in merge and len(merge['merged_style_list']) > 1:
print(f" → Style List ({len(merge['merged_style_list'])} 项): {', '.join(merge['merged_style_list'][:3])}")
if len(merge['merged_style_list']) > 3:
print(f" ... 还有 {len(merge['merged_style_list']) - 3}")
print()
if len(stats['merge_details']) > max_details:
print(f" ... 还有 {len(stats['merge_details']) - max_details} 条合并记录未显示")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="模拟 Expression 合并过程")
parser.add_argument(
"--chat-id",
type=str,
default=None,
help="指定要分析的 chat_id不指定则分析所有"
)
parser.add_argument(
"--similarity-threshold",
type=float,
default=0.75,
help="相似度阈值 (0-1, 默认: 0.75)"
)
parser.add_argument(
"--no-details",
action="store_true",
help="不显示详细信息,只显示统计"
)
parser.add_argument(
"--max-details",
type=int,
default=50,
help="最多显示的合并详情数 (默认: 50)"
)
parser.add_argument(
"--output",
type=str,
default=None,
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
)
parser.add_argument(
"--use-llm",
action="store_true",
help="启用 LLM 进行实际总结(默认: 仅模拟,不调用 LLM"
)
parser.add_argument(
"--max-samples",
type=int,
default=10,
help="最多随机抽取的 Expression 数量 (默认: 10设置为 0 表示不限制)"
)
args = parser.parse_args()
# 验证阈值
if not 0 <= args.similarity_threshold <= 1:
print("错误: similarity-threshold 必须在 0-1 之间")
return
# 确定输出文件路径
if args.output:
output_file = args.output
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"expression_merge_simulation_{timestamp}.txt")
# 查询 Expression 记录
print("正在从数据库加载Expression数据...")
try:
if args.chat_id:
expressions = list(Expression.select().where(Expression.chat_id == args.chat_id))
print(f"✅ 成功加载 {len(expressions)} 条Expression记录 (chat_id: {args.chat_id})")
else:
expressions = list(Expression.select())
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
except Exception as e:
print(f"❌ 加载数据失败: {e}")
return
if not expressions:
print("❌ 数据库中没有找到Expression记录")
return
# 执行合并模拟
print(f"\n正在模拟合并过程(相似度阈值: {args.similarity_threshold},最大样本数: {args.max_samples}...")
if args.use_llm:
print("⚠️ 已启用 LLM 总结,将进行实际的 API 调用")
else:
print(" 未启用 LLM 总结,仅进行模拟(使用 --use-llm 启用实际 LLM 调用)")
stats = asyncio.run(
simulate_merge(
expressions,
similarity_threshold=args.similarity_threshold,
use_llm=args.use_llm,
max_samples=args.max_samples,
)
)
# 输出结果
original_stdout = sys.stdout
try:
with open(output_file, "w", encoding="utf-8") as f:
sys.stdout = f
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
sys.stdout = original_stdout
# 同时在控制台输出
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
except Exception as e:
sys.stdout = original_stdout
print(f"❌ 写入文件失败: {e}")
return
print(f"\n✅ 模拟结果已保存到: {output_file}")
if __name__ == "__main__":
main()