mirror of https://github.com/Mai-with-u/MaiBot.git
fix:间隔过长的消息在回复器中会特殊处理
parent
6acb1ff0ed
commit
ba9b9d26a2
|
|
@ -0,0 +1,322 @@
|
|||
"""
|
||||
评估结果统计脚本
|
||||
|
||||
功能:
|
||||
1. 扫描temp目录下所有JSON文件
|
||||
2. 分析每个文件的统计信息
|
||||
3. 输出详细的统计报告
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("evaluation_stats_analyzer")
|
||||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
|
||||
|
||||
def parse_datetime(dt_str: str) -> datetime | None:
|
||||
"""解析ISO格式的日期时间字符串"""
|
||||
try:
|
||||
return datetime.fromisoformat(dt_str)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def analyze_single_file(file_path: str) -> Dict:
|
||||
"""
|
||||
分析单个JSON文件的统计信息
|
||||
|
||||
Args:
|
||||
file_path: JSON文件路径
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
file_name = os.path.basename(file_path)
|
||||
stats = {
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"error": None,
|
||||
"last_updated": None,
|
||||
"total_count": 0,
|
||||
"actual_count": 0,
|
||||
"suitable_count": 0,
|
||||
"unsuitable_count": 0,
|
||||
"suitable_rate": 0.0,
|
||||
"unique_pairs": 0,
|
||||
"evaluators": Counter(),
|
||||
"evaluation_dates": [],
|
||||
"date_range": None,
|
||||
"has_expression_id": False,
|
||||
"has_reason": False,
|
||||
"reason_count": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 基本信息
|
||||
stats["last_updated"] = data.get("last_updated")
|
||||
stats["total_count"] = data.get("total_count", 0)
|
||||
|
||||
results = data.get("manual_results", [])
|
||||
stats["actual_count"] = len(results)
|
||||
|
||||
if not results:
|
||||
return stats
|
||||
|
||||
# 统计通过/不通过
|
||||
suitable_count = sum(1 for r in results if r.get("suitable") is True)
|
||||
unsuitable_count = sum(1 for r in results if r.get("suitable") is False)
|
||||
stats["suitable_count"] = suitable_count
|
||||
stats["unsuitable_count"] = unsuitable_count
|
||||
stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0
|
||||
|
||||
# 统计唯一的(situation, style)对
|
||||
pairs: Set[Tuple[str, str]] = set()
|
||||
for r in results:
|
||||
if "situation" in r and "style" in r:
|
||||
pairs.add((r["situation"], r["style"]))
|
||||
stats["unique_pairs"] = len(pairs)
|
||||
|
||||
# 统计评估者
|
||||
for r in results:
|
||||
evaluator = r.get("evaluator", "unknown")
|
||||
stats["evaluators"][evaluator] += 1
|
||||
|
||||
# 统计评估时间
|
||||
evaluation_dates = []
|
||||
for r in results:
|
||||
evaluated_at = r.get("evaluated_at")
|
||||
if evaluated_at:
|
||||
dt = parse_datetime(evaluated_at)
|
||||
if dt:
|
||||
evaluation_dates.append(dt)
|
||||
|
||||
stats["evaluation_dates"] = evaluation_dates
|
||||
if evaluation_dates:
|
||||
min_date = min(evaluation_dates)
|
||||
max_date = max(evaluation_dates)
|
||||
stats["date_range"] = {
|
||||
"start": min_date.isoformat(),
|
||||
"end": max_date.isoformat(),
|
||||
"duration_days": (max_date - min_date).days + 1
|
||||
}
|
||||
|
||||
# 检查字段存在性
|
||||
stats["has_expression_id"] = any("expression_id" in r for r in results)
|
||||
stats["has_reason"] = any(r.get("reason") for r in results)
|
||||
stats["reason_count"] = sum(1 for r in results if r.get("reason"))
|
||||
|
||||
except Exception as e:
|
||||
stats["error"] = str(e)
|
||||
logger.error(f"分析文件 {file_name} 时出错: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def print_file_stats(stats: Dict, index: int = None):
|
||||
"""打印单个文件的统计信息"""
|
||||
prefix = f"[{index}] " if index is not None else ""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"{prefix}文件: {stats['file_name']}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
if stats["error"]:
|
||||
print(f"✗ 错误: {stats['error']}")
|
||||
return
|
||||
|
||||
print(f"文件路径: {stats['file_path']}")
|
||||
print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)")
|
||||
|
||||
if stats["last_updated"]:
|
||||
print(f"最后更新: {stats['last_updated']}")
|
||||
|
||||
print(f"\n【记录统计】")
|
||||
print(f" 文件中的 total_count: {stats['total_count']}")
|
||||
print(f" 实际记录数: {stats['actual_count']}")
|
||||
|
||||
if stats['total_count'] != stats['actual_count']:
|
||||
diff = stats['total_count'] - stats['actual_count']
|
||||
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
|
||||
|
||||
print(f"\n【评估结果统计】")
|
||||
print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)")
|
||||
print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)")
|
||||
|
||||
print(f"\n【唯一性统计】")
|
||||
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']} 条")
|
||||
if stats['actual_count'] > 0:
|
||||
duplicate_count = stats['actual_count'] - stats['unique_pairs']
|
||||
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
print(f"\n【评估者统计】")
|
||||
if stats['evaluators']:
|
||||
for evaluator, count in stats['evaluators'].most_common():
|
||||
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
print(f"\n【时间统计】")
|
||||
if stats['date_range']:
|
||||
print(f" 最早评估时间: {stats['date_range']['start']}")
|
||||
print(f" 最晚评估时间: {stats['date_range']['end']}")
|
||||
print(f" 评估时间跨度: {stats['date_range']['duration_days']} 天")
|
||||
else:
|
||||
print(" 无时间信息")
|
||||
|
||||
print(f"\n【字段统计】")
|
||||
print(f" 包含 expression_id: {'是' if stats['has_expression_id'] else '否'}")
|
||||
print(f" 包含 reason: {'是' if stats['has_reason'] else '否'}")
|
||||
if stats['has_reason']:
|
||||
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
|
||||
|
||||
|
||||
def print_summary(all_stats: List[Dict]):
|
||||
"""打印汇总统计信息"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print("汇总统计")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
total_files = len(all_stats)
|
||||
valid_files = [s for s in all_stats if not s.get("error")]
|
||||
error_files = [s for s in all_stats if s.get("error")]
|
||||
|
||||
print(f"\n【文件统计】")
|
||||
print(f" 总文件数: {total_files}")
|
||||
print(f" 成功解析: {len(valid_files)}")
|
||||
print(f" 解析失败: {len(error_files)}")
|
||||
|
||||
if error_files:
|
||||
print(f"\n 失败文件列表:")
|
||||
for stats in error_files:
|
||||
print(f" - {stats['file_name']}: {stats['error']}")
|
||||
|
||||
if not valid_files:
|
||||
print("\n没有成功解析的文件")
|
||||
return
|
||||
|
||||
# 汇总记录统计
|
||||
total_records = sum(s['actual_count'] for s in valid_files)
|
||||
total_suitable = sum(s['suitable_count'] for s in valid_files)
|
||||
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
|
||||
total_unique_pairs = set()
|
||||
|
||||
# 收集所有唯一的(situation, style)对
|
||||
for stats in valid_files:
|
||||
try:
|
||||
with open(stats['file_path'], "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
for r in results:
|
||||
if "situation" in r and "style" in r:
|
||||
total_unique_pairs.add((r["situation"], r["style"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(f"\n【记录汇总】")
|
||||
print(f" 总记录数: {total_records:,} 条")
|
||||
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
|
||||
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
|
||||
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,} 条")
|
||||
|
||||
if total_records > 0:
|
||||
duplicate_count = total_records - len(total_unique_pairs)
|
||||
duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
# 汇总评估者统计
|
||||
all_evaluators = Counter()
|
||||
for stats in valid_files:
|
||||
all_evaluators.update(stats['evaluators'])
|
||||
|
||||
print(f"\n【评估者汇总】")
|
||||
if all_evaluators:
|
||||
for evaluator, count in all_evaluators.most_common():
|
||||
rate = (count / total_records * 100) if total_records > 0 else 0
|
||||
print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
# 汇总时间范围
|
||||
all_dates = []
|
||||
for stats in valid_files:
|
||||
all_dates.extend(stats['evaluation_dates'])
|
||||
|
||||
if all_dates:
|
||||
min_date = min(all_dates)
|
||||
max_date = max(all_dates)
|
||||
print(f"\n【时间汇总】")
|
||||
print(f" 最早评估时间: {min_date.isoformat()}")
|
||||
print(f" 最晚评估时间: {max_date.isoformat()}")
|
||||
print(f" 总时间跨度: {(max_date - min_date).days + 1} 天")
|
||||
|
||||
# 文件大小汇总
|
||||
total_size = sum(s['file_size'] for s in valid_files)
|
||||
avg_size = total_size / len(valid_files) if valid_files else 0
|
||||
print(f"\n【文件大小汇总】")
|
||||
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
|
||||
print(f" 平均大小: {avg_size:,.0f} 字节 ({avg_size / 1024:.2f} KB)")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
logger.info("=" * 80)
|
||||
logger.info("开始分析评估结果统计信息")
|
||||
logger.info("=" * 80)
|
||||
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
print(f"\n✗ 错误:未找到temp目录: {TEMP_DIR}")
|
||||
logger.error(f"未找到temp目录: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
if not json_files:
|
||||
print(f"\n✗ 错误:temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
json_files.sort() # 按文件名排序
|
||||
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件")
|
||||
print("=" * 80)
|
||||
|
||||
# 分析每个文件
|
||||
all_stats = []
|
||||
for i, json_file in enumerate(json_files, 1):
|
||||
stats = analyze_single_file(json_file)
|
||||
all_stats.append(stats)
|
||||
print_file_stats(stats, index=i)
|
||||
|
||||
# 打印汇总统计
|
||||
print_summary(all_stats)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("分析完成")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
|
@ -13,7 +13,8 @@ import json
|
|||
import random
|
||||
import sys
|
||||
import os
|
||||
from typing import List, Dict
|
||||
import glob
|
||||
from typing import List, Dict, Set, Tuple
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
|
@ -27,32 +28,66 @@ logger = get_logger("expression_evaluator_llm")
|
|||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
|
||||
|
||||
|
||||
def load_manual_results() -> List[Dict]:
|
||||
"""
|
||||
加载人工评估结果
|
||||
加载人工评估结果(自动读取temp目录下所有JSON文件并合并)
|
||||
|
||||
Returns:
|
||||
人工评估结果列表
|
||||
人工评估结果列表(已去重)
|
||||
"""
|
||||
if not os.path.exists(MANUAL_EVAL_FILE):
|
||||
logger.error(f"未找到人工评估结果文件: {MANUAL_EVAL_FILE}")
|
||||
print("\n✗ 错误:未找到人工评估结果文件")
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
logger.error(f"未找到temp目录: {TEMP_DIR}")
|
||||
print("\n✗ 错误:未找到temp目录")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
logger.info(f"成功加载 {len(results)} 条人工评估结果")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"加载人工评估结果失败: {e}")
|
||||
print(f"\n✗ 加载人工评估结果失败: {e}")
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
if not json_files:
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
print("\n✗ 错误:temp目录下未找到JSON文件")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
logger.info(f"找到 {len(json_files)} 个JSON文件")
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件:")
|
||||
for json_file in json_files:
|
||||
print(f" - {os.path.basename(json_file)}")
|
||||
|
||||
# 读取并合并所有JSON文件
|
||||
all_results = []
|
||||
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
|
||||
|
||||
for json_file in json_files:
|
||||
try:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
|
||||
# 去重:使用(situation, style)作为唯一标识
|
||||
for result in results:
|
||||
if "situation" not in result or "style" not in result:
|
||||
logger.warning(f"跳过无效数据(缺少必要字段): {result}")
|
||||
continue
|
||||
|
||||
pair = (result["situation"], result["style"])
|
||||
if pair not in seen_pairs:
|
||||
seen_pairs.add(pair)
|
||||
all_results.append(result)
|
||||
|
||||
logger.info(f"从 {os.path.basename(json_file)} 加载了 {len(results)} 条结果")
|
||||
except Exception as e:
|
||||
logger.error(f"加载文件 {json_file} 失败: {e}")
|
||||
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
|
||||
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
|
|
|
|||
|
|
@ -1,567 +0,0 @@
|
|||
"""
|
||||
模拟 Expression 合并过程
|
||||
|
||||
用法:
|
||||
python scripts/expression_merge_simulation.py
|
||||
或指定 chat_id:
|
||||
python scripts/expression_merge_simulation.py --chat-id <chat_id>
|
||||
或指定相似度阈值:
|
||||
python scripts/expression_merge_simulation.py --similarity-threshold 0.8
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Import after setting up path (required for project imports)
|
||||
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
|
||||
from src.bw_learner.learner_utils import calculate_style_similarity # noqa: E402
|
||||
from src.llm_models.utils_model import LLMRequest # noqa: E402
|
||||
from src.config.config import model_config # noqa: E402
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name}"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id[:8]}...)"
|
||||
|
||||
|
||||
def parse_content_list(stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 content_list JSON 字符串为列表"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
|
||||
def parse_style_list(stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 style_list JSON 字符串为列表"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
|
||||
def find_exact_style_match(
|
||||
expressions: List[Expression],
|
||||
target_style: str,
|
||||
chat_id: str,
|
||||
exclude_ids: set
|
||||
) -> Optional[Expression]:
|
||||
"""
|
||||
查找具有完全匹配 style 的 Expression 记录
|
||||
检查 style 字段和 style_list 中的每一项
|
||||
"""
|
||||
for expr in expressions:
|
||||
if expr.chat_id != chat_id or expr.id in exclude_ids:
|
||||
continue
|
||||
|
||||
# 检查 style 字段
|
||||
if expr.style == target_style:
|
||||
return expr
|
||||
|
||||
# 检查 style_list 中的每一项
|
||||
style_list = parse_style_list(expr.style_list)
|
||||
if target_style in style_list:
|
||||
return expr
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_similar_style_expression(
|
||||
expressions: List[Expression],
|
||||
target_style: str,
|
||||
chat_id: str,
|
||||
similarity_threshold: float,
|
||||
exclude_ids: set
|
||||
) -> Optional[Tuple[Expression, float]]:
|
||||
"""
|
||||
查找具有相似 style 的 Expression 记录
|
||||
检查 style 字段和 style_list 中的每一项
|
||||
|
||||
Returns:
|
||||
(Expression, similarity) 或 None
|
||||
"""
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for expr in expressions:
|
||||
if expr.chat_id != chat_id or expr.id in exclude_ids:
|
||||
continue
|
||||
|
||||
# 检查 style 字段
|
||||
similarity = calculate_style_similarity(target_style, expr.style)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
# 检查 style_list 中的每一项
|
||||
style_list = parse_style_list(expr.style_list)
|
||||
for existing_style in style_list:
|
||||
similarity = calculate_style_similarity(target_style, existing_style)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
if best_match:
|
||||
return (best_match, best_similarity)
|
||||
return None
|
||||
|
||||
|
||||
async def compose_situation_text(content_list: List[str], summary_model: LLMRequest) -> str:
|
||||
"""组合 situation 文本,尝试使用 LLM 总结"""
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
if not sanitized:
|
||||
return ""
|
||||
|
||||
if len(sanitized) == 1:
|
||||
return sanitized[0]
|
||||
|
||||
# 尝试使用 LLM 总结
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
print(f" ⚠️ LLM 总结 situation 失败: {e}")
|
||||
|
||||
# 如果总结失败,返回用 "/" 连接的字符串
|
||||
return "/".join(sanitized)
|
||||
|
||||
|
||||
async def compose_style_text(style_list: List[str], summary_model: LLMRequest) -> str:
|
||||
"""组合 style 文本,尝试使用 LLM 总结"""
|
||||
sanitized = [s.strip() for s in style_list if s.strip()]
|
||||
if not sanitized:
|
||||
return ""
|
||||
|
||||
if len(sanitized) == 1:
|
||||
return sanitized[0]
|
||||
|
||||
# 尝试使用 LLM 总结
|
||||
prompt = (
|
||||
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
|
||||
print(f"Prompt:{prompt} Summary:{summary}")
|
||||
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
print(f" ⚠️ LLM 总结 style 失败: {e}")
|
||||
|
||||
# 如果总结失败,返回第一个
|
||||
return sanitized[0]
|
||||
|
||||
|
||||
async def simulate_merge(
|
||||
expressions: List[Expression],
|
||||
similarity_threshold: float = 0.75,
|
||||
use_llm: bool = False,
|
||||
max_samples: int = 10,
|
||||
) -> Dict:
|
||||
"""
|
||||
模拟合并过程
|
||||
|
||||
Args:
|
||||
expressions: Expression 列表(从数据库读出的原始记录)
|
||||
similarity_threshold: style 相似度阈值
|
||||
use_llm: 是否使用 LLM 进行实际总结
|
||||
max_samples: 最多随机抽取的 Expression 数量(为 0 或 None 时表示不限制)
|
||||
|
||||
Returns:
|
||||
包含合并统计信息的字典
|
||||
"""
|
||||
# 如果样本太多,随机抽取一部分进行模拟,避免运行时间过长
|
||||
if max_samples and len(expressions) > max_samples:
|
||||
expressions = random.sample(expressions, max_samples)
|
||||
|
||||
# 按 chat_id 分组
|
||||
expressions_by_chat = defaultdict(list)
|
||||
for expr in expressions:
|
||||
expressions_by_chat[expr.chat_id].append(expr)
|
||||
|
||||
# 初始化 LLM 模型(如果需要)
|
||||
summary_model = None
|
||||
if use_llm:
|
||||
try:
|
||||
summary_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression.summary"
|
||||
)
|
||||
print("✅ LLM 模型已初始化,将进行实际总结")
|
||||
except Exception as e:
|
||||
print(f"⚠️ LLM 模型初始化失败: {e},将跳过 LLM 总结")
|
||||
use_llm = False
|
||||
|
||||
merge_stats = {
|
||||
"total_expressions": len(expressions),
|
||||
"total_chats": len(expressions_by_chat),
|
||||
"exact_matches": 0,
|
||||
"similar_matches": 0,
|
||||
"new_records": 0,
|
||||
"merge_details": [],
|
||||
"chat_stats": {},
|
||||
"use_llm": use_llm
|
||||
}
|
||||
|
||||
# 为每个 chat_id 模拟合并
|
||||
for chat_id, chat_expressions in expressions_by_chat.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
chat_stat = {
|
||||
"chat_id": chat_id,
|
||||
"chat_name": chat_name,
|
||||
"total": len(chat_expressions),
|
||||
"exact_matches": 0,
|
||||
"similar_matches": 0,
|
||||
"new_records": 0,
|
||||
"merges": []
|
||||
}
|
||||
|
||||
processed_ids = set()
|
||||
|
||||
for expr in chat_expressions:
|
||||
if expr.id in processed_ids:
|
||||
continue
|
||||
|
||||
target_style = expr.style
|
||||
target_situation = expr.situation
|
||||
|
||||
# 第一层:检查完全匹配
|
||||
exact_match = find_exact_style_match(
|
||||
chat_expressions,
|
||||
target_style,
|
||||
chat_id,
|
||||
{expr.id}
|
||||
)
|
||||
|
||||
if exact_match:
|
||||
# 完全匹配(不使用 LLM 总结)
|
||||
# 模拟合并后的 content_list 和 style_list
|
||||
target_content_list = parse_content_list(exact_match.content_list)
|
||||
target_content_list.append(target_situation)
|
||||
|
||||
target_style_list = parse_style_list(exact_match.style_list)
|
||||
if exact_match.style and exact_match.style not in target_style_list:
|
||||
target_style_list.append(exact_match.style)
|
||||
if target_style not in target_style_list:
|
||||
target_style_list.append(target_style)
|
||||
|
||||
merge_info = {
|
||||
"type": "exact",
|
||||
"source_id": expr.id,
|
||||
"target_id": exact_match.id,
|
||||
"source_style": target_style,
|
||||
"target_style": exact_match.style,
|
||||
"source_situation": target_situation,
|
||||
"target_situation": exact_match.situation,
|
||||
"similarity": 1.0,
|
||||
"merged_content_list": target_content_list,
|
||||
"merged_style_list": target_style_list,
|
||||
"merged_situation": exact_match.situation, # 完全匹配时保持原 situation
|
||||
"merged_style": exact_match.style # 完全匹配时保持原 style
|
||||
}
|
||||
chat_stat["exact_matches"] += 1
|
||||
chat_stat["merges"].append(merge_info)
|
||||
merge_stats["exact_matches"] += 1
|
||||
processed_ids.add(expr.id)
|
||||
continue
|
||||
|
||||
# 第二层:检查相似匹配
|
||||
similar_match = find_similar_style_expression(
|
||||
chat_expressions,
|
||||
target_style,
|
||||
chat_id,
|
||||
similarity_threshold,
|
||||
{expr.id}
|
||||
)
|
||||
|
||||
if similar_match:
|
||||
match_expr, similarity = similar_match
|
||||
# 相似匹配(使用 LLM 总结)
|
||||
# 模拟合并后的 content_list 和 style_list
|
||||
target_content_list = parse_content_list(match_expr.content_list)
|
||||
target_content_list.append(target_situation)
|
||||
|
||||
target_style_list = parse_style_list(match_expr.style_list)
|
||||
if match_expr.style and match_expr.style not in target_style_list:
|
||||
target_style_list.append(match_expr.style)
|
||||
if target_style not in target_style_list:
|
||||
target_style_list.append(target_style)
|
||||
|
||||
# 使用 LLM 总结(如果启用)
|
||||
merged_situation = match_expr.situation
|
||||
merged_style = match_expr.style or target_style
|
||||
|
||||
if use_llm and summary_model:
|
||||
try:
|
||||
merged_situation = await compose_situation_text(target_content_list, summary_model)
|
||||
merged_style = await compose_style_text(target_style_list, summary_model)
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 处理记录 {expr.id} 时 LLM 总结失败: {e}")
|
||||
# 如果总结失败,使用 fallback
|
||||
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
|
||||
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
|
||||
else:
|
||||
# 不使用 LLM 时,使用简单拼接
|
||||
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
|
||||
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
|
||||
|
||||
merge_info = {
|
||||
"type": "similar",
|
||||
"source_id": expr.id,
|
||||
"target_id": match_expr.id,
|
||||
"source_style": target_style,
|
||||
"target_style": match_expr.style,
|
||||
"source_situation": target_situation,
|
||||
"target_situation": match_expr.situation,
|
||||
"similarity": similarity,
|
||||
"merged_content_list": target_content_list,
|
||||
"merged_style_list": target_style_list,
|
||||
"merged_situation": merged_situation,
|
||||
"merged_style": merged_style,
|
||||
"llm_used": use_llm and summary_model is not None
|
||||
}
|
||||
chat_stat["similar_matches"] += 1
|
||||
chat_stat["merges"].append(merge_info)
|
||||
merge_stats["similar_matches"] += 1
|
||||
processed_ids.add(expr.id)
|
||||
continue
|
||||
|
||||
# 没有匹配,作为新记录
|
||||
chat_stat["new_records"] += 1
|
||||
merge_stats["new_records"] += 1
|
||||
processed_ids.add(expr.id)
|
||||
|
||||
merge_stats["chat_stats"][chat_id] = chat_stat
|
||||
merge_stats["merge_details"].extend(chat_stat["merges"])
|
||||
|
||||
return merge_stats
|
||||
|
||||
|
||||
def print_merge_results(stats: Dict, show_details: bool = True, max_details: int = 50):
|
||||
"""打印合并结果"""
|
||||
print("\n" + "=" * 80)
|
||||
print("Expression 合并模拟结果")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n📊 总体统计:")
|
||||
print(f" 总 Expression 数: {stats['total_expressions']}")
|
||||
print(f" 总聊天数: {stats['total_chats']}")
|
||||
print(f" 完全匹配合并: {stats['exact_matches']}")
|
||||
print(f" 相似匹配合并: {stats['similar_matches']}")
|
||||
print(f" 新记录(无匹配): {stats['new_records']}")
|
||||
if stats.get('use_llm'):
|
||||
print(" LLM 总结: 已启用")
|
||||
else:
|
||||
print(" LLM 总结: 未启用(仅模拟)")
|
||||
|
||||
total_merges = stats['exact_matches'] + stats['similar_matches']
|
||||
if stats['total_expressions'] > 0:
|
||||
merge_ratio = (total_merges / stats['total_expressions']) * 100
|
||||
print(f" 合并比例: {merge_ratio:.1f}%")
|
||||
|
||||
# 按聊天分组显示
|
||||
print("\n📋 按聊天分组统计:")
|
||||
for chat_id, chat_stat in stats['chat_stats'].items():
|
||||
print(f"\n {chat_stat['chat_name']} ({chat_id[:8]}...):")
|
||||
print(f" 总数: {chat_stat['total']}")
|
||||
print(f" 完全匹配: {chat_stat['exact_matches']}")
|
||||
print(f" 相似匹配: {chat_stat['similar_matches']}")
|
||||
print(f" 新记录: {chat_stat['new_records']}")
|
||||
|
||||
# 显示合并详情
|
||||
if show_details and stats['merge_details']:
|
||||
print(f"\n📝 合并详情 (显示前 {min(max_details, len(stats['merge_details']))} 条):")
|
||||
print()
|
||||
|
||||
for idx, merge in enumerate(stats['merge_details'][:max_details], 1):
|
||||
merge_type = "完全匹配" if merge['type'] == 'exact' else f"相似匹配 (相似度: {merge['similarity']:.3f})"
|
||||
print(f" {idx}. {merge_type}")
|
||||
print(f" 源记录 ID: {merge['source_id']}")
|
||||
print(f" 目标记录 ID: {merge['target_id']}")
|
||||
print(f" 源 Style: {merge['source_style'][:50]}")
|
||||
print(f" 目标 Style: {merge['target_style'][:50]}")
|
||||
print(f" 源 Situation: {merge['source_situation'][:50]}")
|
||||
print(f" 目标 Situation: {merge['target_situation'][:50]}")
|
||||
|
||||
# 显示合并后的结果
|
||||
if 'merged_situation' in merge:
|
||||
print(f" → 合并后 Situation: {merge['merged_situation'][:50]}")
|
||||
if 'merged_style' in merge:
|
||||
print(f" → 合并后 Style: {merge['merged_style'][:50]}")
|
||||
if merge.get('llm_used'):
|
||||
print(" → LLM 总结: 已使用")
|
||||
elif merge['type'] == 'similar':
|
||||
print(" → LLM 总结: 未使用(模拟模式)")
|
||||
|
||||
# 显示合并后的列表
|
||||
if 'merged_content_list' in merge and len(merge['merged_content_list']) > 1:
|
||||
print(f" → Content List ({len(merge['merged_content_list'])} 项): {', '.join(merge['merged_content_list'][:3])}")
|
||||
if len(merge['merged_content_list']) > 3:
|
||||
print(f" ... 还有 {len(merge['merged_content_list']) - 3} 项")
|
||||
if 'merged_style_list' in merge and len(merge['merged_style_list']) > 1:
|
||||
print(f" → Style List ({len(merge['merged_style_list'])} 项): {', '.join(merge['merged_style_list'][:3])}")
|
||||
if len(merge['merged_style_list']) > 3:
|
||||
print(f" ... 还有 {len(merge['merged_style_list']) - 3} 项")
|
||||
print()
|
||||
|
||||
if len(stats['merge_details']) > max_details:
|
||||
print(f" ... 还有 {len(stats['merge_details']) - max_details} 条合并记录未显示")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description="模拟 Expression 合并过程")
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="指定要分析的 chat_id(不指定则分析所有)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--similarity-threshold",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="相似度阈值 (0-1, 默认: 0.75)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-details",
|
||||
action="store_true",
|
||||
help="不显示详细信息,只显示统计"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-details",
|
||||
type=int,
|
||||
default=50,
|
||||
help="最多显示的合并详情数 (默认: 50)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-llm",
|
||||
action="store_true",
|
||||
help="启用 LLM 进行实际总结(默认: 仅模拟,不调用 LLM)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-samples",
|
||||
type=int,
|
||||
default=10,
|
||||
help="最多随机抽取的 Expression 数量 (默认: 10,设置为 0 表示不限制)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证阈值
|
||||
if not 0 <= args.similarity_threshold <= 1:
|
||||
print("错误: similarity-threshold 必须在 0-1 之间")
|
||||
return
|
||||
|
||||
# 确定输出文件路径
|
||||
if args.output:
|
||||
output_file = args.output
|
||||
else:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, f"expression_merge_simulation_{timestamp}.txt")
|
||||
|
||||
# 查询 Expression 记录
|
||||
print("正在从数据库加载Expression数据...")
|
||||
try:
|
||||
if args.chat_id:
|
||||
expressions = list(Expression.select().where(Expression.chat_id == args.chat_id))
|
||||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录 (chat_id: {args.chat_id})")
|
||||
else:
|
||||
expressions = list(Expression.select())
|
||||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
|
||||
except Exception as e:
|
||||
print(f"❌ 加载数据失败: {e}")
|
||||
return
|
||||
|
||||
if not expressions:
|
||||
print("❌ 数据库中没有找到Expression记录")
|
||||
return
|
||||
|
||||
# 执行合并模拟
|
||||
print(f"\n正在模拟合并过程(相似度阈值: {args.similarity_threshold},最大样本数: {args.max_samples})...")
|
||||
if args.use_llm:
|
||||
print("⚠️ 已启用 LLM 总结,将进行实际的 API 调用")
|
||||
else:
|
||||
print("ℹ️ 未启用 LLM 总结,仅进行模拟(使用 --use-llm 启用实际 LLM 调用)")
|
||||
|
||||
stats = asyncio.run(
|
||||
simulate_merge(
|
||||
expressions,
|
||||
similarity_threshold=args.similarity_threshold,
|
||||
use_llm=args.use_llm,
|
||||
max_samples=args.max_samples,
|
||||
)
|
||||
)
|
||||
|
||||
# 输出结果
|
||||
original_stdout = sys.stdout
|
||||
try:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
sys.stdout = f
|
||||
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
|
||||
sys.stdout = original_stdout
|
||||
|
||||
# 同时在控制台输出
|
||||
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
|
||||
|
||||
except Exception as e:
|
||||
sys.stdout = original_stdout
|
||||
print(f"❌ 写入文件失败: {e}")
|
||||
return
|
||||
|
||||
print(f"\n✅ 模拟结果已保存到: {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -1,342 +0,0 @@
|
|||
import sys
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def get_expression_data() -> List[Tuple[float, float, str, str]]:
|
||||
"""获取Expression表中的数据,返回(create_date, count, chat_id, expression_type)的列表"""
|
||||
expressions = Expression.select()
|
||||
data = []
|
||||
|
||||
for expr in expressions:
|
||||
# 如果create_date为空,跳过该记录
|
||||
if expr.create_date is None:
|
||||
continue
|
||||
|
||||
data.append((expr.create_date, expr.count, expr.chat_id, expr.type))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 分离数据
|
||||
create_dates = [item[0] for item in data]
|
||||
counts = [item[1] for item in data]
|
||||
_chat_ids = [item[2] for item in data]
|
||||
_expression_types = [item[3] for item in data]
|
||||
|
||||
# 转换时间戳为datetime对象
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
time_span = max(dates) - min(dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = "%Y-%m-%d %H:%M"
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# 创建散点图
|
||||
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap="viridis")
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||
ax.set_title("表达式使用次数随时间分布散点图", fontsize=14, fontweight="bold")
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 添加颜色条
|
||||
cbar = plt.colorbar(scatter)
|
||||
cbar.set_label("数据点顺序", fontsize=10)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print("\n=== 数据统计 ===")
|
||||
print(f"总数据点数量: {len(data)}")
|
||||
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')} 到 {max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"使用次数范围: {min(counts):.1f} 到 {max(counts):.1f}")
|
||||
print(f"平均使用次数: {np.mean(counts):.2f}")
|
||||
print(f"中位数使用次数: {np.median(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建按聊天分组的散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 按chat_id分组
|
||||
chat_groups = {}
|
||||
for item in data:
|
||||
chat_id = item[2]
|
||||
if chat_id not in chat_groups:
|
||||
chat_groups[chat_id] = []
|
||||
chat_groups[chat_id].append(item)
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = "%Y-%m-%d %H:%M"
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
|
||||
# 为每个聊天分配不同颜色
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
|
||||
|
||||
for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
|
||||
create_dates = [item[0] for item in chat_data]
|
||||
counts = [item[1] for item in chat_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
chat_name = get_chat_name(chat_id)
|
||||
# 截断过长的聊天名称
|
||||
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
|
||||
|
||||
ax.scatter(
|
||||
dates,
|
||||
counts,
|
||||
alpha=0.7,
|
||||
s=40,
|
||||
c=[colors[i]],
|
||||
label=f"{display_name} ({len(chat_data)}个)",
|
||||
edgecolors="black",
|
||||
linewidth=0.5,
|
||||
)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||
ax.set_title("按聊天分组的表达式使用次数散点图", fontsize=14, fontweight="bold")
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print("\n=== 分组统计 ===")
|
||||
print(f"总聊天数量: {len(chat_groups)}")
|
||||
for chat_id, chat_data in chat_groups.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
counts = [item[1] for item in chat_data]
|
||||
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n分组散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建按表达式类型分组的散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 按type分组
|
||||
type_groups = {}
|
||||
for item in data:
|
||||
expr_type = item[3]
|
||||
if expr_type not in type_groups:
|
||||
type_groups[expr_type] = []
|
||||
type_groups[expr_type].append(item)
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = "%Y-%m-%d %H:%M"
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# 为每个类型分配不同颜色
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
|
||||
|
||||
for i, (expr_type, type_data) in enumerate(type_groups.items()):
|
||||
create_dates = [item[0] for item in type_data]
|
||||
counts = [item[1] for item in type_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
ax.scatter(
|
||||
dates,
|
||||
counts,
|
||||
alpha=0.7,
|
||||
s=40,
|
||||
c=[colors[i]],
|
||||
label=f"{expr_type} ({len(type_data)}个)",
|
||||
edgecolors="black",
|
||||
linewidth=0.5,
|
||||
)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||
ax.set_title("按表达式类型分组的散点图", fontsize=14, fontweight="bold")
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print("\n=== 类型统计 ===")
|
||||
for expr_type, type_data in type_groups.items():
|
||||
counts = [item[1] for item in type_data]
|
||||
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n类型散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("开始分析表达式数据...")
|
||||
|
||||
# 获取数据
|
||||
data = get_expression_data()
|
||||
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据(create_date不为空的数据)")
|
||||
return
|
||||
|
||||
print(f"找到 {len(data)} 条有效数据")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 生成时间戳用于文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 1. 创建基础散点图
|
||||
print("\n1. 创建基础散点图...")
|
||||
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
|
||||
|
||||
# 2. 创建按聊天分组的散点图
|
||||
print("\n2. 创建按聊天分组的散点图...")
|
||||
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
|
||||
|
||||
# 3. 创建按类型分组的散点图
|
||||
print("\n3. 创建按类型分组的散点图...")
|
||||
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
|
||||
|
||||
print("\n分析完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,559 +0,0 @@
|
|||
"""
|
||||
分析expression库中situation和style的相似度
|
||||
|
||||
用法:
|
||||
python scripts/expression_similarity_analysis.py
|
||||
或指定阈值:
|
||||
python scripts/expression_similarity_analysis.py --situation-threshold 0.8 --style-threshold 0.7
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from typing import List, Tuple
|
||||
from collections import defaultdict
|
||||
from difflib import SequenceMatcher
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Import after setting up path (required for project imports)
|
||||
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
|
||||
from src.config.config import global_config # noqa: E402
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager # noqa: E402
|
||||
|
||||
|
||||
class TeeOutput:
|
||||
"""同时输出到控制台和文件的类"""
|
||||
def __init__(self, file_path: str):
|
||||
self.file = open(file_path, "w", encoding="utf-8")
|
||||
self.console = sys.stdout
|
||||
|
||||
def write(self, text: str):
|
||||
"""写入文本到控制台和文件"""
|
||||
self.console.write(text)
|
||||
self.file.write(text)
|
||||
self.file.flush() # 立即刷新到文件
|
||||
|
||||
def flush(self):
|
||||
"""刷新输出"""
|
||||
self.console.flush()
|
||||
self.file.flush()
|
||||
|
||||
def close(self):
|
||||
"""关闭文件"""
|
||||
if self.file:
|
||||
self.file.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||
"""
|
||||
解析'platform:id:type'为chat_id,直接复用 ChatManager 的逻辑
|
||||
"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def build_chat_id_groups() -> dict[str, set[str]]:
|
||||
"""
|
||||
根据expression_groups配置,构建chat_id到相关chat_id集合的映射
|
||||
|
||||
Returns:
|
||||
dict: {chat_id: set of related chat_ids (including itself)}
|
||||
"""
|
||||
groups = global_config.expression.expression_groups
|
||||
chat_id_groups: dict[str, set[str]] = {}
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,收集所有配置中的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if stream_config_str == "*":
|
||||
continue
|
||||
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
|
||||
# 所有chat_id都互相相关
|
||||
for chat_id in all_chat_ids:
|
||||
chat_id_groups[chat_id] = all_chat_ids.copy()
|
||||
else:
|
||||
# 处理普通组
|
||||
for group in groups:
|
||||
group_chat_ids = set()
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.add(chat_id_candidate)
|
||||
|
||||
# 组内的所有chat_id都互相相关
|
||||
for chat_id in group_chat_ids:
|
||||
if chat_id not in chat_id_groups:
|
||||
chat_id_groups[chat_id] = set()
|
||||
chat_id_groups[chat_id].update(group_chat_ids)
|
||||
|
||||
# 确保每个chat_id至少包含自身
|
||||
for chat_id in chat_id_groups:
|
||||
chat_id_groups[chat_id].add(chat_id)
|
||||
|
||||
return chat_id_groups
|
||||
|
||||
|
||||
def are_chat_ids_related(chat_id1: str, chat_id2: str, chat_id_groups: dict[str, set[str]]) -> bool:
|
||||
"""
|
||||
判断两个chat_id是否相关(相同或同组)
|
||||
|
||||
Args:
|
||||
chat_id1: 第一个chat_id
|
||||
chat_id2: 第二个chat_id
|
||||
chat_id_groups: chat_id到相关chat_id集合的映射
|
||||
|
||||
Returns:
|
||||
bool: 如果两个chat_id相同或同组,返回True
|
||||
"""
|
||||
if chat_id1 == chat_id2:
|
||||
return True
|
||||
|
||||
# 如果chat_id1在映射中,检查chat_id2是否在其相关集合中
|
||||
if chat_id1 in chat_id_groups:
|
||||
return chat_id2 in chat_id_groups[chat_id1]
|
||||
|
||||
# 如果chat_id1不在映射中,说明它不在任何组中,只与自己相关
|
||||
return False
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name}"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id[:8]}...)"
|
||||
|
||||
|
||||
def text_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
使用SequenceMatcher计算相似度,返回0-1之间的值
|
||||
在计算前会移除"使用"和"句式"这两个词
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# 移除"使用"和"句式"这两个词
|
||||
def remove_ignored_words(text: str) -> str:
|
||||
"""移除需要忽略的词"""
|
||||
text = text.replace("使用", "")
|
||||
text = text.replace("句式", "")
|
||||
return text.strip()
|
||||
|
||||
cleaned_text1 = remove_ignored_words(text1)
|
||||
cleaned_text2 = remove_ignored_words(text2)
|
||||
|
||||
# 如果清理后文本为空,返回0
|
||||
if not cleaned_text1 or not cleaned_text2:
|
||||
return 0.0
|
||||
|
||||
return SequenceMatcher(None, cleaned_text1, cleaned_text2).ratio()
|
||||
|
||||
|
||||
def find_similar_pairs(
|
||||
expressions: List[Expression],
|
||||
field_name: str,
|
||||
threshold: float,
|
||||
max_pairs: int = None
|
||||
) -> List[Tuple[int, int, float, str, str]]:
|
||||
"""
|
||||
找出相似的expression对
|
||||
|
||||
Args:
|
||||
expressions: Expression对象列表
|
||||
field_name: 要比较的字段名 ('situation' 或 'style')
|
||||
threshold: 相似度阈值 (0-1)
|
||||
max_pairs: 最多返回的对数,None表示返回所有
|
||||
|
||||
Returns:
|
||||
List of (index1, index2, similarity, text1, text2) tuples
|
||||
"""
|
||||
similar_pairs = []
|
||||
n = len(expressions)
|
||||
|
||||
print(f"正在分析 {field_name} 字段的相似度...")
|
||||
print(f"总共需要比较 {n * (n - 1) // 2} 对...")
|
||||
|
||||
for i in range(n):
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f" 已处理 {i + 1}/{n} 个项目...")
|
||||
|
||||
expr1 = expressions[i]
|
||||
text1 = getattr(expr1, field_name, "")
|
||||
|
||||
for j in range(i + 1, n):
|
||||
expr2 = expressions[j]
|
||||
text2 = getattr(expr2, field_name, "")
|
||||
|
||||
similarity = text_similarity(text1, text2)
|
||||
|
||||
if similarity >= threshold:
|
||||
similar_pairs.append((i, j, similarity, text1, text2))
|
||||
|
||||
# 按相似度降序排序
|
||||
similar_pairs.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
if max_pairs:
|
||||
similar_pairs = similar_pairs[:max_pairs]
|
||||
|
||||
return similar_pairs
|
||||
|
||||
|
||||
def group_similar_items(
|
||||
expressions: List[Expression],
|
||||
field_name: str,
|
||||
threshold: float,
|
||||
chat_id_groups: dict[str, set[str]]
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
将相似的expression分组(仅比较相同chat_id或同组的项目)
|
||||
|
||||
Args:
|
||||
expressions: Expression对象列表
|
||||
field_name: 要比较的字段名 ('situation' 或 'style')
|
||||
threshold: 相似度阈值 (0-1)
|
||||
chat_id_groups: chat_id到相关chat_id集合的映射
|
||||
|
||||
Returns:
|
||||
List of groups, each group is a list of indices
|
||||
"""
|
||||
n = len(expressions)
|
||||
# 使用并查集的思想来分组
|
||||
parent = list(range(n))
|
||||
|
||||
def find(x):
|
||||
if parent[x] != x:
|
||||
parent[x] = find(parent[x])
|
||||
return parent[x]
|
||||
|
||||
def union(x, y):
|
||||
px, py = find(x), find(y)
|
||||
if px != py:
|
||||
parent[px] = py
|
||||
|
||||
print(f"正在对 {field_name} 字段进行分组(仅比较相同chat_id或同组的项目)...")
|
||||
|
||||
# 统计需要比较的对数
|
||||
total_pairs = 0
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
if are_chat_ids_related(expressions[i].chat_id, expressions[j].chat_id, chat_id_groups):
|
||||
total_pairs += 1
|
||||
|
||||
print(f"总共需要比较 {total_pairs} 对(已过滤不同chat_id且不同组的项目)...")
|
||||
|
||||
compared_pairs = 0
|
||||
for i in range(n):
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f" 已处理 {i + 1}/{n} 个项目...")
|
||||
|
||||
expr1 = expressions[i]
|
||||
text1 = getattr(expr1, field_name, "")
|
||||
|
||||
for j in range(i + 1, n):
|
||||
expr2 = expressions[j]
|
||||
|
||||
# 只比较相同chat_id或同组的项目
|
||||
if not are_chat_ids_related(expr1.chat_id, expr2.chat_id, chat_id_groups):
|
||||
continue
|
||||
|
||||
compared_pairs += 1
|
||||
text2 = getattr(expr2, field_name, "")
|
||||
|
||||
similarity = text_similarity(text1, text2)
|
||||
|
||||
if similarity >= threshold:
|
||||
union(i, j)
|
||||
|
||||
# 收集分组
|
||||
groups = defaultdict(list)
|
||||
for i in range(n):
|
||||
root = find(i)
|
||||
groups[root].append(i)
|
||||
|
||||
# 只返回包含多个项目的组
|
||||
result = [group for group in groups.values() if len(group) > 1]
|
||||
result.sort(key=len, reverse=True)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_similarity_analysis(
|
||||
expressions: List[Expression],
|
||||
field_name: str,
|
||||
threshold: float,
|
||||
chat_id_groups: dict[str, set[str]],
|
||||
show_details: bool = True,
|
||||
max_groups: int = 20
|
||||
):
|
||||
"""打印相似度分析结果"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"{field_name.upper()} 相似度分析 (阈值: {threshold})")
|
||||
print("=" * 80)
|
||||
|
||||
# 分组分析
|
||||
groups = group_similar_items(expressions, field_name, threshold, chat_id_groups)
|
||||
|
||||
total_items = len(expressions)
|
||||
similar_items_count = sum(len(group) for group in groups)
|
||||
unique_groups = len(groups)
|
||||
|
||||
print("\n📊 统计信息:")
|
||||
print(f" 总项目数: {total_items}")
|
||||
print(f" 相似项目数: {similar_items_count} ({similar_items_count / total_items * 100:.1f}%)")
|
||||
print(f" 相似组数: {unique_groups}")
|
||||
print(f" 平均每组项目数: {similar_items_count / unique_groups:.1f}" if unique_groups > 0 else " 平均每组项目数: 0")
|
||||
|
||||
if not groups:
|
||||
print(f"\n未找到相似度 >= {threshold} 的项目组")
|
||||
return
|
||||
|
||||
print(f"\n📋 相似组详情 (显示前 {min(max_groups, len(groups))} 组):")
|
||||
print()
|
||||
|
||||
for group_idx, group in enumerate(groups[:max_groups], 1):
|
||||
print(f"组 {group_idx} (共 {len(group)} 个项目):")
|
||||
|
||||
if show_details:
|
||||
# 显示组内所有项目的详细信息
|
||||
for idx in group:
|
||||
expr = expressions[idx]
|
||||
text = getattr(expr, field_name, "")
|
||||
chat_name = get_chat_name(expr.chat_id)
|
||||
|
||||
# 截断过长的文本
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
|
||||
print(f" [{expr.id}] {display_text}")
|
||||
print(f" 聊天: {chat_name}, Count: {expr.count}")
|
||||
|
||||
# 计算组内平均相似度
|
||||
if len(group) > 1:
|
||||
similarities = []
|
||||
above_threshold_pairs = [] # 存储满足阈值的相似对
|
||||
above_threshold_count = 0
|
||||
for i in range(len(group)):
|
||||
for j in range(i + 1, len(group)):
|
||||
text1 = getattr(expressions[group[i]], field_name, "")
|
||||
text2 = getattr(expressions[group[j]], field_name, "")
|
||||
sim = text_similarity(text1, text2)
|
||||
similarities.append(sim)
|
||||
if sim >= threshold:
|
||||
above_threshold_count += 1
|
||||
# 存储满足阈值的对的信息
|
||||
expr1 = expressions[group[i]]
|
||||
expr2 = expressions[group[j]]
|
||||
display_text1 = text1[:40] + "..." if len(text1) > 40 else text1
|
||||
display_text2 = text2[:40] + "..." if len(text2) > 40 else text2
|
||||
above_threshold_pairs.append((
|
||||
expr1.id, display_text1,
|
||||
expr2.id, display_text2,
|
||||
sim
|
||||
))
|
||||
|
||||
if similarities:
|
||||
avg_sim = sum(similarities) / len(similarities)
|
||||
min_sim = min(similarities)
|
||||
max_sim = max(similarities)
|
||||
above_threshold_ratio = above_threshold_count / len(similarities) * 100
|
||||
print(f" 平均相似度: {avg_sim:.3f} (范围: {min_sim:.3f} - {max_sim:.3f})")
|
||||
print(f" 满足阈值({threshold})的比例: {above_threshold_ratio:.1f}% ({above_threshold_count}/{len(similarities)})")
|
||||
|
||||
# 显示满足阈值的相似对(这些是直接连接,导致它们被分到一组)
|
||||
if above_threshold_pairs:
|
||||
print(" ⚠️ 直接相似的对 (这些对导致它们被分到一组):")
|
||||
# 按相似度降序排序
|
||||
above_threshold_pairs.sort(key=lambda x: x[4], reverse=True)
|
||||
for idx1, text1, idx2, text2, sim in above_threshold_pairs[:10]: # 最多显示10对
|
||||
print(f" [{idx1}] ↔ [{idx2}]: {sim:.3f}")
|
||||
print(f" \"{text1}\" ↔ \"{text2}\"")
|
||||
if len(above_threshold_pairs) > 10:
|
||||
print(f" ... 还有 {len(above_threshold_pairs) - 10} 对满足阈值")
|
||||
else:
|
||||
print(f" ⚠️ 警告: 组内没有任何对满足阈值({threshold:.2f}),可能是通过传递性连接")
|
||||
else:
|
||||
# 只显示组内第一个项目作为示例
|
||||
expr = expressions[group[0]]
|
||||
text = getattr(expr, field_name, "")
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
print(f" 示例: {display_text}")
|
||||
print(f" ... 还有 {len(group) - 1} 个相似项目")
|
||||
|
||||
print()
|
||||
|
||||
if len(groups) > max_groups:
|
||||
print(f"... 还有 {len(groups) - max_groups} 组未显示")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description="分析expression库中situation和style的相似度")
|
||||
parser.add_argument(
|
||||
"--situation-threshold",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="situation相似度阈值 (0-1, 默认: 0.7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--style-threshold",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="style相似度阈值 (0-1, 默认: 0.7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-details",
|
||||
action="store_true",
|
||||
help="不显示详细信息,只显示统计"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-groups",
|
||||
type=int,
|
||||
default=20,
|
||||
help="最多显示的组数 (默认: 20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证阈值
|
||||
if not 0 <= args.situation_threshold <= 1:
|
||||
print("错误: situation-threshold 必须在 0-1 之间")
|
||||
return
|
||||
if not 0 <= args.style_threshold <= 1:
|
||||
print("错误: style-threshold 必须在 0-1 之间")
|
||||
return
|
||||
|
||||
# 确定输出文件路径
|
||||
if args.output:
|
||||
output_file = args.output
|
||||
else:
|
||||
# 自动生成带时间戳的输出文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, f"expression_similarity_analysis_{timestamp}.txt")
|
||||
|
||||
# 使用TeeOutput同时输出到控制台和文件
|
||||
with TeeOutput(output_file) as tee:
|
||||
# 临时替换sys.stdout
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = tee
|
||||
|
||||
try:
|
||||
print("=" * 80)
|
||||
print("Expression 相似度分析工具")
|
||||
print("=" * 80)
|
||||
print(f"输出文件: {output_file}")
|
||||
print()
|
||||
|
||||
_run_analysis(args)
|
||||
|
||||
finally:
|
||||
# 恢复原始stdout
|
||||
sys.stdout = original_stdout
|
||||
|
||||
print(f"\n✅ 分析结果已保存到: {output_file}")
|
||||
|
||||
|
||||
def _run_analysis(args):
|
||||
"""执行分析的主逻辑"""
|
||||
|
||||
# 查询所有Expression记录
|
||||
print("正在从数据库加载Expression数据...")
|
||||
try:
|
||||
expressions = list(Expression.select())
|
||||
except Exception as e:
|
||||
print(f"❌ 加载数据失败: {e}")
|
||||
return
|
||||
|
||||
if not expressions:
|
||||
print("❌ 数据库中没有找到Expression记录")
|
||||
return
|
||||
|
||||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
|
||||
print()
|
||||
|
||||
# 构建chat_id分组映射
|
||||
print("正在构建chat_id分组映射(根据expression_groups配置)...")
|
||||
try:
|
||||
chat_id_groups = build_chat_id_groups()
|
||||
print(f"✅ 成功构建 {len(chat_id_groups)} 个chat_id的分组映射")
|
||||
if chat_id_groups:
|
||||
# 统计分组信息
|
||||
total_related = sum(len(related) for related in chat_id_groups.values())
|
||||
avg_related = total_related / len(chat_id_groups)
|
||||
print(f" 平均每个chat_id与 {avg_related:.1f} 个chat_id相关(包括自身)")
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f"⚠️ 构建chat_id分组映射失败: {e}")
|
||||
print(" 将使用默认行为:只比较相同chat_id的项目")
|
||||
chat_id_groups = {}
|
||||
|
||||
# 分析situation相似度
|
||||
print_similarity_analysis(
|
||||
expressions,
|
||||
"situation",
|
||||
args.situation_threshold,
|
||||
chat_id_groups,
|
||||
show_details=not args.no_details,
|
||||
max_groups=args.max_groups
|
||||
)
|
||||
|
||||
# 分析style相似度
|
||||
print_similarity_analysis(
|
||||
expressions,
|
||||
"style",
|
||||
args.style_threshold,
|
||||
chat_id_groups,
|
||||
show_details=not args.no_details,
|
||||
max_groups=args.max_groups
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("分析完成!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -1,196 +0,0 @@
|
|||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
# Add project root to Python path
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
# 直接从数据库查询ChatStreams表
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
# 如果有群组信息,显示群组名称
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
# 如果是私聊,显示用户昵称
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def calculate_time_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of last active time in days"""
|
||||
now = time.time()
|
||||
distribution = {
|
||||
"0-1天": 0,
|
||||
"1-3天": 0,
|
||||
"3-7天": 0,
|
||||
"7-14天": 0,
|
||||
"14-30天": 0,
|
||||
"30-60天": 0,
|
||||
"60-90天": 0,
|
||||
"90+天": 0,
|
||||
}
|
||||
for expr in expressions:
|
||||
diff_days = (now - expr.last_active_time) / (24 * 3600)
|
||||
if diff_days < 1:
|
||||
distribution["0-1天"] += 1
|
||||
elif diff_days < 3:
|
||||
distribution["1-3天"] += 1
|
||||
elif diff_days < 7:
|
||||
distribution["3-7天"] += 1
|
||||
elif diff_days < 14:
|
||||
distribution["7-14天"] += 1
|
||||
elif diff_days < 30:
|
||||
distribution["14-30天"] += 1
|
||||
elif diff_days < 60:
|
||||
distribution["30-60天"] += 1
|
||||
elif diff_days < 90:
|
||||
distribution["60-90天"] += 1
|
||||
else:
|
||||
distribution["90+天"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_count_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of count values"""
|
||||
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
|
||||
for expr in expressions:
|
||||
cnt = expr.count
|
||||
if cnt < 1:
|
||||
distribution["0-1"] += 1
|
||||
elif cnt < 2:
|
||||
distribution["1-2"] += 1
|
||||
elif cnt < 3:
|
||||
distribution["2-3"] += 1
|
||||
elif cnt < 4:
|
||||
distribution["3-4"] += 1
|
||||
elif cnt < 5:
|
||||
distribution["4-5"] += 1
|
||||
elif cnt < 10:
|
||||
distribution["5-10"] += 1
|
||||
else:
|
||||
distribution["10+"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
|
||||
"""Get top N most used expressions for a specific chat_id"""
|
||||
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
|
||||
|
||||
|
||||
def show_overall_statistics(expressions, total: int) -> None:
|
||||
"""Show overall statistics"""
|
||||
time_dist = calculate_time_distribution(expressions)
|
||||
count_dist = calculate_count_distribution(expressions)
|
||||
|
||||
print("\n=== 总体统计 ===")
|
||||
print(f"总表达式数量: {total}")
|
||||
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
print(f"{period}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
|
||||
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
|
||||
"""Show statistics for a specific chat"""
|
||||
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
|
||||
chat_total = len(chat_exprs)
|
||||
|
||||
print(f"\n=== {chat_name} ===")
|
||||
print(f"表达式数量: {chat_total}")
|
||||
|
||||
if chat_total == 0:
|
||||
print("该聊天没有表达式数据")
|
||||
return
|
||||
|
||||
# Time distribution for this chat
|
||||
time_dist = calculate_time_distribution(chat_exprs)
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
if count > 0:
|
||||
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Count distribution for this chat
|
||||
count_dist = calculate_count_distribution(chat_exprs)
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
if count > 0:
|
||||
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Top expressions
|
||||
print("\nTop 10使用最多的表达式:")
|
||||
top_exprs = get_top_expressions_by_chat(chat_id, 10)
|
||||
for i, expr in enumerate(top_exprs, 1):
|
||||
print(f"{i}. [{expr.type}] Count: {expr.count}")
|
||||
print(f" Situation: {expr.situation}")
|
||||
print(f" Style: {expr.style}")
|
||||
print()
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for expression statistics"""
|
||||
# Get all expressions
|
||||
expressions = list(Expression.select())
|
||||
if not expressions:
|
||||
print("数据库中没有找到表达式")
|
||||
return
|
||||
|
||||
total = len(expressions)
|
||||
|
||||
# Get unique chat_ids and their names
|
||||
chat_ids = list(set(expr.chat_id for expr in expressions))
|
||||
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
|
||||
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
|
||||
|
||||
while True:
|
||||
print("\n" + "=" * 50)
|
||||
print("表达式统计分析")
|
||||
print("=" * 50)
|
||||
print("0. 显示总体统计")
|
||||
|
||||
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
|
||||
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
|
||||
print(f"{i}. {chat_name} ({chat_count}个表达式)")
|
||||
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
try:
|
||||
choice_num = int(choice)
|
||||
if choice_num == 0:
|
||||
show_overall_statistics(expressions, total)
|
||||
elif 1 <= choice_num <= len(chat_info):
|
||||
chat_id, chat_name = chat_info[choice_num - 1]
|
||||
show_chat_statistics(chat_id, chat_name)
|
||||
else:
|
||||
print("无效的选择,请重新输入")
|
||||
except ValueError:
|
||||
print("请输入有效的数字")
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -616,107 +616,6 @@ class DefaultReplyer:
|
|||
logger.error(f"上下文黑话解释失败: {e}")
|
||||
return ""
|
||||
|
||||
def build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
target_user_id: 目标用户ID(当前对话对象)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
||||
"""
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt = build_readable_messages(
|
||||
latest_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
|
||||
return all_dialogue_prompt
|
||||
|
||||
def core_background_build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
target_user_id: 目标用户ID(当前对话对象)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
||||
"""
|
||||
core_dialogue_list: List[DatabaseMessages] = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
|
||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
||||
for msg in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg.user_info.user_id)
|
||||
reply_to = msg.reply_to
|
||||
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
# bot 和目标用户的对话
|
||||
core_dialogue_list.append(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
|
||||
|
||||
# 构建核心对话 prompt
|
||||
core_dialogue_prompt = ""
|
||||
if core_dialogue_list:
|
||||
# 检查最新五条消息中是否包含bot自己说的消息
|
||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||
has_bot_message = any(str(msg.user_info.user_id) == bot_id for msg in latest_5_messages)
|
||||
|
||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
||||
|
||||
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
||||
if not has_bot_message:
|
||||
core_dialogue_prompt = ""
|
||||
else:
|
||||
core_dialogue_list = core_dialogue_list[
|
||||
-int(global_config.chat.max_context_size * 0.6) :
|
||||
] # 限制消息数量
|
||||
|
||||
core_dialogue_prompt_str = build_readable_messages(
|
||||
core_dialogue_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = f"""--------------------------------
|
||||
这是上述中你和{sender}的对话摘要,内容从上面的对话中截取,便于你理解:
|
||||
{core_dialogue_prompt_str}
|
||||
--------------------------------
|
||||
"""
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
if core_dialogue_prompt:
|
||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
||||
else:
|
||||
all_dialogue_prompt = f"{all_dialogue_prompt_str}"
|
||||
|
||||
return core_dialogue_prompt, all_dialogue_prompt
|
||||
|
||||
async def build_actions_prompt(
|
||||
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
|
||||
) -> str:
|
||||
|
|
@ -940,6 +839,7 @@ class DefaultReplyer:
|
|||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
long_time_notice=True,
|
||||
)
|
||||
|
||||
# 统一黑话解释构建:根据配置选择上下文或 Planner 模式
|
||||
|
|
@ -1047,8 +947,16 @@ class DefaultReplyer:
|
|||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
dialogue_prompt = self.build_chat_history_prompts(message_list_before_now_long, user_id, sender)
|
||||
|
||||
if message_list_before_now_long:
|
||||
latest_msgs = message_list_before_now_long[-int(global_config.chat.max_context_size) :]
|
||||
dialogue_prompt = build_readable_messages(
|
||||
latest_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
long_time_notice=True,
|
||||
)
|
||||
|
||||
# 获取匹配的额外prompt
|
||||
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)
|
||||
|
|
|
|||
|
|
@ -667,6 +667,7 @@ class PrivateReplyer:
|
|||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
long_time_notice=True
|
||||
)
|
||||
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
|
|
|
|||
|
|
@ -370,6 +370,7 @@ def _build_readable_messages_internal(
|
|||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
|
||||
pic_single: bool = False,
|
||||
long_time_notice: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
# sourcery skip: use-getitem-for-re-match-groups
|
||||
"""
|
||||
|
|
@ -523,7 +524,30 @@ def _build_readable_messages_internal(
|
|||
# 3: 格式化为字符串
|
||||
output_lines: List[str] = []
|
||||
|
||||
prev_timestamp: Optional[float] = None
|
||||
for timestamp, name, content, is_action in detailed_message:
|
||||
# 检查是否需要插入长时间间隔提示
|
||||
if long_time_notice and prev_timestamp is not None:
|
||||
time_diff = timestamp - prev_timestamp
|
||||
time_diff_hours = time_diff / 3600
|
||||
|
||||
# 检查是否跨天
|
||||
prev_date = time.strftime("%Y-%m-%d", time.localtime(prev_timestamp))
|
||||
current_date = time.strftime("%Y-%m-%d", time.localtime(timestamp))
|
||||
is_cross_day = prev_date != current_date
|
||||
|
||||
# 如果间隔大于8小时或跨天,插入提示
|
||||
if time_diff_hours > 8 or is_cross_day:
|
||||
# 格式化日期为中文格式:xxxx年xx月xx日(去掉前导零)
|
||||
current_time_struct = time.localtime(timestamp)
|
||||
year = current_time_struct.tm_year
|
||||
month = current_time_struct.tm_mon
|
||||
day = current_time_struct.tm_mday
|
||||
date_str = f"{year}年{month}月{day}日"
|
||||
hours_str = f"{int(time_diff_hours)}h"
|
||||
notice = f"以下聊天开始时间:{date_str}。距离上一条消息过去了{hours_str}\n"
|
||||
output_lines.append(notice)
|
||||
|
||||
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
|
||||
|
||||
# 查找消息id(如果有)并构建id_prefix
|
||||
|
|
@ -536,6 +560,8 @@ def _build_readable_messages_internal(
|
|||
else:
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
|
||||
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
|
||||
|
||||
prev_timestamp = timestamp
|
||||
|
||||
formatted_string = "".join(output_lines).strip()
|
||||
|
||||
|
|
@ -651,6 +677,7 @@ async def build_readable_messages_with_list(
|
|||
show_pic=True,
|
||||
message_id_list=None,
|
||||
pic_single=pic_single,
|
||||
long_time_notice=False,
|
||||
)
|
||||
|
||||
if not pic_single:
|
||||
|
|
@ -704,6 +731,7 @@ def build_readable_messages(
|
|||
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
|
||||
remove_emoji_stickers: bool = False,
|
||||
pic_single: bool = False,
|
||||
long_time_notice: bool = False,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
|
|
@ -719,6 +747,7 @@ def build_readable_messages(
|
|||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
remove_emoji_stickers: 是否移除表情包并过滤空消息
|
||||
long_time_notice: 是否在消息间隔过长(>8小时)或跨天时插入时间提示
|
||||
"""
|
||||
# WIP HERE and BELOW ----------------------------------------------
|
||||
# 创建messages的深拷贝,避免修改原始列表
|
||||
|
|
@ -812,6 +841,7 @@ def build_readable_messages(
|
|||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
pic_single=pic_single,
|
||||
long_time_notice=long_time_notice,
|
||||
)
|
||||
|
||||
if not pic_single:
|
||||
|
|
@ -839,6 +869,7 @@ def build_readable_messages(
|
|||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
pic_single=pic_single,
|
||||
long_time_notice=long_time_notice,
|
||||
)
|
||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
|
|
@ -850,6 +881,7 @@ def build_readable_messages(
|
|||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
pic_single=pic_single,
|
||||
long_time_notice=long_time_notice,
|
||||
)
|
||||
|
||||
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"
|
||||
|
|
|
|||
Loading…
Reference in New Issue