mirror of https://github.com/Mai-with-u/MaiBot.git
568 lines
21 KiB
Python
568 lines
21 KiB
Python
"""
|
||
模拟 Expression 合并过程
|
||
|
||
用法:
|
||
python scripts/expression_merge_simulation.py
|
||
或指定 chat_id:
|
||
python scripts/expression_merge_simulation.py --chat-id <chat_id>
|
||
或指定相似度阈值:
|
||
python scripts/expression_merge_simulation.py --similarity-threshold 0.8
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import json
|
||
import argparse
|
||
import asyncio
|
||
import random
|
||
from typing import List, Dict, Tuple, Optional
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
|
||
# Add project root to Python path
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.insert(0, project_root)
|
||
|
||
# Import after setting up path (required for project imports)
|
||
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
|
||
from src.bw_learner.learner_utils import calculate_style_similarity # noqa: E402
|
||
from src.llm_models.utils_model import LLMRequest # noqa: E402
|
||
from src.config.config import model_config # noqa: E402
|
||
|
||
|
||
def get_chat_name(chat_id: str) -> str:
|
||
"""根据 chat_id 获取聊天名称"""
|
||
try:
|
||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||
if chat_stream is None:
|
||
return f"未知聊天 ({chat_id[:8]}...)"
|
||
|
||
if chat_stream.group_name:
|
||
return f"{chat_stream.group_name}"
|
||
elif chat_stream.user_nickname:
|
||
return f"{chat_stream.user_nickname}的私聊"
|
||
else:
|
||
return f"未知聊天 ({chat_id[:8]}...)"
|
||
except Exception:
|
||
return f"查询失败 ({chat_id[:8]}...)"
|
||
|
||
|
||
def parse_content_list(stored_list: Optional[str]) -> List[str]:
|
||
"""解析 content_list JSON 字符串为列表"""
|
||
if not stored_list:
|
||
return []
|
||
try:
|
||
data = json.loads(stored_list)
|
||
except json.JSONDecodeError:
|
||
return []
|
||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||
|
||
|
||
def parse_style_list(stored_list: Optional[str]) -> List[str]:
|
||
"""解析 style_list JSON 字符串为列表"""
|
||
if not stored_list:
|
||
return []
|
||
try:
|
||
data = json.loads(stored_list)
|
||
except json.JSONDecodeError:
|
||
return []
|
||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||
|
||
|
||
def find_exact_style_match(
|
||
expressions: List[Expression],
|
||
target_style: str,
|
||
chat_id: str,
|
||
exclude_ids: set
|
||
) -> Optional[Expression]:
|
||
"""
|
||
查找具有完全匹配 style 的 Expression 记录
|
||
检查 style 字段和 style_list 中的每一项
|
||
"""
|
||
for expr in expressions:
|
||
if expr.chat_id != chat_id or expr.id in exclude_ids:
|
||
continue
|
||
|
||
# 检查 style 字段
|
||
if expr.style == target_style:
|
||
return expr
|
||
|
||
# 检查 style_list 中的每一项
|
||
style_list = parse_style_list(expr.style_list)
|
||
if target_style in style_list:
|
||
return expr
|
||
|
||
return None
|
||
|
||
|
||
def find_similar_style_expression(
|
||
expressions: List[Expression],
|
||
target_style: str,
|
||
chat_id: str,
|
||
similarity_threshold: float,
|
||
exclude_ids: set
|
||
) -> Optional[Tuple[Expression, float]]:
|
||
"""
|
||
查找具有相似 style 的 Expression 记录
|
||
检查 style 字段和 style_list 中的每一项
|
||
|
||
Returns:
|
||
(Expression, similarity) 或 None
|
||
"""
|
||
best_match = None
|
||
best_similarity = 0.0
|
||
|
||
for expr in expressions:
|
||
if expr.chat_id != chat_id or expr.id in exclude_ids:
|
||
continue
|
||
|
||
# 检查 style 字段
|
||
similarity = calculate_style_similarity(target_style, expr.style)
|
||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||
best_similarity = similarity
|
||
best_match = expr
|
||
|
||
# 检查 style_list 中的每一项
|
||
style_list = parse_style_list(expr.style_list)
|
||
for existing_style in style_list:
|
||
similarity = calculate_style_similarity(target_style, existing_style)
|
||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||
best_similarity = similarity
|
||
best_match = expr
|
||
|
||
if best_match:
|
||
return (best_match, best_similarity)
|
||
return None
|
||
|
||
|
||
async def compose_situation_text(content_list: List[str], summary_model: LLMRequest) -> str:
|
||
"""组合 situation 文本,尝试使用 LLM 总结"""
|
||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||
if not sanitized:
|
||
return ""
|
||
|
||
if len(sanitized) == 1:
|
||
return sanitized[0]
|
||
|
||
# 尝试使用 LLM 总结
|
||
prompt = (
|
||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||
"长度不超过20个字,保留共同特点:\n"
|
||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||
)
|
||
|
||
try:
|
||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||
summary = summary.strip()
|
||
if summary:
|
||
return summary
|
||
except Exception as e:
|
||
print(f" ⚠️ LLM 总结 situation 失败: {e}")
|
||
|
||
# 如果总结失败,返回用 "/" 连接的字符串
|
||
return "/".join(sanitized)
|
||
|
||
|
||
async def compose_style_text(style_list: List[str], summary_model: LLMRequest) -> str:
|
||
"""组合 style 文本,尝试使用 LLM 总结"""
|
||
sanitized = [s.strip() for s in style_list if s.strip()]
|
||
if not sanitized:
|
||
return ""
|
||
|
||
if len(sanitized) == 1:
|
||
return sanitized[0]
|
||
|
||
# 尝试使用 LLM 总结
|
||
prompt = (
|
||
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
|
||
"长度不超过20个字,保留共同特点:\n"
|
||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||
)
|
||
|
||
try:
|
||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||
|
||
print(f"Prompt:{prompt} Summary:{summary}")
|
||
|
||
summary = summary.strip()
|
||
if summary:
|
||
return summary
|
||
except Exception as e:
|
||
print(f" ⚠️ LLM 总结 style 失败: {e}")
|
||
|
||
# 如果总结失败,返回第一个
|
||
return sanitized[0]
|
||
|
||
|
||
async def simulate_merge(
|
||
expressions: List[Expression],
|
||
similarity_threshold: float = 0.75,
|
||
use_llm: bool = False,
|
||
max_samples: int = 10,
|
||
) -> Dict:
|
||
"""
|
||
模拟合并过程
|
||
|
||
Args:
|
||
expressions: Expression 列表(从数据库读出的原始记录)
|
||
similarity_threshold: style 相似度阈值
|
||
use_llm: 是否使用 LLM 进行实际总结
|
||
max_samples: 最多随机抽取的 Expression 数量(为 0 或 None 时表示不限制)
|
||
|
||
Returns:
|
||
包含合并统计信息的字典
|
||
"""
|
||
# 如果样本太多,随机抽取一部分进行模拟,避免运行时间过长
|
||
if max_samples and len(expressions) > max_samples:
|
||
expressions = random.sample(expressions, max_samples)
|
||
|
||
# 按 chat_id 分组
|
||
expressions_by_chat = defaultdict(list)
|
||
for expr in expressions:
|
||
expressions_by_chat[expr.chat_id].append(expr)
|
||
|
||
# 初始化 LLM 模型(如果需要)
|
||
summary_model = None
|
||
if use_llm:
|
||
try:
|
||
summary_model = LLMRequest(
|
||
model_set=model_config.model_task_config.utils_small,
|
||
request_type="expression.summary"
|
||
)
|
||
print("✅ LLM 模型已初始化,将进行实际总结")
|
||
except Exception as e:
|
||
print(f"⚠️ LLM 模型初始化失败: {e},将跳过 LLM 总结")
|
||
use_llm = False
|
||
|
||
merge_stats = {
|
||
"total_expressions": len(expressions),
|
||
"total_chats": len(expressions_by_chat),
|
||
"exact_matches": 0,
|
||
"similar_matches": 0,
|
||
"new_records": 0,
|
||
"merge_details": [],
|
||
"chat_stats": {},
|
||
"use_llm": use_llm
|
||
}
|
||
|
||
# 为每个 chat_id 模拟合并
|
||
for chat_id, chat_expressions in expressions_by_chat.items():
|
||
chat_name = get_chat_name(chat_id)
|
||
chat_stat = {
|
||
"chat_id": chat_id,
|
||
"chat_name": chat_name,
|
||
"total": len(chat_expressions),
|
||
"exact_matches": 0,
|
||
"similar_matches": 0,
|
||
"new_records": 0,
|
||
"merges": []
|
||
}
|
||
|
||
processed_ids = set()
|
||
|
||
for expr in chat_expressions:
|
||
if expr.id in processed_ids:
|
||
continue
|
||
|
||
target_style = expr.style
|
||
target_situation = expr.situation
|
||
|
||
# 第一层:检查完全匹配
|
||
exact_match = find_exact_style_match(
|
||
chat_expressions,
|
||
target_style,
|
||
chat_id,
|
||
{expr.id}
|
||
)
|
||
|
||
if exact_match:
|
||
# 完全匹配(不使用 LLM 总结)
|
||
# 模拟合并后的 content_list 和 style_list
|
||
target_content_list = parse_content_list(exact_match.content_list)
|
||
target_content_list.append(target_situation)
|
||
|
||
target_style_list = parse_style_list(exact_match.style_list)
|
||
if exact_match.style and exact_match.style not in target_style_list:
|
||
target_style_list.append(exact_match.style)
|
||
if target_style not in target_style_list:
|
||
target_style_list.append(target_style)
|
||
|
||
merge_info = {
|
||
"type": "exact",
|
||
"source_id": expr.id,
|
||
"target_id": exact_match.id,
|
||
"source_style": target_style,
|
||
"target_style": exact_match.style,
|
||
"source_situation": target_situation,
|
||
"target_situation": exact_match.situation,
|
||
"similarity": 1.0,
|
||
"merged_content_list": target_content_list,
|
||
"merged_style_list": target_style_list,
|
||
"merged_situation": exact_match.situation, # 完全匹配时保持原 situation
|
||
"merged_style": exact_match.style # 完全匹配时保持原 style
|
||
}
|
||
chat_stat["exact_matches"] += 1
|
||
chat_stat["merges"].append(merge_info)
|
||
merge_stats["exact_matches"] += 1
|
||
processed_ids.add(expr.id)
|
||
continue
|
||
|
||
# 第二层:检查相似匹配
|
||
similar_match = find_similar_style_expression(
|
||
chat_expressions,
|
||
target_style,
|
||
chat_id,
|
||
similarity_threshold,
|
||
{expr.id}
|
||
)
|
||
|
||
if similar_match:
|
||
match_expr, similarity = similar_match
|
||
# 相似匹配(使用 LLM 总结)
|
||
# 模拟合并后的 content_list 和 style_list
|
||
target_content_list = parse_content_list(match_expr.content_list)
|
||
target_content_list.append(target_situation)
|
||
|
||
target_style_list = parse_style_list(match_expr.style_list)
|
||
if match_expr.style and match_expr.style not in target_style_list:
|
||
target_style_list.append(match_expr.style)
|
||
if target_style not in target_style_list:
|
||
target_style_list.append(target_style)
|
||
|
||
# 使用 LLM 总结(如果启用)
|
||
merged_situation = match_expr.situation
|
||
merged_style = match_expr.style or target_style
|
||
|
||
if use_llm and summary_model:
|
||
try:
|
||
merged_situation = await compose_situation_text(target_content_list, summary_model)
|
||
merged_style = await compose_style_text(target_style_list, summary_model)
|
||
except Exception as e:
|
||
print(f" ⚠️ 处理记录 {expr.id} 时 LLM 总结失败: {e}")
|
||
# 如果总结失败,使用 fallback
|
||
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
|
||
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
|
||
else:
|
||
# 不使用 LLM 时,使用简单拼接
|
||
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
|
||
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
|
||
|
||
merge_info = {
|
||
"type": "similar",
|
||
"source_id": expr.id,
|
||
"target_id": match_expr.id,
|
||
"source_style": target_style,
|
||
"target_style": match_expr.style,
|
||
"source_situation": target_situation,
|
||
"target_situation": match_expr.situation,
|
||
"similarity": similarity,
|
||
"merged_content_list": target_content_list,
|
||
"merged_style_list": target_style_list,
|
||
"merged_situation": merged_situation,
|
||
"merged_style": merged_style,
|
||
"llm_used": use_llm and summary_model is not None
|
||
}
|
||
chat_stat["similar_matches"] += 1
|
||
chat_stat["merges"].append(merge_info)
|
||
merge_stats["similar_matches"] += 1
|
||
processed_ids.add(expr.id)
|
||
continue
|
||
|
||
# 没有匹配,作为新记录
|
||
chat_stat["new_records"] += 1
|
||
merge_stats["new_records"] += 1
|
||
processed_ids.add(expr.id)
|
||
|
||
merge_stats["chat_stats"][chat_id] = chat_stat
|
||
merge_stats["merge_details"].extend(chat_stat["merges"])
|
||
|
||
return merge_stats
|
||
|
||
|
||
def print_merge_results(stats: Dict, show_details: bool = True, max_details: int = 50):
|
||
"""打印合并结果"""
|
||
print("\n" + "=" * 80)
|
||
print("Expression 合并模拟结果")
|
||
print("=" * 80)
|
||
|
||
print("\n📊 总体统计:")
|
||
print(f" 总 Expression 数: {stats['total_expressions']}")
|
||
print(f" 总聊天数: {stats['total_chats']}")
|
||
print(f" 完全匹配合并: {stats['exact_matches']}")
|
||
print(f" 相似匹配合并: {stats['similar_matches']}")
|
||
print(f" 新记录(无匹配): {stats['new_records']}")
|
||
if stats.get('use_llm'):
|
||
print(" LLM 总结: 已启用")
|
||
else:
|
||
print(" LLM 总结: 未启用(仅模拟)")
|
||
|
||
total_merges = stats['exact_matches'] + stats['similar_matches']
|
||
if stats['total_expressions'] > 0:
|
||
merge_ratio = (total_merges / stats['total_expressions']) * 100
|
||
print(f" 合并比例: {merge_ratio:.1f}%")
|
||
|
||
# 按聊天分组显示
|
||
print("\n📋 按聊天分组统计:")
|
||
for chat_id, chat_stat in stats['chat_stats'].items():
|
||
print(f"\n {chat_stat['chat_name']} ({chat_id[:8]}...):")
|
||
print(f" 总数: {chat_stat['total']}")
|
||
print(f" 完全匹配: {chat_stat['exact_matches']}")
|
||
print(f" 相似匹配: {chat_stat['similar_matches']}")
|
||
print(f" 新记录: {chat_stat['new_records']}")
|
||
|
||
# 显示合并详情
|
||
if show_details and stats['merge_details']:
|
||
print(f"\n📝 合并详情 (显示前 {min(max_details, len(stats['merge_details']))} 条):")
|
||
print()
|
||
|
||
for idx, merge in enumerate(stats['merge_details'][:max_details], 1):
|
||
merge_type = "完全匹配" if merge['type'] == 'exact' else f"相似匹配 (相似度: {merge['similarity']:.3f})"
|
||
print(f" {idx}. {merge_type}")
|
||
print(f" 源记录 ID: {merge['source_id']}")
|
||
print(f" 目标记录 ID: {merge['target_id']}")
|
||
print(f" 源 Style: {merge['source_style'][:50]}")
|
||
print(f" 目标 Style: {merge['target_style'][:50]}")
|
||
print(f" 源 Situation: {merge['source_situation'][:50]}")
|
||
print(f" 目标 Situation: {merge['target_situation'][:50]}")
|
||
|
||
# 显示合并后的结果
|
||
if 'merged_situation' in merge:
|
||
print(f" → 合并后 Situation: {merge['merged_situation'][:50]}")
|
||
if 'merged_style' in merge:
|
||
print(f" → 合并后 Style: {merge['merged_style'][:50]}")
|
||
if merge.get('llm_used'):
|
||
print(" → LLM 总结: 已使用")
|
||
elif merge['type'] == 'similar':
|
||
print(" → LLM 总结: 未使用(模拟模式)")
|
||
|
||
# 显示合并后的列表
|
||
if 'merged_content_list' in merge and len(merge['merged_content_list']) > 1:
|
||
print(f" → Content List ({len(merge['merged_content_list'])} 项): {', '.join(merge['merged_content_list'][:3])}")
|
||
if len(merge['merged_content_list']) > 3:
|
||
print(f" ... 还有 {len(merge['merged_content_list']) - 3} 项")
|
||
if 'merged_style_list' in merge and len(merge['merged_style_list']) > 1:
|
||
print(f" → Style List ({len(merge['merged_style_list'])} 项): {', '.join(merge['merged_style_list'][:3])}")
|
||
if len(merge['merged_style_list']) > 3:
|
||
print(f" ... 还有 {len(merge['merged_style_list']) - 3} 项")
|
||
print()
|
||
|
||
if len(stats['merge_details']) > max_details:
|
||
print(f" ... 还有 {len(stats['merge_details']) - max_details} 条合并记录未显示")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description="模拟 Expression 合并过程")
|
||
parser.add_argument(
|
||
"--chat-id",
|
||
type=str,
|
||
default=None,
|
||
help="指定要分析的 chat_id(不指定则分析所有)"
|
||
)
|
||
parser.add_argument(
|
||
"--similarity-threshold",
|
||
type=float,
|
||
default=0.75,
|
||
help="相似度阈值 (0-1, 默认: 0.75)"
|
||
)
|
||
parser.add_argument(
|
||
"--no-details",
|
||
action="store_true",
|
||
help="不显示详细信息,只显示统计"
|
||
)
|
||
parser.add_argument(
|
||
"--max-details",
|
||
type=int,
|
||
default=50,
|
||
help="最多显示的合并详情数 (默认: 50)"
|
||
)
|
||
parser.add_argument(
|
||
"--output",
|
||
type=str,
|
||
default=None,
|
||
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
|
||
)
|
||
parser.add_argument(
|
||
"--use-llm",
|
||
action="store_true",
|
||
help="启用 LLM 进行实际总结(默认: 仅模拟,不调用 LLM)"
|
||
)
|
||
parser.add_argument(
|
||
"--max-samples",
|
||
type=int,
|
||
default=10,
|
||
help="最多随机抽取的 Expression 数量 (默认: 10,设置为 0 表示不限制)"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 验证阈值
|
||
if not 0 <= args.similarity_threshold <= 1:
|
||
print("错误: similarity-threshold 必须在 0-1 之间")
|
||
return
|
||
|
||
# 确定输出文件路径
|
||
if args.output:
|
||
output_file = args.output
|
||
else:
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
output_dir = os.path.join(project_root, "data", "temp")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
output_file = os.path.join(output_dir, f"expression_merge_simulation_{timestamp}.txt")
|
||
|
||
# 查询 Expression 记录
|
||
print("正在从数据库加载Expression数据...")
|
||
try:
|
||
if args.chat_id:
|
||
expressions = list(Expression.select().where(Expression.chat_id == args.chat_id))
|
||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录 (chat_id: {args.chat_id})")
|
||
else:
|
||
expressions = list(Expression.select())
|
||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
|
||
except Exception as e:
|
||
print(f"❌ 加载数据失败: {e}")
|
||
return
|
||
|
||
if not expressions:
|
||
print("❌ 数据库中没有找到Expression记录")
|
||
return
|
||
|
||
# 执行合并模拟
|
||
print(f"\n正在模拟合并过程(相似度阈值: {args.similarity_threshold},最大样本数: {args.max_samples})...")
|
||
if args.use_llm:
|
||
print("⚠️ 已启用 LLM 总结,将进行实际的 API 调用")
|
||
else:
|
||
print("ℹ️ 未启用 LLM 总结,仅进行模拟(使用 --use-llm 启用实际 LLM 调用)")
|
||
|
||
stats = asyncio.run(
|
||
simulate_merge(
|
||
expressions,
|
||
similarity_threshold=args.similarity_threshold,
|
||
use_llm=args.use_llm,
|
||
max_samples=args.max_samples,
|
||
)
|
||
)
|
||
|
||
# 输出结果
|
||
original_stdout = sys.stdout
|
||
try:
|
||
with open(output_file, "w", encoding="utf-8") as f:
|
||
sys.stdout = f
|
||
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
|
||
sys.stdout = original_stdout
|
||
|
||
# 同时在控制台输出
|
||
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
|
||
|
||
except Exception as e:
|
||
sys.stdout = original_stdout
|
||
print(f"❌ 写入文件失败: {e}")
|
||
return
|
||
|
||
print(f"\n✅ 模拟结果已保存到: {output_file}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|