""" 分析expression库中situation和style的相似度 用法: python scripts/expression_similarity_analysis.py 或指定阈值: python scripts/expression_similarity_analysis.py --situation-threshold 0.8 --style-threshold 0.7 """ import sys import os import argparse from typing import List, Tuple from collections import defaultdict from difflib import SequenceMatcher from datetime import datetime # Add project root to Python path project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) # Import after setting up path (required for project imports) from src.common.database.database_model import Expression, ChatStreams # noqa: E402 from src.config.config import global_config # noqa: E402 from src.chat.message_receive.chat_stream import get_chat_manager # noqa: E402 class TeeOutput: """同时输出到控制台和文件的类""" def __init__(self, file_path: str): self.file = open(file_path, "w", encoding="utf-8") self.console = sys.stdout def write(self, text: str): """写入文本到控制台和文件""" self.console.write(text) self.file.write(text) self.file.flush() # 立即刷新到文件 def flush(self): """刷新输出""" self.console.flush() self.file.flush() def close(self): """关闭文件""" if self.file: self.file.close() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() return False def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None: """ 解析'platform:id:type'为chat_id,直接复用 ChatManager 的逻辑 """ try: parts = stream_config_str.split(":") if len(parts) != 3: return None platform = parts[0] id_str = parts[1] stream_type = parts[2] is_group = stream_type == "group" return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) except Exception: return None def build_chat_id_groups() -> dict[str, set[str]]: """ 根据expression_groups配置,构建chat_id到相关chat_id集合的映射 Returns: dict: {chat_id: set of related chat_ids (including itself)} """ groups = global_config.expression.expression_groups chat_id_groups: dict[str, set[str]] = {} # 检查是否存在全局共享组(包含"*"的组) global_group_exists = any("*" in group for group in groups) if global_group_exists: # 如果存在全局共享组,收集所有配置中的chat_id all_chat_ids = set() for group in groups: for stream_config_str in group: if stream_config_str == "*": continue if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str): all_chat_ids.add(chat_id_candidate) # 所有chat_id都互相相关 for chat_id in all_chat_ids: chat_id_groups[chat_id] = all_chat_ids.copy() else: # 处理普通组 for group in groups: group_chat_ids = set() for stream_config_str in group: if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str): group_chat_ids.add(chat_id_candidate) # 组内的所有chat_id都互相相关 for chat_id in group_chat_ids: if chat_id not in chat_id_groups: chat_id_groups[chat_id] = set() chat_id_groups[chat_id].update(group_chat_ids) # 确保每个chat_id至少包含自身 for chat_id in chat_id_groups: chat_id_groups[chat_id].add(chat_id) return chat_id_groups def are_chat_ids_related(chat_id1: str, chat_id2: str, chat_id_groups: dict[str, set[str]]) -> bool: """ 判断两个chat_id是否相关(相同或同组) Args: chat_id1: 第一个chat_id chat_id2: 第二个chat_id chat_id_groups: chat_id到相关chat_id集合的映射 Returns: bool: 如果两个chat_id相同或同组,返回True """ if chat_id1 == chat_id2: return True # 如果chat_id1在映射中,检查chat_id2是否在其相关集合中 if chat_id1 in chat_id_groups: return chat_id2 in chat_id_groups[chat_id1] # 如果chat_id1不在映射中,说明它不在任何组中,只与自己相关 return False def get_chat_name(chat_id: str) -> str: """根据 chat_id 获取聊天名称""" try: chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) if chat_stream is None: return f"未知聊天 ({chat_id[:8]}...)" if chat_stream.group_name: return f"{chat_stream.group_name}" elif chat_stream.user_nickname: return f"{chat_stream.user_nickname}的私聊" else: return f"未知聊天 ({chat_id[:8]}...)" except Exception: return f"查询失败 ({chat_id[:8]}...)" def text_similarity(text1: str, text2: str) -> float: """ 计算两个文本的相似度 使用SequenceMatcher计算相似度,返回0-1之间的值 在计算前会移除"使用"和"句式"这两个词 """ if not text1 or not text2: return 0.0 # 移除"使用"和"句式"这两个词 def remove_ignored_words(text: str) -> str: """移除需要忽略的词""" text = text.replace("使用", "") text = text.replace("句式", "") return text.strip() cleaned_text1 = remove_ignored_words(text1) cleaned_text2 = remove_ignored_words(text2) # 如果清理后文本为空,返回0 if not cleaned_text1 or not cleaned_text2: return 0.0 return SequenceMatcher(None, cleaned_text1, cleaned_text2).ratio() def find_similar_pairs( expressions: List[Expression], field_name: str, threshold: float, max_pairs: int = None ) -> List[Tuple[int, int, float, str, str]]: """ 找出相似的expression对 Args: expressions: Expression对象列表 field_name: 要比较的字段名 ('situation' 或 'style') threshold: 相似度阈值 (0-1) max_pairs: 最多返回的对数,None表示返回所有 Returns: List of (index1, index2, similarity, text1, text2) tuples """ similar_pairs = [] n = len(expressions) print(f"正在分析 {field_name} 字段的相似度...") print(f"总共需要比较 {n * (n - 1) // 2} 对...") for i in range(n): if (i + 1) % 100 == 0: print(f" 已处理 {i + 1}/{n} 个项目...") expr1 = expressions[i] text1 = getattr(expr1, field_name, "") for j in range(i + 1, n): expr2 = expressions[j] text2 = getattr(expr2, field_name, "") similarity = text_similarity(text1, text2) if similarity >= threshold: similar_pairs.append((i, j, similarity, text1, text2)) # 按相似度降序排序 similar_pairs.sort(key=lambda x: x[2], reverse=True) if max_pairs: similar_pairs = similar_pairs[:max_pairs] return similar_pairs def group_similar_items( expressions: List[Expression], field_name: str, threshold: float, chat_id_groups: dict[str, set[str]] ) -> List[List[int]]: """ 将相似的expression分组(仅比较相同chat_id或同组的项目) Args: expressions: Expression对象列表 field_name: 要比较的字段名 ('situation' 或 'style') threshold: 相似度阈值 (0-1) chat_id_groups: chat_id到相关chat_id集合的映射 Returns: List of groups, each group is a list of indices """ n = len(expressions) # 使用并查集的思想来分组 parent = list(range(n)) def find(x): if parent[x] != x: parent[x] = find(parent[x]) return parent[x] def union(x, y): px, py = find(x), find(y) if px != py: parent[px] = py print(f"正在对 {field_name} 字段进行分组(仅比较相同chat_id或同组的项目)...") # 统计需要比较的对数 total_pairs = 0 for i in range(n): for j in range(i + 1, n): if are_chat_ids_related(expressions[i].chat_id, expressions[j].chat_id, chat_id_groups): total_pairs += 1 print(f"总共需要比较 {total_pairs} 对(已过滤不同chat_id且不同组的项目)...") compared_pairs = 0 for i in range(n): if (i + 1) % 100 == 0: print(f" 已处理 {i + 1}/{n} 个项目...") expr1 = expressions[i] text1 = getattr(expr1, field_name, "") for j in range(i + 1, n): expr2 = expressions[j] # 只比较相同chat_id或同组的项目 if not are_chat_ids_related(expr1.chat_id, expr2.chat_id, chat_id_groups): continue compared_pairs += 1 text2 = getattr(expr2, field_name, "") similarity = text_similarity(text1, text2) if similarity >= threshold: union(i, j) # 收集分组 groups = defaultdict(list) for i in range(n): root = find(i) groups[root].append(i) # 只返回包含多个项目的组 result = [group for group in groups.values() if len(group) > 1] result.sort(key=len, reverse=True) return result def print_similarity_analysis( expressions: List[Expression], field_name: str, threshold: float, chat_id_groups: dict[str, set[str]], show_details: bool = True, max_groups: int = 20 ): """打印相似度分析结果""" print("\n" + "=" * 80) print(f"{field_name.upper()} 相似度分析 (阈值: {threshold})") print("=" * 80) # 分组分析 groups = group_similar_items(expressions, field_name, threshold, chat_id_groups) total_items = len(expressions) similar_items_count = sum(len(group) for group in groups) unique_groups = len(groups) print("\n📊 统计信息:") print(f" 总项目数: {total_items}") print(f" 相似项目数: {similar_items_count} ({similar_items_count / total_items * 100:.1f}%)") print(f" 相似组数: {unique_groups}") print(f" 平均每组项目数: {similar_items_count / unique_groups:.1f}" if unique_groups > 0 else " 平均每组项目数: 0") if not groups: print(f"\n未找到相似度 >= {threshold} 的项目组") return print(f"\n📋 相似组详情 (显示前 {min(max_groups, len(groups))} 组):") print() for group_idx, group in enumerate(groups[:max_groups], 1): print(f"组 {group_idx} (共 {len(group)} 个项目):") if show_details: # 显示组内所有项目的详细信息 for idx in group: expr = expressions[idx] text = getattr(expr, field_name, "") chat_name = get_chat_name(expr.chat_id) # 截断过长的文本 display_text = text[:60] + "..." if len(text) > 60 else text print(f" [{expr.id}] {display_text}") print(f" 聊天: {chat_name}, Count: {expr.count}") # 计算组内平均相似度 if len(group) > 1: similarities = [] above_threshold_pairs = [] # 存储满足阈值的相似对 above_threshold_count = 0 for i in range(len(group)): for j in range(i + 1, len(group)): text1 = getattr(expressions[group[i]], field_name, "") text2 = getattr(expressions[group[j]], field_name, "") sim = text_similarity(text1, text2) similarities.append(sim) if sim >= threshold: above_threshold_count += 1 # 存储满足阈值的对的信息 expr1 = expressions[group[i]] expr2 = expressions[group[j]] display_text1 = text1[:40] + "..." if len(text1) > 40 else text1 display_text2 = text2[:40] + "..." if len(text2) > 40 else text2 above_threshold_pairs.append(( expr1.id, display_text1, expr2.id, display_text2, sim )) if similarities: avg_sim = sum(similarities) / len(similarities) min_sim = min(similarities) max_sim = max(similarities) above_threshold_ratio = above_threshold_count / len(similarities) * 100 print(f" 平均相似度: {avg_sim:.3f} (范围: {min_sim:.3f} - {max_sim:.3f})") print(f" 满足阈值({threshold})的比例: {above_threshold_ratio:.1f}% ({above_threshold_count}/{len(similarities)})") # 显示满足阈值的相似对(这些是直接连接,导致它们被分到一组) if above_threshold_pairs: print(" ⚠️ 直接相似的对 (这些对导致它们被分到一组):") # 按相似度降序排序 above_threshold_pairs.sort(key=lambda x: x[4], reverse=True) for idx1, text1, idx2, text2, sim in above_threshold_pairs[:10]: # 最多显示10对 print(f" [{idx1}] ↔ [{idx2}]: {sim:.3f}") print(f" \"{text1}\" ↔ \"{text2}\"") if len(above_threshold_pairs) > 10: print(f" ... 还有 {len(above_threshold_pairs) - 10} 对满足阈值") else: print(f" ⚠️ 警告: 组内没有任何对满足阈值({threshold:.2f}),可能是通过传递性连接") else: # 只显示组内第一个项目作为示例 expr = expressions[group[0]] text = getattr(expr, field_name, "") display_text = text[:60] + "..." if len(text) > 60 else text print(f" 示例: {display_text}") print(f" ... 还有 {len(group) - 1} 个相似项目") print() if len(groups) > max_groups: print(f"... 还有 {len(groups) - max_groups} 组未显示") def main(): """主函数""" parser = argparse.ArgumentParser(description="分析expression库中situation和style的相似度") parser.add_argument( "--situation-threshold", type=float, default=0.7, help="situation相似度阈值 (0-1, 默认: 0.7)" ) parser.add_argument( "--style-threshold", type=float, default=0.7, help="style相似度阈值 (0-1, 默认: 0.7)" ) parser.add_argument( "--no-details", action="store_true", help="不显示详细信息,只显示统计" ) parser.add_argument( "--max-groups", type=int, default=20, help="最多显示的组数 (默认: 20)" ) parser.add_argument( "--output", type=str, default=None, help="输出文件路径 (默认: 自动生成带时间戳的文件)" ) args = parser.parse_args() # 验证阈值 if not 0 <= args.situation_threshold <= 1: print("错误: situation-threshold 必须在 0-1 之间") return if not 0 <= args.style_threshold <= 1: print("错误: style-threshold 必须在 0-1 之间") return # 确定输出文件路径 if args.output: output_file = args.output else: # 自动生成带时间戳的输出文件 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = os.path.join(project_root, "data", "temp") os.makedirs(output_dir, exist_ok=True) output_file = os.path.join(output_dir, f"expression_similarity_analysis_{timestamp}.txt") # 使用TeeOutput同时输出到控制台和文件 with TeeOutput(output_file) as tee: # 临时替换sys.stdout original_stdout = sys.stdout sys.stdout = tee try: print("=" * 80) print("Expression 相似度分析工具") print("=" * 80) print(f"输出文件: {output_file}") print() _run_analysis(args) finally: # 恢复原始stdout sys.stdout = original_stdout print(f"\n✅ 分析结果已保存到: {output_file}") def _run_analysis(args): """执行分析的主逻辑""" # 查询所有Expression记录 print("正在从数据库加载Expression数据...") try: expressions = list(Expression.select()) except Exception as e: print(f"❌ 加载数据失败: {e}") return if not expressions: print("❌ 数据库中没有找到Expression记录") return print(f"✅ 成功加载 {len(expressions)} 条Expression记录") print() # 构建chat_id分组映射 print("正在构建chat_id分组映射(根据expression_groups配置)...") try: chat_id_groups = build_chat_id_groups() print(f"✅ 成功构建 {len(chat_id_groups)} 个chat_id的分组映射") if chat_id_groups: # 统计分组信息 total_related = sum(len(related) for related in chat_id_groups.values()) avg_related = total_related / len(chat_id_groups) print(f" 平均每个chat_id与 {avg_related:.1f} 个chat_id相关(包括自身)") print() except Exception as e: print(f"⚠️ 构建chat_id分组映射失败: {e}") print(" 将使用默认行为:只比较相同chat_id的项目") chat_id_groups = {} # 分析situation相似度 print_similarity_analysis( expressions, "situation", args.situation_threshold, chat_id_groups, show_details=not args.no_details, max_groups=args.max_groups ) # 分析style相似度 print_similarity_analysis( expressions, "style", args.style_threshold, chat_id_groups, show_details=not args.no_details, max_groups=args.max_groups ) print("\n" + "=" * 80) print("分析完成!") print("=" * 80) if __name__ == "__main__": main()