From 99665e79182df5da809ae099d741d6e775cc868c Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 27 Dec 2025 17:20:11 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E6=96=B0=E5=A2=9E=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E8=A1=A8=E8=BE=BE=E4=BC=98=E5=8C=96=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E8=A1=A8=E8=BE=BE=E6=96=B9=E5=BC=8F?= =?UTF-8?q?=E7=9A=84=E6=8F=90=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../evaluate_expressions_count_analysis.py | 556 +++++++++++++++++ src/bw_learner/expression_auto_check_task.py | 235 ++++++++ src/bw_learner/expression_learner.py | 558 ++++++------------ src/bw_learner/expression_reflector.py | 4 +- src/bw_learner/expression_selector.py | 20 +- src/bw_learner/learner_utils.py | 184 ++++-- src/bw_learner/message_recorder.py | 35 +- src/chat/heart_flow/heartFC_chat.py | 2 +- src/common/database/database_model.py | 5 - src/config/official_configs.py | 29 +- src/hippo_memorizer/memory_forget_task.py | 362 ------------ src/main.py | 10 +- .../chat_history_summarizer.py | 0 template/bot_config_template.toml | 14 +- 14 files changed, 1177 insertions(+), 837 deletions(-) create mode 100644 scripts/evaluate_expressions_count_analysis.py create mode 100644 src/bw_learner/expression_auto_check_task.py delete mode 100644 src/hippo_memorizer/memory_forget_task.py rename src/{hippo_memorizer => memory_system}/chat_history_summarizer.py (100%) diff --git a/scripts/evaluate_expressions_count_analysis.py b/scripts/evaluate_expressions_count_analysis.py new file mode 100644 index 00000000..db1f4e71 --- /dev/null +++ b/scripts/evaluate_expressions_count_analysis.py @@ -0,0 +1,556 @@ +""" +表达方式按count分组的LLM评估和统计分析脚本 + +功能: +1. 随机选择50条表达,至少要有20条count>1的项目,然后进行LLM评估 +2. 比较不同count之间的LLM评估合格率是否有显著差异 + - 首先每个count分开比较 + - 然后比较count为1和count大于1的两种 +""" + +import asyncio +import random +import json +import sys +import os +import re +from typing import List, Dict, Set, Tuple +from datetime import datetime +from collections import defaultdict + +# 添加项目根目录到路径 +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, project_root) + +from src.common.database.database_model import Expression +from src.common.database.database import db +from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest +from src.config.config import model_config + +logger = get_logger("expression_evaluator_count_analysis_llm") + +# 评估结果文件路径 +TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp") +COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.json") + + +def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]: + """ + 加载已有的评估结果 + + Returns: + (已有结果列表, 已评估的项目(situation, style)元组集合) + """ + if not os.path.exists(COUNT_ANALYSIS_FILE): + return [], set() + + try: + with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f: + data = json.load(f) + results = data.get("evaluation_results", []) + # 使用 (situation, style) 作为唯一标识 + evaluated_pairs = {(r["situation"], r["style"]) for r in results if "situation" in r and "style" in r} + logger.info(f"已加载 {len(results)} 条已有评估结果") + return results, evaluated_pairs + except Exception as e: + logger.error(f"加载已有评估结果失败: {e}") + return [], set() + + +def save_results(evaluation_results: List[Dict]): + """ + 保存评估结果到文件 + + Args: + evaluation_results: 评估结果列表 + """ + try: + os.makedirs(TEMP_DIR, exist_ok=True) + + data = { + "last_updated": datetime.now().isoformat(), + "total_count": len(evaluation_results), + "evaluation_results": evaluation_results + } + + with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}") + print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)") + except Exception as e: + logger.error(f"保存评估结果失败: {e}") + print(f"\n✗ 保存评估结果失败: {e}") + + +def select_expressions_for_evaluation( + evaluated_pairs: Set[Tuple[str, str]] = None +) -> List[Expression]: + """ + 选择用于评估的表达方式 + 选择所有count>1的项目,然后选择两倍数量的count=1的项目 + + Args: + evaluated_pairs: 已评估的项目集合,用于避免重复 + + Returns: + 选中的表达方式列表 + """ + if evaluated_pairs is None: + evaluated_pairs = set() + + try: + # 查询所有表达方式 + all_expressions = list(Expression.select()) + + if not all_expressions: + logger.warning("数据库中没有表达方式记录") + return [] + + # 过滤出未评估的项目 + unevaluated = [ + expr for expr in all_expressions + if (expr.situation, expr.style) not in evaluated_pairs + ] + + if not unevaluated: + logger.warning("所有项目都已评估完成") + return [] + + # 按count分组 + count_eq1 = [expr for expr in unevaluated if expr.count == 1] + count_gt1 = [expr for expr in unevaluated if expr.count > 1] + + logger.info(f"未评估项目中:count=1的有{len(count_eq1)}条,count>1的有{len(count_gt1)}条") + + # 选择所有count>1的项目 + selected_count_gt1 = count_gt1.copy() + + # 选择count=1的项目,数量为count>1数量的2倍 + count_gt1_count = len(selected_count_gt1) + count_eq1_needed = count_gt1_count * 2 + + if len(count_eq1) < count_eq1_needed: + logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条") + count_eq1_needed = len(count_eq1) + + # 随机选择count=1的项目 + selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else [] + + selected = selected_count_gt1 + selected_count_eq1 + random.shuffle(selected) # 打乱顺序 + + logger.info(f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)") + + return selected + + except Exception as e: + logger.error(f"选择表达方式失败: {e}") + import traceback + logger.error(traceback.format_exc()) + return [] + + +def create_evaluation_prompt(situation: str, style: str) -> str: + """ + 创建评估提示词 + + Args: + situation: 情境 + style: 风格 + + Returns: + 评估提示词 + """ + prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适: +使用条件或使用情景:{situation} +表达方式或言语风格:{style} + +请从以下方面进行评估: +1. 表达方式或言语风格 是否与使用条件或使用情景 匹配 +2. 允许部分语法错误或口头化或缺省出现 +3. 表达方式不能太过特指,需要具有泛用性 +4. 一般不涉及具体的人名或名称 + +请以JSON格式输出评估结果: +{{ + "suitable": true/false, + "reason": "评估理由(如果不合适,请说明原因)" + +}} +如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。 +请严格按照JSON格式输出,不要包含其他内容。""" + + return prompt + + +async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]: + """ + 执行单次LLM评估 + + Args: + situation: 情境 + style: 风格 + llm: LLM请求实例 + + Returns: + (suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息 + """ + try: + prompt = create_evaluation_prompt(situation, style) + logger.debug(f"正在评估表达方式: situation={situation}, style={style}") + + response, (reasoning, model_name, _) = await llm.generate_response_async( + prompt=prompt, + temperature=0.6, + max_tokens=1024 + ) + + logger.debug(f"LLM响应: {response}") + + # 解析JSON响应 + try: + evaluation = json.loads(response) + except json.JSONDecodeError as e: + json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL) + if json_match: + evaluation = json.loads(json_match.group()) + else: + raise ValueError("无法从响应中提取JSON格式的评估结果") from e + + suitable = evaluation.get("suitable", False) + reason = evaluation.get("reason", "未提供理由") + + logger.debug(f"评估结果: {'通过' if suitable else '不通过'}") + return suitable, reason, None + + except Exception as e: + logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}") + return False, f"评估过程出错: {str(e)}", str(e) + + +async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict: + """ + 使用LLM评估单个表达方式 + + Args: + expression: 表达方式对象 + llm: LLM请求实例 + + Returns: + 评估结果字典 + """ + logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}") + + suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm) + + if error: + suitable = False + + logger.info(f"评估完成: {'通过' if suitable else '不通过'}") + + return { + "situation": expression.situation, + "style": expression.style, + "count": expression.count, + "suitable": suitable, + "reason": reason, + "error": error, + "evaluator": "llm", + "evaluated_at": datetime.now().isoformat() + } + + +def perform_statistical_analysis(evaluation_results: List[Dict]): + """ + 对评估结果进行统计分析 + + Args: + evaluation_results: 评估结果列表 + """ + if not evaluation_results: + print("\n没有评估结果可供分析") + return + + print("\n" + "=" * 60) + print("统计分析结果") + print("=" * 60) + + # 按count分组统计 + count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0}) + + for result in evaluation_results: + count = result.get("count", 1) + suitable = result.get("suitable", False) + count_groups[count]["total"] += 1 + if suitable: + count_groups[count]["suitable"] += 1 + else: + count_groups[count]["unsuitable"] += 1 + + # 显示每个count的统计 + print("\n【按count分组统计】") + print("-" * 60) + for count in sorted(count_groups.keys()): + group = count_groups[count] + total = group["total"] + suitable = group["suitable"] + unsuitable = group["unsuitable"] + pass_rate = (suitable / total * 100) if total > 0 else 0 + + print(f"Count = {count}:") + print(f" 总数: {total}") + print(f" 通过: {suitable} ({pass_rate:.2f}%)") + print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)") + print() + + # 比较count=1和count>1 + count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0} + count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0} + + for result in evaluation_results: + count = result.get("count", 1) + suitable = result.get("suitable", False) + + if count == 1: + count_eq1_group["total"] += 1 + if suitable: + count_eq1_group["suitable"] += 1 + else: + count_eq1_group["unsuitable"] += 1 + else: + count_gt1_group["total"] += 1 + if suitable: + count_gt1_group["suitable"] += 1 + else: + count_gt1_group["unsuitable"] += 1 + + print("\n【Count=1 vs Count>1 对比】") + print("-" * 60) + + eq1_total = count_eq1_group["total"] + eq1_suitable = count_eq1_group["suitable"] + eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0 + + gt1_total = count_gt1_group["total"] + gt1_suitable = count_gt1_group["suitable"] + gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0 + + print("Count = 1:") + print(f" 总数: {eq1_total}") + print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)") + print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)") + print() + print("Count > 1:") + print(f" 总数: {gt1_total}") + print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)") + print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)") + print() + + # 进行卡方检验(简化版,使用2x2列联表) + if eq1_total > 0 and gt1_total > 0: + print("【统计显著性检验】") + print("-" * 60) + + # 构建2x2列联表 + # 通过 不通过 + # count=1 a b + # count>1 c d + a = eq1_suitable + b = eq1_total - eq1_suitable + c = gt1_suitable + d = gt1_total - gt1_suitable + + # 计算卡方统计量(简化版,使用Pearson卡方检验) + n = eq1_total + gt1_total + if n > 0: + # 期望频数 + e_a = (eq1_total * (a + c)) / n + e_b = (eq1_total * (b + d)) / n + e_c = (gt1_total * (a + c)) / n + e_d = (gt1_total * (b + d)) / n + + # 检查期望频数是否足够大(卡方检验要求每个期望频数>=5) + min_expected = min(e_a, e_b, e_c, e_d) + if min_expected < 5: + print("警告:期望频数小于5,卡方检验可能不准确") + print("建议使用Fisher精确检验") + + # 计算卡方值 + chi_square = 0 + if e_a > 0: + chi_square += ((a - e_a) ** 2) / e_a + if e_b > 0: + chi_square += ((b - e_b) ** 2) / e_b + if e_c > 0: + chi_square += ((c - e_c) ** 2) / e_c + if e_d > 0: + chi_square += ((d - e_d) ** 2) / e_d + + # 自由度 = (行数-1) * (列数-1) = 1 + df = 1 + + # 临界值(α=0.05) + chi_square_critical_005 = 3.841 + chi_square_critical_001 = 6.635 + + print(f"卡方统计量: {chi_square:.4f}") + print(f"自由度: {df}") + print(f"临界值 (α=0.05): {chi_square_critical_005}") + print(f"临界值 (α=0.01): {chi_square_critical_001}") + + if chi_square >= chi_square_critical_001: + print("结论: 在α=0.01水平下,count=1和count>1的合格率存在显著差异(p<0.01)") + elif chi_square >= chi_square_critical_005: + print("结论: 在α=0.05水平下,count=1和count>1的合格率存在显著差异(p<0.05)") + else: + print("结论: 在α=0.05水平下,count=1和count>1的合格率不存在显著差异(p≥0.05)") + + # 计算差异大小 + diff = abs(eq1_pass_rate - gt1_pass_rate) + print(f"\n合格率差异: {diff:.2f}%") + if diff > 10: + print("差异较大(>10%)") + elif diff > 5: + print("差异中等(5-10%)") + else: + print("差异较小(<5%)") + else: + print("数据不足,无法进行统计检验") + else: + print("数据不足,无法进行count=1和count>1的对比分析") + + # 保存统计分析结果 + analysis_result = { + "analysis_time": datetime.now().isoformat(), + "count_groups": {str(k): v for k, v in count_groups.items()}, + "count_eq1": count_eq1_group, + "count_gt1": count_gt1_group, + "total_evaluated": len(evaluation_results) + } + + try: + analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json") + with open(analysis_file, "w", encoding="utf-8") as f: + json.dump(analysis_result, f, ensure_ascii=False, indent=2) + print(f"\n✓ 统计分析结果已保存到: {analysis_file}") + except Exception as e: + logger.error(f"保存统计分析结果失败: {e}") + + +async def main(): + """主函数""" + logger.info("=" * 60) + logger.info("开始表达方式按count分组的LLM评估和统计分析") + logger.info("=" * 60) + + # 初始化数据库连接 + try: + db.connect(reuse_if_open=True) + logger.info("数据库连接成功") + except Exception as e: + logger.error(f"数据库连接失败: {e}") + return + + # 加载已有评估结果 + existing_results, evaluated_pairs = load_existing_results() + evaluation_results = existing_results.copy() + + if evaluated_pairs: + print(f"\n已加载 {len(existing_results)} 条已有评估结果") + print(f"已评估项目数: {len(evaluated_pairs)}") + + # 检查是否需要继续评估(检查是否还有未评估的count>1项目) + # 先查询未评估的count>1项目数量 + try: + all_expressions = list(Expression.select()) + unevaluated_count_gt1 = [ + expr for expr in all_expressions + if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs + ] + has_unevaluated = len(unevaluated_count_gt1) > 0 + except Exception as e: + logger.error(f"查询未评估项目失败: {e}") + has_unevaluated = False + + if has_unevaluated: + print("\n" + "=" * 60) + print("开始LLM评估") + print("=" * 60) + print("评估结果会自动保存到文件\n") + + # 创建LLM实例 + print("创建LLM实例...") + try: + llm = LLMRequest( + model_set=model_config.model_task_config.tool_use, + request_type="expression_evaluator_count_analysis_llm" + ) + print("✓ LLM实例创建成功\n") + except Exception as e: + logger.error(f"创建LLM实例失败: {e}") + import traceback + logger.error(traceback.format_exc()) + print(f"\n✗ 创建LLM实例失败: {e}") + db.close() + return + + # 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目) + expressions = select_expressions_for_evaluation( + evaluated_pairs=evaluated_pairs + ) + + if not expressions: + print("\n没有可评估的项目") + else: + print(f"\n已选择 {len(expressions)} 条表达方式进行评估") + print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)} 条") + print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)} 条\n") + + batch_results = [] + for i, expression in enumerate(expressions, 1): + print(f"LLM评估进度: {i}/{len(expressions)}") + print(f" Situation: {expression.situation}") + print(f" Style: {expression.style}") + print(f" Count: {expression.count}") + + llm_result = await llm_evaluate_expression(expression, llm) + + print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}") + if llm_result.get('error'): + print(f" 错误: {llm_result['error']}") + print() + + batch_results.append(llm_result) + # 使用 (situation, style) 作为唯一标识 + evaluated_pairs.add((llm_result["situation"], llm_result["style"])) + + # 添加延迟以避免API限流 + await asyncio.sleep(0.3) + + # 将当前批次结果添加到总结果中 + evaluation_results.extend(batch_results) + + # 保存结果 + save_results(evaluation_results) + else: + print(f"\n所有count>1的项目都已评估完成,已有 {len(evaluation_results)} 条评估结果") + + # 进行统计分析 + if len(evaluation_results) > 0: + perform_statistical_analysis(evaluation_results) + else: + print("\n没有评估结果可供分析") + + # 关闭数据库连接 + try: + db.close() + logger.info("数据库连接已关闭") + except Exception as e: + logger.warning(f"关闭数据库连接时出错: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/src/bw_learner/expression_auto_check_task.py b/src/bw_learner/expression_auto_check_task.py new file mode 100644 index 00000000..028604c1 --- /dev/null +++ b/src/bw_learner/expression_auto_check_task.py @@ -0,0 +1,235 @@ +""" +表达方式自动检查定时任务 + +功能: +1. 定期随机选取指定数量的表达方式 +2. 使用LLM进行评估 +3. 通过评估的:rejected=0, checked=1 +4. 未通过评估的:rejected=1, checked=1 +""" + +import asyncio +import json +import random +from typing import List + +from src.common.database.database_model import Expression +from src.common.logger import get_logger +from src.config.config import global_config +from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest +from src.manager.async_task_manager import AsyncTask + +logger = get_logger("expression_auto_check_task") + + +def create_evaluation_prompt(situation: str, style: str) -> str: + """ + 创建评估提示词 + + Args: + situation: 情境 + style: 风格 + + Returns: + 评估提示词 + """ + prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适: +使用条件或使用情景:{situation} +表达方式或言语风格:{style} + +请从以下方面进行评估: +1. 表达方式或言语风格 是否与使用条件或使用情景 匹配 +2. 允许部分语法错误或口头化或缺省出现 +3. 表达方式不能太过特指,需要具有泛用性 +4. 一般不涉及具体的人名或名称 + +请以JSON格式输出评估结果: +{{ + "suitable": true/false, + "reason": "评估理由(如果不合适,请说明原因)" + +}} +如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。 +请严格按照JSON格式输出,不要包含其他内容。""" + + return prompt + +judge_llm = LLMRequest( + model_set=model_config.model_task_config.tool_use, + request_type="expression_check" +) + +async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str]: + """ + 执行单次LLM评估 + + Args: + situation: 情境 + style: 风格 + + Returns: + (suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息 + """ + try: + prompt = create_evaluation_prompt(situation, style) + logger.debug(f"正在评估表达方式: situation={situation}, style={style}") + + response, (reasoning, model_name, _) = await judge_llm.generate_response_async( + prompt=prompt, + temperature=0.6, + max_tokens=1024 + ) + + logger.debug(f"LLM响应: {response}") + + # 解析JSON响应 + try: + evaluation = json.loads(response) + except json.JSONDecodeError as e: + import re + json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL) + if json_match: + evaluation = json.loads(json_match.group()) + else: + raise ValueError("无法从响应中提取JSON格式的评估结果") from e + + suitable = evaluation.get("suitable", False) + reason = evaluation.get("reason", "未提供理由") + + logger.debug(f"评估结果: {'通过' if suitable else '不通过'}") + return suitable, reason, None + + except Exception as e: + logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}") + return False, f"评估过程出错: {str(e)}", str(e) + + +class ExpressionAutoCheckTask(AsyncTask): + """表达方式自动检查定时任务""" + + def __init__(self): + # 从配置中获取检查间隔和一次检查数量 + check_interval = global_config.expression.expression_auto_check_interval + super().__init__( + task_name="Expression Auto Check Task", + wait_before_start=60, # 启动后等待60秒再开始第一次检查 + run_interval=check_interval + ) + + async def _select_expressions(self, count: int) -> List[Expression]: + """ + 随机选择指定数量的未检查表达方式 + + Args: + count: 需要选择的数量 + + Returns: + 选中的表达方式列表 + """ + try: + # 查询所有未检查的表达方式(checked=False) + unevaluated_expressions = list( + Expression.select().where(~Expression.checked) + ) + + if not unevaluated_expressions: + logger.info("没有未检查的表达方式") + return [] + + # 随机选择指定数量 + selected_count = min(count, len(unevaluated_expressions)) + selected = random.sample(unevaluated_expressions, selected_count) + + logger.info(f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条") + return selected + + except Exception as e: + logger.error(f"选择表达方式时出错: {e}") + return [] + + async def _evaluate_expression(self, expression: Expression) -> bool: + """ + 评估单个表达方式 + + Args: + expression: 要评估的表达方式 + + Returns: + True表示通过,False表示不通过 + """ + + suitable, reason, error = await single_expression_check( + expression.situation, + expression.style, + ) + + # 更新数据库 + try: + expression.checked = True + expression.rejected = not suitable # 通过则rejected=0,不通过则rejected=1 + expression.save() + + status = "通过" if suitable else "不通过" + logger.info( + f"表达方式评估完成 [ID: {expression.id}] - {status} | " + f"Situation: {expression.situation}... | " + f"Style: {expression.style}... | " + f"Reason: {reason[:50]}..." + ) + + if error: + logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}") + + return suitable + + except Exception as e: + logger.error(f"更新表达方式状态失败 [ID: {expression.id}]: {e}") + return False + + async def run(self): + """执行检查任务""" + try: + # 检查是否启用自动检查 + if not global_config.expression.expression_self_reflect: + logger.debug("表达方式自动检查未启用,跳过本次执行") + return + + check_count = global_config.expression.expression_auto_check_count + if check_count <= 0: + logger.warning(f"检查数量配置无效: {check_count},跳过本次执行") + return + + logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count} 条") + + + # 选择要检查的表达方式 + expressions = await self._select_expressions(check_count) + + if not expressions: + logger.info("没有需要检查的表达方式") + return + + # 逐个评估 + passed_count = 0 + failed_count = 0 + + for i, expression in enumerate(expressions, 1): + logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}") + + if await self._evaluate_expression(expression): + passed_count += 1 + else: + failed_count += 1 + + # 避免请求过快 + await asyncio.sleep(0.3) + + logger.info( + f"表达方式自动检查完成: 总计 {len(expressions)} 条," + f"通过 {passed_count} 条,不通过 {failed_count} 条" + ) + + except Exception as e: + logger.error(f"执行表达方式自动检查任务时出错: {e}", exc_info=True) + diff --git a/src/bw_learner/expression_learner.py b/src/bw_learner/expression_learner.py index 52e53b2d..1e19ee46 100644 --- a/src/bw_learner/expression_learner.py +++ b/src/bw_learner/expression_learner.py @@ -18,10 +18,13 @@ from src.bw_learner.learner_utils import ( is_bot_message, build_context_paragraph, contains_bot_self_name, - calculate_style_similarity, + calculate_similarity, + parse_expression_response, ) from src.bw_learner.jargon_miner import miner_manager -from json_repair import repair_json +from src.bw_learner.expression_auto_check_task import ( + single_expression_check, +) # MAX_EXPRESSION_COUNT = 300 @@ -91,6 +94,7 @@ class ExpressionLearner: self.summary_model: LLMRequest = LLMRequest( model_set=model_config.model_task_config.utils, request_type="expression.summary" ) + self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化 self.chat_id = chat_id self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id @@ -136,11 +140,10 @@ class ExpressionLearner: # 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号) expressions: List[Tuple[str, str, str]] jargon_entries: List[Tuple[str, str]] # (content, source_id) - expressions, jargon_entries = self.parse_expression_response(response) - expressions = self._filter_self_reference_styles(expressions) + expressions, jargon_entries = parse_expression_response(response) # 检查表达方式数量,如果超过10个则放弃本次表达学习 - if len(expressions) > 10: + if len(expressions) > 20: logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习") expressions = [] @@ -155,7 +158,7 @@ class ExpressionLearner: # 如果没有表达方式,直接返回 if not expressions: - logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)") + logger.info("解析后没有可用的表达方式") return [] logger.info(f"学习的prompt: {prompt}") @@ -163,9 +166,60 @@ class ExpressionLearner: logger.info(f"学习的jargon_entries: {jargon_entries}") logger.info(f"学习的response: {response}") - # 直接根据 source_id 在 random_msg 中溯源,获取 context + # 过滤表达方式,根据 source_id 溯源并应用各种过滤规则 + learnt_expressions = self._filter_expressions(expressions, random_msg) + + if learnt_expressions is None: + logger.info("没有学习到表达风格") + return [] + + # 展示学到的表达方式 + learnt_expressions_str = "" + for (situation,style) in learnt_expressions: + learnt_expressions_str += f"{situation}->{style}\n" + logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") + + current_time = time.time() + + # 存储到数据库 Expression 表 + for (situation,style) in learnt_expressions: + await self._upsert_expression_record( + situation=situation, + style=style, + current_time=current_time, + ) + + return learnt_expressions + + def _filter_expressions( + self, + expressions: List[Tuple[str, str, str]], + messages: List[Any], + ) -> List[Tuple[str, str, str]]: + """ + 过滤表达方式,移除不符合条件的条目 + + Args: + expressions: 表达方式列表,每个元素是 (situation, style, source_id) + messages: 原始消息列表,用于溯源和验证 + + Returns: + 过滤后的表达方式列表,每个元素是 (situation, style, context) + """ filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context) + # 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达) + banned_names = set() + bot_nickname = (global_config.bot.nickname or "").strip() + if bot_nickname: + banned_names.add(bot_nickname) + alias_names = global_config.bot.alias_names or [] + for alias in alias_names: + alias = alias.strip() + if alias: + banned_names.add(alias) + banned_casefold = {name.casefold() for name in banned_names if name} + for situation, style, source_id in expressions: source_id_str = (source_id or "").strip() if not source_id_str.isdigit(): @@ -173,12 +227,12 @@ class ExpressionLearner: continue line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始 - if line_index < 0 or line_index >= len(random_msg): + if line_index < 0 or line_index >= len(messages): # 超出范围,跳过 continue # 当前行的原始内容 - current_msg = random_msg[line_index] + current_msg = messages[line_index] # 过滤掉从bot自己发言中提取到的表达方式 if is_bot_message(current_msg): @@ -195,251 +249,53 @@ class ExpressionLearner: ) continue - filtered_expressions.append((situation, style, context)) - - learnt_expressions = filtered_expressions - - if learnt_expressions is None: - logger.info("没有学习到表达风格") - return [] - - # 展示学到的表达方式 - learnt_expressions_str = "" - for ( - situation, - style, - _context, - ) in learnt_expressions: - learnt_expressions_str += f"{situation}->{style}\n" - logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") - - current_time = time.time() - - # 存储到数据库 Expression 表 - for ( - situation, - style, - context, - ) in learnt_expressions: - await self._upsert_expression_record( - situation=situation, - style=style, - context=context, - current_time=current_time, - ) - - return learnt_expressions - - def parse_expression_response(self, response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: - """ - 解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。 - - 期望的 JSON 结构: - [ - {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式 - {"content": "词条", "source_id": "12"}, // 黑话 - ... - ] - - Returns: - Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: - 第一个列表是表达方式 (situation, style, source_id) - 第二个列表是黑话 (content, source_id) - """ - if not response: - return [], [] - - raw = response.strip() - - # 尝试提取 ```json 代码块 - json_block_pattern = r"```json\s*(.*?)\s*```" - match = re.search(json_block_pattern, raw, re.DOTALL) - if match: - raw = match.group(1).strip() - else: - # 去掉可能存在的通用 ``` 包裹 - raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE) - raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE) - raw = raw.strip() - - parsed = None - expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id) - jargon_entries: List[Tuple[str, str]] = [] # (content, source_id) - - try: - # 优先尝试直接解析 - if raw.startswith("[") and raw.endswith("]"): - parsed = json.loads(raw) - else: - repaired = repair_json(raw) - if isinstance(repaired, str): - parsed = json.loads(repaired) - else: - parsed = repaired - except Exception as parse_error: - # 如果解析失败,尝试修复中文引号问题 - # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 - try: - - def fix_chinese_quotes_in_json(text): - """使用状态机修复 JSON 字符串值中的中文引号""" - result = [] - i = 0 - in_string = False - escape_next = False - - while i < len(text): - char = text[i] - - if escape_next: - # 当前字符是转义字符后的字符,直接添加 - result.append(char) - escape_next = False - i += 1 - continue - - if char == "\\": - # 转义字符 - result.append(char) - escape_next = True - i += 1 - continue - - if char == '"' and not escape_next: - # 遇到英文引号,切换字符串状态 - in_string = not in_string - result.append(char) - i += 1 - continue - - if in_string: - # 在字符串值内部,将中文引号替换为转义的英文引号 - if char == '"': # 中文左引号 U+201C - result.append('\\"') - elif char == '"': # 中文右引号 U+201D - result.append('\\"') - else: - result.append(char) - else: - # 不在字符串内,直接添加 - result.append(char) - - i += 1 - - return "".join(result) - - fixed_raw = fix_chinese_quotes_in_json(raw) - - # 再次尝试解析 - if fixed_raw.startswith("[") and fixed_raw.endswith("]"): - parsed = json.loads(fixed_raw) - else: - repaired = repair_json(fixed_raw) - if isinstance(repaired, str): - parsed = json.loads(repaired) - else: - parsed = repaired - except Exception as fix_error: - logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}") - logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}") - logger.error(f"解析表达风格 JSON 失败,原始响应:{response}") - logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}") - return [], [] - - if isinstance(parsed, dict): - parsed_list = [parsed] - elif isinstance(parsed, list): - parsed_list = parsed - else: - logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}") - return [], [] - - for item in parsed_list: - if not isinstance(item, dict): + # 过滤掉 style 与机器人名称/昵称重复的表达 + normalized_style = (style or "").strip() + if normalized_style and normalized_style.casefold() in banned_casefold: + logger.debug( + f"跳过 style 与机器人名称重复的表达方式: situation={situation}, style={style}, source_id={source_id}" + ) continue - # 检查是否是表达方式条目(有 situation 和 style) - situation = str(item.get("situation", "")).strip() - style = str(item.get("style", "")).strip() - source_id = str(item.get("source_id", "")).strip() + # 过滤掉包含 "表情:" 或 "表情:" 的内容 + if "表情:" in (situation or "") or "表情:" in (situation or "") or \ + "表情:" in (style or "") or "表情:" in (style or "") or \ + "表情:" in context or "表情:" in context: + logger.info( + f"跳过包含表情标记的表达方式: situation={situation}, style={style}, source_id={source_id}" + ) + continue - if situation and style and source_id: - # 表达方式条目 - expressions.append((situation, style, source_id)) - elif item.get("content"): - # 黑话条目(有 content 字段) - content = str(item.get("content", "")).strip() - source_id = str(item.get("source_id", "")).strip() - if content and source_id: - jargon_entries.append((content, source_id)) + # 过滤掉包含 "[图片" 的内容 + if "[图片" in (situation or "") or "[图片" in (style or "") or "[图片" in context: + logger.info( + f"跳过包含图片标记的表达方式: situation={situation}, style={style}, source_id={source_id}" + ) + continue - return expressions, jargon_entries + filtered_expressions.append((situation, style)) - def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]: - """ - 过滤掉style与机器人名称/昵称重复的表达 - """ - banned_names = set() - bot_nickname = (global_config.bot.nickname or "").strip() - if bot_nickname: - banned_names.add(bot_nickname) - - alias_names = global_config.bot.alias_names or [] - for alias in alias_names: - alias = alias.strip() - if alias: - banned_names.add(alias) - - banned_casefold = {name.casefold() for name in banned_names if name} - - filtered: List[Tuple[str, str, str]] = [] - removed_count = 0 - for situation, style, source_id in expressions: - normalized_style = (style or "").strip() - if normalized_style and normalized_style.casefold() not in banned_casefold: - filtered.append((situation, style, source_id)) - else: - removed_count += 1 - - if removed_count: - logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式") - - return filtered + return filtered_expressions async def _upsert_expression_record( self, situation: str, style: str, - context: str, current_time: float, ) -> None: - # 第一层:检查是否有完全一致的 style(检查 style 字段和 style_list) - expr_obj = await self._find_exact_style_match(style) + # 检查是否有相似的 situation(相似度 >= 0.75,检查 content_list) + # 完全匹配(相似度 == 1.0)和相似匹配(相似度 >= 0.75)统一处理 + expr_obj, similarity = await self._find_similar_situation_expression(situation, similarity_threshold=0.75) if expr_obj: - # 找到完全匹配的 style,合并到现有记录(不使用 LLM 总结) + # 根据相似度决定是否使用 LLM 总结 + # 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结 + use_llm_summary = similarity < 1.0 await self._update_existing_expression( expr_obj=expr_obj, situation=situation, - style=style, - context=context, current_time=current_time, - use_llm_summary=False, - ) - return - - # 第二层:检查是否有相似的 style(相似度 >= 0.75,检查 style 字段和 style_list) - similar_expr_obj = await self._find_similar_style_expression(style, similarity_threshold=0.75) - - if similar_expr_obj: - # 找到相似的 style,合并到现有记录(使用 LLM 总结) - await self._update_existing_expression( - expr_obj=similar_expr_obj, - situation=situation, - style=style, - context=context, - current_time=current_time, - use_llm_summary=True, + use_llm_summary=use_llm_summary, ) return @@ -447,7 +303,6 @@ class ExpressionLearner: await self._create_expression_record( situation=situation, style=style, - context=context, current_time=current_time, ) @@ -455,7 +310,6 @@ class ExpressionLearner: self, situation: str, style: str, - context: str, current_time: float, ) -> None: content_list = [situation] @@ -466,26 +320,22 @@ class ExpressionLearner: situation=formatted_situation, style=style, content_list=json.dumps(content_list, ensure_ascii=False), - style_list=None, # 新记录初始时 style_list 为空 count=1, last_active_time=current_time, chat_id=self.chat_id, create_date=current_time, - context=context, ) async def _update_existing_expression( self, expr_obj: Expression, situation: str, - style: str, - context: str, current_time: float, use_llm_summary: bool = True, ) -> None: """ - 更新现有 Expression 记录(style 完全匹配或相似的情况) - 将新的 situation 添加到 content_list,将新的 style 添加到 style_list(如果不同) + 更新现有 Expression 记录(situation 完全匹配或相似的情况) + 将新的 situation 添加到 content_list,不合并 style Args: use_llm_summary: 是否使用 LLM 进行总结,完全匹配时为 False,相似匹配时为 True @@ -495,43 +345,24 @@ class ExpressionLearner: content_list.append(situation) expr_obj.content_list = json.dumps(content_list, ensure_ascii=False) - # 更新 style_list(如果 style 不同,添加到 style_list) - style_list = self._parse_style_list(expr_obj.style_list) - # 将原有的 style 也加入 style_list(如果还没有的话) - if expr_obj.style and expr_obj.style not in style_list: - style_list.append(expr_obj.style) - # 如果新的 style 不在 style_list 中,添加它 - if style not in style_list: - style_list.append(style) - expr_obj.style_list = json.dumps(style_list, ensure_ascii=False) - # 更新其他字段 expr_obj.count = (expr_obj.count or 0) + 1 + expr_obj.checked = False # count 增加时重置 checked 为 False expr_obj.last_active_time = current_time - expr_obj.context = context if use_llm_summary: - # 相似匹配时,使用 LLM 重新组合 situation 和 style + # 相似匹配时,使用 LLM 重新组合 situation new_situation = await self._compose_situation_text( content_list=content_list, - count=expr_obj.count, fallback=expr_obj.situation, ) expr_obj.situation = new_situation - new_style = await self._compose_style_text( - style_list=style_list, - count=expr_obj.count, - fallback=expr_obj.style or style, - ) - expr_obj.style = new_style - else: - # 完全匹配时,不进行 LLM 总结,保持原有的 situation 和 style 不变 - # 只更新 content_list 和 style_list - pass - expr_obj.save() + # count 增加后,立即进行一次检查 + await self._check_expression_immediately(expr_obj) + def _parse_content_list(self, stored_list: Optional[str]) -> List[str]: if not stored_list: return [] @@ -541,49 +372,19 @@ class ExpressionLearner: return [] return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else [] - def _parse_style_list(self, stored_list: Optional[str]) -> List[str]: - """解析 style_list JSON 字符串为列表,逻辑与 _parse_content_list 相同""" - if not stored_list: - return [] - try: - data = json.loads(stored_list) - except json.JSONDecodeError: - return [] - return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else [] - - async def _find_exact_style_match(self, style: str) -> Optional[Expression]: + async def _find_similar_situation_expression(self, situation: str, similarity_threshold: float = 0.75) -> Tuple[Optional[Expression], float]: """ - 查找具有完全匹配 style 的 Expression 记录 - 只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述) + 查找具有相似 situation 的 Expression 记录 + 检查 content_list 中的每一项 Args: - style: 要查找的 style - - Returns: - 找到的 Expression 对象,如果没有找到则返回 None - """ - # 查询同一 chat_id 的所有记录 - all_expressions = Expression.select().where(Expression.chat_id == self.chat_id) - - for expr in all_expressions: - # 只检查 style_list 中的每一项 - style_list = self._parse_style_list(expr.style_list) - if style in style_list: - return expr - - return None - - async def _find_similar_style_expression(self, style: str, similarity_threshold: float = 0.75) -> Optional[Expression]: - """ - 查找具有相似 style 的 Expression 记录 - 只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述) - - Args: - style: 要查找的 style + situation: 要查找的 situation similarity_threshold: 相似度阈值,默认 0.75 Returns: - 找到的最相似的 Expression 对象,如果没有找到则返回 None + Tuple[Optional[Expression], float]: + - 找到的最相似的 Expression 对象,如果没有找到则返回 None + - 相似度值(如果找到匹配,范围在 similarity_threshold 到 1.0 之间) """ # 查询同一 chat_id 的所有记录 all_expressions = Expression.select().where(Expression.chat_id == self.chat_id) @@ -592,96 +393,28 @@ class ExpressionLearner: best_similarity = 0.0 for expr in all_expressions: - # 只检查 style_list 中的每一项 - style_list = self._parse_style_list(expr.style_list) - for existing_style in style_list: - similarity = calculate_style_similarity(style, existing_style) + # 检查 content_list 中的每一项 + content_list = self._parse_content_list(expr.content_list) + for existing_situation in content_list: + similarity = calculate_similarity(situation, existing_situation) if similarity >= similarity_threshold and similarity > best_similarity: best_similarity = similarity best_match = expr if best_match: - logger.debug(f"找到相似的 style: 相似度={best_similarity:.3f}, 现有='{best_match.style}', 新='{style}'") + logger.debug(f"找到相似的 situation: 相似度={best_similarity:.3f}, 现有='{best_match.situation}', 新='{situation}'") - return best_match + return best_match, best_similarity - async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str: + async def _compose_situation_text(self, content_list: List[str], fallback: str = "") -> str: sanitized = [c.strip() for c in content_list if c.strip()] - summary = await self._summarize_situations(sanitized) - if summary: - return summary - return "/".join(sanitized) if sanitized else fallback - - async def _compose_style_text(self, style_list: List[str], count: int, fallback: str = "") -> str: - """ - 组合 style 文本,如果 style_list 有多个元素则尝试总结 - """ - sanitized = [s.strip() for s in style_list if s.strip()] - if len(sanitized) > 1: - # 只有当有多个 style 时才尝试总结 - summary = await self._summarize_styles(sanitized) - if summary: - return summary - # 如果只有一个或总结失败,返回第一个或 fallback - return sanitized[0] if sanitized else fallback - - async def _summarize_styles(self, styles: List[str]) -> Optional[str]: - """总结多个 style,生成一个概括性的 style 描述""" - if not styles or len(styles) <= 1: - return None - - # 计算输入列表中最长项目的长度 - max_input_length = max(len(s) for s in styles) if styles else 0 - max_summary_length = max_input_length * 2 - - # 最多重试3次 - max_retries = 3 - retry_count = 0 - - while retry_count < max_retries: - # 如果是重试,在 prompt 中强调要更简洁 - length_hint = f"长度不超过{max_summary_length}个字符," if retry_count > 0 else "长度不超过20个字," - - prompt = ( - "请阅读以下多个语言风格/表达方式,对其进行总结。" - "不要对其进行语义概括,而是尽可能找出其中不变的部分或共同表达,尽量使用原文" - f"{length_hint}保留共同特点:\n" - f"{chr(10).join(f'- {s}' for s in styles[-10:])}\n只输出概括内容。不要输出其他内容" - ) - - try: - summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2) - summary = summary.strip() - if summary: - # 检查总结长度是否超过限制 - if len(summary) <= max_summary_length: - return summary - else: - retry_count += 1 - logger.debug( - f"总结长度 {len(summary)} 超过限制 {max_summary_length} " - f"(输入最长项长度: {max_input_length}),重试第 {retry_count} 次" - ) - continue - except Exception as e: - logger.error(f"概括表达风格失败: {e}") - return None - - # 如果重试多次后仍然超过长度,返回 None(不进行总结) - logger.warning( - f"总结多次后仍超过长度限制,放弃总结。" - f"输入最长项长度: {max_input_length}, 最大允许长度: {max_summary_length}" - ) - return None - - async def _summarize_situations(self, situations: List[str]) -> Optional[str]: - if not situations: - return None + if not sanitized: + return fallback prompt = ( "请阅读以下多个聊天情境描述,并将它们概括成一句简短的话," "长度不超过20个字,保留共同特点:\n" - f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。" + f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。" ) try: @@ -691,7 +424,64 @@ class ExpressionLearner: return summary except Exception as e: logger.error(f"概括表达情境失败: {e}") - return None + return "/".join(sanitized) if sanitized else fallback + + async def _init_check_model(self) -> None: + """初始化检查用的 LLM 实例""" + if self.check_model is None: + try: + self.check_model = LLMRequest( + model_set=model_config.model_task_config.tool_use, + request_type="expression.check" + ) + logger.debug("检查用 LLM 实例初始化成功") + except Exception as e: + logger.error(f"创建检查用 LLM 实例失败: {e}") + + async def _check_expression_immediately(self, expr_obj: Expression) -> None: + """ + 立即检查表达方式(在 count 增加后调用) + + Args: + expr_obj: 要检查的表达方式对象 + """ + try: + # 检查是否启用自动检查 + if not global_config.expression.expression_self_reflect: + logger.debug("表达方式自动检查未启用,跳过立即检查") + return + + # 初始化检查用的 LLM + await self._init_check_model() + if self.check_model is None: + logger.warning("检查用 LLM 实例初始化失败,跳过立即检查") + return + + # 执行 LLM 评估 + suitable, reason, error = await single_expression_check( + expr_obj.situation, + expr_obj.style + ) + + # 更新数据库 + expr_obj.checked = True + expr_obj.rejected = not suitable # 通过则 rejected=False,不通过则 rejected=True + expr_obj.save() + + status = "通过" if suitable else "不通过" + logger.info( + f"表达方式立即检查完成 [ID: {expr_obj.id}] - {status} | " + f"Situation: {expr_obj.situation[:30]}... | " + f"Style: {expr_obj.style[:30]}... | " + f"Reason: {reason[:50] if reason else '无'}..." + ) + + if error: + logger.warning(f"表达方式立即检查时出现错误 [ID: {expr_obj.id}]: {error}") + + except Exception as e: + logger.error(f"立即检查表达方式失败 [ID: {expr_obj.id}]: {e}", exc_info=True) + # 检查失败时,保持 checked=False,等待后续自动检查任务处理 async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None: """ diff --git a/src/bw_learner/expression_reflector.py b/src/bw_learner/expression_reflector.py index c627b5b7..c98f0012 100644 --- a/src/bw_learner/expression_reflector.py +++ b/src/bw_learner/expression_reflector.py @@ -28,11 +28,11 @@ class ExpressionReflector: try: logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})") - if not global_config.expression.reflect: + if not global_config.expression.expression_self_reflect: logger.debug("[Expression Reflection] 表达反思功能未启用,跳过") return False - operator_config = global_config.expression.reflect_operator_id + operator_config = global_config.expression.manual_reflect_operator_id if not operator_config: logger.debug("[Expression Reflection] Operator ID 未配置,跳过") return False diff --git a/src/bw_learner/expression_selector.py b/src/bw_learner/expression_selector.py index 6768ffa3..457e5610 100644 --- a/src/bw_learner/expression_selector.py +++ b/src/bw_learner/expression_selector.py @@ -123,9 +123,11 @@ class ExpressionSelector: related_chat_ids = self.get_related_chat_ids(chat_id) # 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的 - style_query = Expression.select().where( - (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1) - ) + # 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的 + base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1) + if global_config.expression.expression_checked_only: + base_conditions = base_conditions & (Expression.checked) + style_query = Expression.select().where(base_conditions) style_exprs = [ { @@ -202,7 +204,11 @@ class ExpressionSelector: related_chat_ids = self.get_related_chat_ids(chat_id) # 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达 - style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)) + # 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的 + base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) + if global_config.expression.expression_checked_only: + base_conditions = base_conditions & (Expression.checked) + style_query = Expression.select().where(base_conditions) style_exprs = [ { @@ -295,7 +301,11 @@ class ExpressionSelector: # think_level == 1: 先选高count,再从所有表达方式中随机抽样 # 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的 related_chat_ids = self.get_related_chat_ids(chat_id) - style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)) + # 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的 + base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) + if global_config.expression.expression_checked_only: + base_conditions = base_conditions & (Expression.checked) + style_query = Expression.select().where(base_conditions) all_style_exprs = [ { diff --git a/src/bw_learner/learner_utils.py b/src/bw_learner/learner_utils.py index e871dde7..fabf555e 100644 --- a/src/bw_learner/learner_utils.py +++ b/src/bw_learner/learner_utils.py @@ -2,8 +2,7 @@ import re import difflib import random import json -from datetime import datetime -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, Tuple from src.common.logger import get_logger from src.config.config import global_config @@ -11,6 +10,7 @@ from src.chat.utils.chat_message_builder import ( build_readable_messages, ) from src.chat.utils.utils import parse_platform_accounts +from json_repair import repair_json logger = get_logger("learner_utils") @@ -88,33 +88,15 @@ def calculate_style_similarity(style1: str, style2: str) -> float: return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio() -def format_create_date(timestamp: float) -> str: - """ - 将时间戳格式化为可读的日期字符串 - - Args: - timestamp: 时间戳 - - Returns: - str: 格式化后的日期字符串 - """ - try: - return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") - except (ValueError, OSError): - return "未知时间" - - def _compute_weights(population: List[Dict]) -> List[float]: """ 根据表达的count计算权重,范围限定在1~5之间。 count越高,权重越高,但最多为基础权重的5倍。 - 如果表达已checked,权重会再乘以3倍。 """ if not population: return [] counts = [] - checked_flags = [] for item in population: count = item.get("count", 1) try: @@ -122,29 +104,19 @@ def _compute_weights(population: List[Dict]) -> List[float]: except (TypeError, ValueError): count_value = 1.0 counts.append(max(count_value, 0.0)) - # 获取checked状态 - checked = item.get("checked", False) - checked_flags.append(bool(checked)) min_count = min(counts) max_count = max(counts) if max_count == min_count: - base_weights = [1.0 for _ in counts] + weights = [1.0 for _ in counts] else: - base_weights = [] + weights = [] for count_value in counts: # 线性映射到[1,5]区间 normalized = (count_value - min_count) / (max_count - min_count) - base_weights.append(1.0 + normalized * 4.0) # 1~5 + weights.append(1.0 + normalized * 4.0) # 1~5 - # 如果checked,权重乘以3 - weights = [] - for base_weight, checked in zip(base_weights, checked_flags, strict=False): - if checked: - weights.append(base_weight * 3.0) - else: - weights.append(base_weight) return weights @@ -378,3 +350,149 @@ def is_bot_message(msg: Any) -> bool: bot_account = bot_accounts.get(platform) return bool(bot_account and user_id == bot_account) + + +def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: + """ + 解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。 + + 期望的 JSON 结构: + [ + {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式 + {"content": "词条", "source_id": "12"}, // 黑话 + ... + ] + + Returns: + Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: + 第一个列表是表达方式 (situation, style, source_id) + 第二个列表是黑话 (content, source_id) + """ + if not response: + return [], [] + + raw = response.strip() + + # 尝试提取 ```json 代码块 + json_block_pattern = r"```json\s*(.*?)\s*```" + match = re.search(json_block_pattern, raw, re.DOTALL) + if match: + raw = match.group(1).strip() + else: + # 去掉可能存在的通用 ``` 包裹 + raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE) + raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE) + raw = raw.strip() + + parsed = None + expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id) + jargon_entries: List[Tuple[str, str]] = [] # (content, source_id) + + try: + # 优先尝试直接解析 + if raw.startswith("[") and raw.endswith("]"): + parsed = json.loads(raw) + else: + repaired = repair_json(raw) + if isinstance(repaired, str): + parsed = json.loads(repaired) + else: + parsed = repaired + except Exception as parse_error: + # 如果解析失败,尝试修复中文引号问题 + # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 + try: + + def fix_chinese_quotes_in_json(text): + """使用状态机修复 JSON 字符串值中的中文引号""" + result = [] + i = 0 + in_string = False + escape_next = False + + while i < len(text): + char = text[i] + + if escape_next: + # 当前字符是转义字符后的字符,直接添加 + result.append(char) + escape_next = False + i += 1 + continue + + if char == "\\": + # 转义字符 + result.append(char) + escape_next = True + i += 1 + continue + + if char == '"' and not escape_next: + # 遇到英文引号,切换字符串状态 + in_string = not in_string + result.append(char) + i += 1 + continue + + if in_string: + # 在字符串值内部,将中文引号替换为转义的英文引号 + if char == '"': # 中文左引号 U+201C + result.append('\\"') + elif char == '"': # 中文右引号 U+201D + result.append('\\"') + else: + result.append(char) + else: + # 不在字符串内,直接添加 + result.append(char) + + i += 1 + + return "".join(result) + + fixed_raw = fix_chinese_quotes_in_json(raw) + + # 再次尝试解析 + if fixed_raw.startswith("[") and fixed_raw.endswith("]"): + parsed = json.loads(fixed_raw) + else: + repaired = repair_json(fixed_raw) + if isinstance(repaired, str): + parsed = json.loads(repaired) + else: + parsed = repaired + except Exception as fix_error: + logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}") + logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}") + logger.error(f"解析表达风格 JSON 失败,原始响应:{response}") + logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}") + return [], [] + + if isinstance(parsed, dict): + parsed_list = [parsed] + elif isinstance(parsed, list): + parsed_list = parsed + else: + logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}") + return [], [] + + for item in parsed_list: + if not isinstance(item, dict): + continue + + # 检查是否是表达方式条目(有 situation 和 style) + situation = str(item.get("situation", "")).strip() + style = str(item.get("style", "")).strip() + source_id = str(item.get("source_id", "")).strip() + + if situation and style and source_id: + # 表达方式条目 + expressions.append((situation, style, source_id)) + elif item.get("content"): + # 黑话条目(有 content 字段) + content = str(item.get("content", "")).strip() + source_id = str(item.get("source_id", "")).strip() + if content and source_id: + jargon_entries.append((content, source_id)) + + return expressions, jargon_entries \ No newline at end of file diff --git a/src/bw_learner/message_recorder.py b/src/bw_learner/message_recorder.py index 4d8a5015..fc570909 100644 --- a/src/bw_learner/message_recorder.py +++ b/src/bw_learner/message_recorder.py @@ -116,20 +116,12 @@ class MessageRecorder: f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}" ) - # 分别触发 expression_learner 和 jargon_miner 的处理 - # 传递提取的消息,避免它们重复获取 - # 触发 expression 学习(如果启用) + # 触发 expression_learner 和 jargon_miner 的处理 if self.enable_expression_learning: asyncio.create_task( - self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages) + self._trigger_expression_learning(messages) ) - # 触发 jargon 提取(如果启用),传递消息 - # if self.enable_jargon_learning: - # asyncio.create_task( - # self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages) - # ) - except Exception as e: logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") import traceback @@ -138,7 +130,7 @@ class MessageRecorder: # 即使失败也保持时间戳更新,避免频繁重试 async def _trigger_expression_learning( - self, timestamp_start: float, timestamp_end: float, messages: List[Any] + self, messages: List[Any] ) -> None: """ 触发 expression 学习,使用指定的消息列表 @@ -162,27 +154,6 @@ class MessageRecorder: traceback.print_exc() - async def _trigger_jargon_extraction( - self, timestamp_start: float, timestamp_end: float, messages: List[Any] - ) -> None: - """ - 触发 jargon 提取,使用指定的消息列表 - - Args: - timestamp_start: 开始时间戳 - timestamp_end: 结束时间戳 - messages: 消息列表 - """ - try: - # 传递消息给 JargonMiner,避免它重复获取 - await self.jargon_miner.run_once(messages=messages) - - except Exception as e: - logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}") - import traceback - - traceback.print_exc() - class MessageRecorderManager: """MessageRecorder 管理器""" diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 45e79308..6b53fc04 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -30,7 +30,7 @@ from src.chat.utils.chat_message_builder import ( get_raw_msg_before_timestamp_with_chat, ) from src.chat.utils.utils import record_replyer_action_temp -from src.hippo_memorizer.chat_history_summarizer import ChatHistorySummarizer +from src.memory_system.chat_history_summarizer import ChatHistorySummarizer if TYPE_CHECKING: from src.common.data_models.database_data_model import DatabaseMessages diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index a59f68e2..4c7dbc47 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -321,12 +321,7 @@ class Expression(BaseModel): situation = TextField() style = TextField() - - # new mode fields - context = TextField(null=True) - content_list = TextField(null=True) - style_list = TextField(null=True) # 存储相似的 style,格式与 content_list 相同(JSON 数组) count = IntegerField(default=1) last_active_time = FloatField() chat_id = TextField(index=True) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 99d47dd5..e02a5afb 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -323,10 +323,13 @@ class ExpressionConfig(ConfigBase): 格式: [["qq:12345:group", "qq:67890:private"]] """ - reflect: bool = False - """是否启用表达反思""" + expression_self_reflect: bool = False + """是否启用自动表达优化""" + + expression_manual_reflect: bool = False + """是否启用手动表达优化""" - reflect_operator_id: str = "" + manual_reflect_operator_id: str = "" """表达反思操作员ID""" allow_reflect: list[str] = field(default_factory=list) @@ -350,6 +353,26 @@ class ExpressionConfig(ConfigBase): - "planner": 仅使用 Planner 在 reply 动作中给出的 unknown_words 列表进行黑话检索 """ + expression_checked_only: bool = False + """ + 是否仅选择已检查且未拒绝的表达方式 + 当设置为 true 时,只有 checked=True 且 rejected=False 的表达方式才会被选择 + 当设置为 false 时,保留旧的筛选原则(仅排除 rejected=True 的表达方式) + """ + + + expression_auto_check_interval: int = 3600 + """ + 表达方式自动检查的间隔时间(单位:秒) + 默认值:3600秒(1小时) + """ + + expression_auto_check_count: int = 10 + """ + 每次自动检查时随机选取的表达方式数量 + 默认值:10条 + """ + def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: """ 解析流配置字符串并生成对应的 chat_id diff --git a/src/hippo_memorizer/memory_forget_task.py b/src/hippo_memorizer/memory_forget_task.py deleted file mode 100644 index ce2a9b2a..00000000 --- a/src/hippo_memorizer/memory_forget_task.py +++ /dev/null @@ -1,362 +0,0 @@ -""" -记忆遗忘任务 -每5分钟进行一次遗忘检查,根据不同的遗忘阶段删除记忆 -""" - -import time -import random -from typing import List - -from src.common.logger import get_logger -from src.common.database.database_model import ChatHistory -from src.manager.async_task_manager import AsyncTask - -logger = get_logger("memory_forget_task") - - -class MemoryForgetTask(AsyncTask): - """记忆遗忘任务,每5分钟执行一次""" - - def __init__(self): - # 每5分钟执行一次(300秒) - super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300) - - async def run(self): - """执行遗忘检查""" - try: - current_time = time.time() - # logger.info("[记忆遗忘] 开始遗忘检查...") - - # 执行4个阶段的遗忘检查 - # await self._forget_stage_1(current_time) - # await self._forget_stage_2(current_time) - # await self._forget_stage_3(current_time) - # await self._forget_stage_4(current_time) - - # logger.info("[记忆遗忘] 遗忘检查完成") - except Exception as e: - logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True) - - async def _forget_stage_1(self, current_time: float): - """ - 第一次遗忘检查: - 搜集所有:记忆还未被遗忘检查过(forget_times=0),且已经是30分钟之外的记忆 - 取count最高25%和最低25%,删除,然后标记被遗忘检查次数为1 - """ - try: - # 30分钟 = 1800秒 - time_threshold = current_time - 1800 - - # 查询符合条件的记忆:forget_times=0 且 end_time < time_threshold - candidates = list( - ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold)) - ) - - if not candidates: - logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆") - return - - logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆") - - # 按count排序 - candidates.sort(key=lambda x: x.count, reverse=True) - - # 计算要删除的数量(最高25%和最低25%) - total_count = len(candidates) - delete_count = int(total_count * 0.25) # 25% - - if delete_count == 0: - logger.debug("[记忆遗忘-阶段1] 删除数量为0,跳过") - return - - # 选择要删除的记录(处理count相同的情况:随机选择) - to_delete = [] - to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) - to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) - - # 去重(避免重复删除),使用id去重 - seen_ids = set() - unique_to_delete = [] - for record in to_delete: - if record.id not in seen_ids: - seen_ids.add(record.id) - unique_to_delete.append(record) - to_delete = unique_to_delete - - # 删除记录并更新forget_times - deleted_count = 0 - for record in to_delete: - try: - record.delete_instance() - deleted_count += 1 - except Exception as e: - logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}") - - # 更新剩余记录的forget_times为1 - to_delete_ids = {r.id for r in to_delete} - remaining = [r for r in candidates if r.id not in to_delete_ids] - if remaining: - # 批量更新 - ids_to_update = [r.id for r in remaining] - ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute() - - logger.info( - f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1" - ) - - except Exception as e: - logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True) - - async def _forget_stage_2(self, current_time: float): - """ - 第二次遗忘检查: - 搜集所有:记忆遗忘检查为1,且已经是8小时之外的记忆 - 取count最高7%和最低7%,删除,然后标记被遗忘检查次数为2 - """ - try: - # 8小时 = 28800秒 - time_threshold = current_time - 28800 - - # 查询符合条件的记忆:forget_times=1 且 end_time < time_threshold - candidates = list( - ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold)) - ) - - if not candidates: - logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆") - return - - logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆") - - # 按count排序 - candidates.sort(key=lambda x: x.count, reverse=True) - - # 计算要删除的数量(最高7%和最低7%) - total_count = len(candidates) - delete_count = int(total_count * 0.07) # 7% - - if delete_count == 0: - logger.debug("[记忆遗忘-阶段2] 删除数量为0,跳过") - return - - # 选择要删除的记录 - to_delete = [] - to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) - to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) - - # 去重 - to_delete = list(set(to_delete)) - - # 删除记录 - deleted_count = 0 - for record in to_delete: - try: - record.delete_instance() - deleted_count += 1 - except Exception as e: - logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}") - - # 更新剩余记录的forget_times为2 - to_delete_ids = {r.id for r in to_delete} - remaining = [r for r in candidates if r.id not in to_delete_ids] - if remaining: - ids_to_update = [r.id for r in remaining] - ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute() - - logger.info( - f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2" - ) - - except Exception as e: - logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True) - - async def _forget_stage_3(self, current_time: float): - """ - 第三次遗忘检查: - 搜集所有:记忆遗忘检查为2,且已经是48小时之外的记忆 - 取count最高5%和最低5%,删除,然后标记被遗忘检查次数为3 - """ - try: - # 48小时 = 172800秒 - time_threshold = current_time - 172800 - - # 查询符合条件的记忆:forget_times=2 且 end_time < time_threshold - candidates = list( - ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold)) - ) - - if not candidates: - logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆") - return - - logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆") - - # 按count排序 - candidates.sort(key=lambda x: x.count, reverse=True) - - # 计算要删除的数量(最高5%和最低5%) - total_count = len(candidates) - delete_count = int(total_count * 0.05) # 5% - - if delete_count == 0: - logger.debug("[记忆遗忘-阶段3] 删除数量为0,跳过") - return - - # 选择要删除的记录 - to_delete = [] - to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) - to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) - - # 去重 - to_delete = list(set(to_delete)) - - # 删除记录 - deleted_count = 0 - for record in to_delete: - try: - record.delete_instance() - deleted_count += 1 - except Exception as e: - logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}") - - # 更新剩余记录的forget_times为3 - to_delete_ids = {r.id for r in to_delete} - remaining = [r for r in candidates if r.id not in to_delete_ids] - if remaining: - ids_to_update = [r.id for r in remaining] - ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute() - - logger.info( - f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3" - ) - - except Exception as e: - logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True) - - async def _forget_stage_4(self, current_time: float): - """ - 第四次遗忘检查: - 搜集所有:记忆遗忘检查为3,且已经是7天之外的记忆 - 取count最高2%和最低2%,删除,然后标记被遗忘检查次数为4 - """ - try: - # 7天 = 604800秒 - time_threshold = current_time - 604800 - - # 查询符合条件的记忆:forget_times=3 且 end_time < time_threshold - candidates = list( - ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold)) - ) - - if not candidates: - logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆") - return - - logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆") - - # 按count排序 - candidates.sort(key=lambda x: x.count, reverse=True) - - # 计算要删除的数量(最高2%和最低2%) - total_count = len(candidates) - delete_count = int(total_count * 0.02) # 2% - - if delete_count == 0: - logger.debug("[记忆遗忘-阶段4] 删除数量为0,跳过") - return - - # 选择要删除的记录 - to_delete = [] - to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) - to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) - - # 去重 - to_delete = list(set(to_delete)) - - # 删除记录 - deleted_count = 0 - for record in to_delete: - try: - record.delete_instance() - deleted_count += 1 - except Exception as e: - logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}") - - # 更新剩余记录的forget_times为4 - to_delete_ids = {r.id for r in to_delete} - remaining = [r for r in candidates if r.id not in to_delete_ids] - if remaining: - ids_to_update = [r.id for r in remaining] - ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute() - - logger.info( - f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4" - ) - - except Exception as e: - logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True) - - def _handle_same_count_random( - self, candidates: List[ChatHistory], delete_count: int, mode: str - ) -> List[ChatHistory]: - """ - 处理count相同的情况,随机选择要删除的记录 - - Args: - candidates: 候选记录列表(已按count排序) - delete_count: 要删除的数量 - mode: "high" 表示选择最高count的记录,"low" 表示选择最低count的记录 - - Returns: - 要删除的记录列表 - """ - if not candidates or delete_count == 0: - return [] - - to_delete = [] - - if mode == "high": - # 从最高count开始选择 - start_idx = 0 - while start_idx < len(candidates) and len(to_delete) < delete_count: - # 找到所有count相同的记录 - current_count = candidates[start_idx].count - same_count_records = [] - idx = start_idx - while idx < len(candidates) and candidates[idx].count == current_count: - same_count_records.append(candidates[idx]) - idx += 1 - - # 如果相同count的记录数量 <= 还需要删除的数量,全部选择 - needed = delete_count - len(to_delete) - if len(same_count_records) <= needed: - to_delete.extend(same_count_records) - else: - # 随机选择需要的数量 - to_delete.extend(random.sample(same_count_records, needed)) - - start_idx = idx - - else: # mode == "low" - # 从最低count开始选择 - start_idx = len(candidates) - 1 - while start_idx >= 0 and len(to_delete) < delete_count: - # 找到所有count相同的记录 - current_count = candidates[start_idx].count - same_count_records = [] - idx = start_idx - while idx >= 0 and candidates[idx].count == current_count: - same_count_records.append(candidates[idx]) - idx -= 1 - - # 如果相同count的记录数量 <= 还需要删除的数量,全部选择 - needed = delete_count - len(to_delete) - if len(same_count_records) <= needed: - to_delete.extend(same_count_records) - else: - # 随机选择需要的数量 - to_delete.extend(random.sample(same_count_records, needed)) - - start_idx = idx - - return to_delete diff --git a/src/main.py b/src/main.py index 57950bf9..aa985ce0 100644 --- a/src/main.py +++ b/src/main.py @@ -24,6 +24,7 @@ from src.plugin_system.core.plugin_manager import plugin_manager # 导入消息API和traceback模块 from src.common.message import get_global_api from src.dream.dream_agent import start_dream_scheduler +from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask # 插件系统现在使用统一的插件加载器 @@ -87,16 +88,11 @@ class MainSystem: # 添加统计信息输出任务 await async_task_manager.add_task(StatisticOutputTask()) - # 添加聊天流统计任务(每5分钟生成一次报告,统计最近30天的数据) - # await async_task_manager.add_task(TokenStatisticsTask()) - # 添加遥测心跳任务 await async_task_manager.add_task(TelemetryHeartBeatTask()) - # 添加记忆遗忘任务 - from src.hippo_memorizer.memory_forget_task import MemoryForgetTask - - await async_task_manager.add_task(MemoryForgetTask()) + # 添加表达方式自动检查任务 + await async_task_manager.add_task(ExpressionAutoCheckTask()) # 启动API服务器 # start_api_server() diff --git a/src/hippo_memorizer/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py similarity index 100% rename from src/hippo_memorizer/chat_history_summarizer.py rename to src/memory_system/chat_history_summarizer.py diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index f006a600..ee80feee 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.2.8" +version = "7.3.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- # 如果你想要修改配置文件,请递增version的值 @@ -91,15 +91,23 @@ expression_groups = [ # 注意:如果为群聊,则需要设置为group,如果设置为私聊,则需要设置为private ] -reflect = false # 是否启用表达反思(Bot主动向管理员询问表达方式是否合适) -reflect_operator_id = "" # 表达反思操作员ID,格式:platform:id:type (例如 "qq:123456:private" 或 "qq:654321:group") +expression_checked_only = true # 是否仅选择已检查且未拒绝的表达方式。当设置为 true 时,只有 checked=True 且 rejected=False 的表达方式才会被选择;当设置为 false 时,保留旧的筛选原则(仅排除 rejected=True 的表达方式) + +expression_self_reflect = true # 是否启用自动表达优化(Bot主动向管理员询问表达方式是否合适) +expression_auto_check_interval = 600 # 表达方式自动检查的间隔时间(单位:秒),默认值:3600秒(1小时) +expression_auto_check_count = 20 # 每次自动检查时随机选取的表达方式数量,默认值:10条 + +expression_manual_reflect = false # 是否启用手动表达优化(Bot主动向管理员询问表达方式是否合适) +manual_reflect_operator_id = "" # 手动表达优化操作员ID,格式:platform:id:type (例如 "qq:123456:private" 或 "qq:654321:group") allow_reflect = [] # 允许进行表达反思的聊天流ID列表,格式:["qq:123456:private", "qq:654321:group", ...],只有在此列表中的聊天流才会提出问题并跟踪。如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true) + all_global_jargon = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除 enable_jargon_explanation = true # 是否在回复前尝试对上下文中的黑话进行解释(关闭可减少一次LLM调用,仅影响回复前的黑话匹配与解释,不影响黑话学习) jargon_mode = "planner" # 黑话解释来源模式,可选: "context"(使用上下文自动匹配黑话) 或 "planner"(仅使用Planner在reply动作中给出的unknown_words列表) + [chat] # 麦麦的聊天设置 talk_value = 1 # 聊天频率,越小越沉默,范围0-1 mentioned_bot_reply = true # 是否启用提及必回复