diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index b5aa6aff..676616e5 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -185,13 +185,11 @@ class HeartFChatting: question_probability = 0 if time.time() - self.last_active_time > 3600: - question_probability = 0.01 - elif time.time() - self.last_active_time > 1200: - question_probability = 0.005 - elif time.time() - self.last_active_time > 600: question_probability = 0.001 - else: + elif time.time() - self.last_active_time > 1200: question_probability = 0.0003 + else: + question_probability = 0.0001 question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id) @@ -210,7 +208,7 @@ class HeartFChatting: if question: logger.info(f"{self.log_prefix} 问题: {question}") await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id) - await self._lift_question_reply(question,context,cycle_timers,thinking_id) + await self._lift_question_reply(question,context,thinking_id) else: logger.info(f"{self.log_prefix} 无问题") # self.end_cycle(cycle_timers, thinking_id) @@ -550,8 +548,8 @@ class HeartFChatting: traceback.print_exc() return False, "" - async def _lift_question_reply(self, question: str, context: str, cycle_timers: Dict[str, float], thinking_id: str): - reason = f"在聊天中:\n{context}\n你对问题\"{question}\"感到好奇,想要和群友讨论" + async def _lift_question_reply(self, question: str, question_context: str, thinking_id: str): + reason = f"在聊天中:\n{question_context}\n你对问题\"{question}\"感到好奇,想要和群友讨论" new_msg = get_raw_msg_before_timestamp_with_chat( chat_id=self.stream_id, timestamp=time.time(), diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index f3efa943..89e0a019 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -310,7 +310,6 @@ class Expression(BaseModel): last_active_time = FloatField() chat_id = TextField(index=True) - type = TextField() create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据 class Meta: diff --git a/src/express/expression_learner.py b/src/express/expression_learner.py index 56536a21..3ffb6cfa 100644 --- a/src/express/expression_learner.py +++ b/src/express/expression_learner.py @@ -1,11 +1,9 @@ import time -import random import json import os import re from datetime import datetime -import jieba -from typing import List, Dict, Optional, Any, Tuple +from typing import List, Optional, Tuple import traceback import difflib from src.common.logger import get_logger @@ -148,7 +146,7 @@ class ExpressionLearner: return True - async def trigger_learning_for_chat(self) -> bool: + async def trigger_learning_for_chat(self): """ 为指定聊天流触发学习 @@ -159,11 +157,10 @@ class ExpressionLearner: bool: 是否成功触发学习 """ if not self.should_trigger_learning(): - return False + return try: - logger.info(f"为聊天流 {self.chat_name} 触发表达学习") - + logger.info(f"在聊天流 {self.chat_name} 学习表达方式") # 学习语言风格 learnt_style = await self.learn_and_store(num=25) @@ -172,15 +169,13 @@ class ExpressionLearner: if learnt_style: logger.info(f"聊天流 {self.chat_name} 表达学习完成") - return True else: logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果") - return False except Exception as e: logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}") traceback.print_exc() - return False + return @@ -188,127 +183,87 @@ class ExpressionLearner: """ 学习并存储表达方式 """ - res = await self.learn_expression(num) + learnt_expressions = await self.learn_expression(num) - if res is None: + if learnt_expressions is None: logger.info("没有学习到表达风格") return [] - learnt_expressions = res + + # 展示学到的表达方式 learnt_expressions_str = "" for ( - _chat_id, situation, style, _context, _up_content, ) in learnt_expressions: learnt_expressions_str += f"{situation}->{style}\n" - logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") - # 按chat_id分组 - chat_dict: Dict[str, List[Dict[str, Any]]] = {} + current_time = time.time() + + # 存储到数据库 Expression 表并训练 style_learner + has_new_expressions = False # 记录是否有新的表达方式 + learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例 + for ( - chat_id, situation, style, context, up_content, ) in learnt_expressions: - if chat_id not in chat_dict: - chat_dict[chat_id] = [] - chat_dict[chat_id].append( - { - "situation": situation, - "style": style, - "context": context, - "up_content": up_content, - } + # 查找是否已存在相似表达方式 + query = Expression.select().where( + (Expression.chat_id == self.chat_id) + & (Expression.situation == situation) + & (Expression.style == style) ) - - current_time = time.time() - - # 存储到数据库 Expression 表并训练 style_learner - trained_chat_ids = set() # 记录已训练的聊天室 - - for chat_id, expr_list in chat_dict.items(): - for new_expr in expr_list: - # 查找是否已存在相似表达方式 - query = Expression.select().where( - (Expression.chat_id == chat_id) - & (Expression.type == "style") - & (Expression.situation == new_expr["situation"]) - & (Expression.style == new_expr["style"]) + if query.exists(): + # 表达方式完全相同,只更新时间戳 + expr_obj = query.get() + expr_obj.last_active_time = current_time + expr_obj.save() + continue + else: + Expression.create( + situation=situation, + style=style, + last_active_time=current_time, + chat_id=self.chat_id, + create_date=current_time, # 手动设置创建日期 + context=context, + up_content=up_content, ) - if query.exists(): - expr_obj = query.get() - # 50%概率替换内容 - if random.random() < 0.5: - expr_obj.situation = new_expr["situation"] - expr_obj.style = new_expr["style"] - expr_obj.context = new_expr["context"] - expr_obj.up_content = new_expr["up_content"] - expr_obj.last_active_time = current_time - expr_obj.save() - else: - Expression.create( - situation=new_expr["situation"], - style=new_expr["style"], - last_active_time=current_time, - chat_id=chat_id, - type="style", - create_date=current_time, # 手动设置创建日期 - context=new_expr["context"], - up_content=new_expr["up_content"], - ) - - # 训练 style_learner(up_content 和 style 必定存在) - try: - # 获取 learner 实例 - learner = style_learner_manager.get_learner(chat_id) - - # 先添加风格和对应的 situation(如果存在) - if new_expr.get("situation"): - learner.add_style(new_expr["style"], new_expr["situation"]) - else: - learner.add_style(new_expr["style"]) - - # 学习映射关系 - success = style_learner_manager.learn_mapping( - chat_id, - new_expr["up_content"], - new_expr["style"] - ) - if success: - logger.debug(f"StyleLearner学习成功: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}" + - (f" (situation: {new_expr['situation']})" if new_expr.get("situation") else "")) - trained_chat_ids.add(chat_id) - else: - logger.warning(f"StyleLearner学习失败: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}") - except Exception as e: - logger.error(f"StyleLearner学习异常: {chat_id} - {e}") + has_new_expressions = True - # 限制最大数量 - # exprs = list( - # Expression.select() - # .where((Expression.chat_id == chat_id) & (Expression.type == "style")) - # .order_by(Expression.last_active_time.asc()) - # ) - # if len(exprs) > MAX_EXPRESSION_COUNT: - # 删除最久未活跃的多余表达方式 - # for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: - # expr.delete_instance() - - # 保存训练好的 style_learner 模型 - if trained_chat_ids: + # 训练 style_learner(up_content 和 style 必定存在) try: - logger.info(f"开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...") - save_success = style_learner_manager.save_all_models() + learner.add_style(style, situation) + + # 学习映射关系 + success = style_learner_manager.learn_mapping( + self.chat_id, + up_content, + style + ) + if success: + logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else "")) + else: + logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}") + except Exception as e: + logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}") + + + # 保存当前聊天室的 style_learner 模型 + if has_new_expressions: + try: + logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...") + save_success = learner.save(style_learner_manager.model_save_path) if save_success: - logger.info(f"StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}") + logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}") else: - logger.warning("StyleLearner 模型保存失败") + logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}") except Exception as e: logger.error(f"StyleLearner 模型保存异常: {e}") @@ -415,15 +370,12 @@ class ExpressionLearner: async def learn_expression( self, num: int = 10 - ) -> Optional[List[Tuple[str, str, str, List[str], str]]]: + ) -> Optional[List[Tuple[str, str, str, str]]]: """从指定聊天流学习表达方式 Args: num: 学习数量 """ - type_str = "语言风格" - prompt = "learn_style_prompt" - current_time = time.time() # 获取上次学习之后的消息 @@ -436,14 +388,14 @@ class ExpressionLearner: # print(random_msg) if not random_msg or random_msg == []: return None - # 转化成str - _chat_id: str = random_msg[0].chat_id - # random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal") + + # 学习用 random_msg_str: str = await build_anonymous_messages(random_msg) + # 溯源用 random_msg_match_str: str = await build_bare_messages(random_msg) prompt: str = await global_prompt_manager.format_prompt( - prompt, + "learn_style_prompt", chat_str=random_msg_str, ) @@ -453,20 +405,18 @@ class ExpressionLearner: try: response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) except Exception as e: - logger.error(f"学习{type_str}失败: {e}") + logger.error(f"学习表达方式失败,模型生成出错: {e}") return None - + expressions: List[Tuple[str, str]] = self.parse_expression_response(response) # logger.debug(f"学习{type_str}的response: {response}") - expressions: List[Tuple[str, str]] = self.parse_expression_response(response) - + + # 对表达方式溯源 matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context( expressions, random_msg_match_str ) - - print(f"matched_expressions: {matched_expressions}") - # 为每条消息构建与 build_bare_messages 相同规则的精简文本列表,保留到原消息索引的映射 + # 这里有待斟酌,需要进一步处理图片和表情包 bare_lines: List[Tuple[int, str]] = [] # (original_index, bare_content) pic_pattern = r"\[picid:[^\]]+\]" reply_pattern = r"回复<[^:<>]+:[^:<>]+>" @@ -479,7 +429,6 @@ class ExpressionLearner: content = content.strip() if content: bare_lines.append((idx, content)) - # 将 matched_expressions 结合上一句 up_content(若不存在上一句则跳过) filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content) for situation, style, context in matched_expressions: @@ -503,11 +452,7 @@ class ExpressionLearner: if not filtered_with_up: return None - results: List[Tuple[str, str, str, str]] = [] - for (situation, style, context, up_content) in filtered_with_up: - results.append((self.chat_id, situation, style, context, up_content)) - - return results + return filtered_with_up def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: diff --git a/src/express/expression_selector.py b/src/express/expression_selector.py index cab764c8..9ebef43c 100644 --- a/src/express/expression_selector.py +++ b/src/express/expression_selector.py @@ -136,13 +136,13 @@ class ExpressionSelector: return group_chat_ids return [chat_id] - def get_model_predicted_expressions(self, chat_id: str, chat_info: str, total_num: int = 10) -> List[Dict[str, Any]]: + def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]: """ 使用 style_learner 模型预测最合适的表达方式 Args: chat_id: 聊天室ID - chat_info: 聊天内容信息 + target_message: 目标消息内容 total_num: 需要预测的数量 Returns: @@ -152,10 +152,7 @@ class ExpressionSelector: # 支持多chat_id合并预测 related_chat_ids = self.get_related_chat_ids(chat_id) - # 从聊天信息中提取关键内容作为预测输入 - # 这里可以进一步优化,提取更合适的预测输入 - prediction_input = self._extract_prediction_input(chat_info) - + predicted_expressions = [] # 为每个相关的chat_id进行预测 @@ -163,7 +160,7 @@ class ExpressionSelector: try: # 使用 style_learner 预测最合适的风格 best_style, scores = style_learner_manager.predict_style( - related_chat_id, prediction_input, top_k=total_num + related_chat_id, target_message, top_k=total_num ) if best_style and scores: @@ -175,7 +172,6 @@ class ExpressionSelector: # 从数据库查找对应的表达记录 expr_query = Expression.select().where( (Expression.chat_id == related_chat_id) & - (Expression.type == "style") & (Expression.situation == situation) & (Expression.style == best_style) ) @@ -188,11 +184,12 @@ class ExpressionSelector: "style": expr.style, "last_active_time": expr.last_active_time, "source_id": expr.chat_id, - "type": "style", "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "prediction_score": scores.get(best_style, 0.0), - "prediction_input": prediction_input + "prediction_input": target_message }) + else: + logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式") except Exception as e: logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}") @@ -208,39 +205,11 @@ class ExpressionSelector: except Exception as e: logger.error(f"模型预测表达方式失败: {e}") # 如果预测失败,回退到随机选择 - return self._fallback_random_expressions(chat_id, total_num) + return self._random_expressions(chat_id, total_num) - def _extract_prediction_input(self, chat_info: str) -> str: + def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]: """ - 从聊天信息中提取用于预测的关键内容 - - Args: - chat_info: 聊天内容信息 - - Returns: - str: 提取的预测输入 - """ - try: - # 简单的提取策略:取最后几句话作为预测输入 - lines = chat_info.strip().split('\n') - if not lines: - return "" - - # 取最后3行作为预测输入 - recent_lines = lines[-1:] - prediction_input = ' '.join(recent_lines).strip() - - logger.info(f"提取预测输入: {prediction_input}") - - return prediction_input - - except Exception as e: - logger.warning(f"提取预测输入失败: {e}") - return "" - - def _fallback_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]: - """ - 回退到随机选择表达方式 + 随机选择表达方式 Args: chat_id: 聊天室ID @@ -255,7 +224,7 @@ class ExpressionSelector: # 优化:一次性查询所有相关chat_id的表达方式 style_query = Expression.select().where( - (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") + (Expression.chat_id.in_(related_chat_ids)) ) style_exprs = [ @@ -265,7 +234,6 @@ class ExpressionSelector: "style": expr.style, "last_active_time": expr.last_active_time, "source_id": expr.chat_id, - "type": "style", "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, } for expr in style_query @@ -277,7 +245,7 @@ class ExpressionSelector: else: selected_style = [] - logger.info(f"回退到随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式") + logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式") return selected_style except Exception as e: @@ -315,7 +283,7 @@ class ExpressionSelector: if expression_mode == "exp_model": # exp_model模式:直接使用模型预测,不经过LLM logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式") - return await self._select_expressions_model_only(chat_id, chat_info, max_num) + return await self._select_expressions_model_only(chat_id, target_message, max_num) elif expression_mode == "classic": # classic模式:随机选择+LLM选择 logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式") @@ -327,7 +295,7 @@ class ExpressionSelector: async def _select_expressions_model_only( self, chat_id: str, - chat_info: str, + target_message: str, max_num: int = 10, ) -> Tuple[List[Dict[str, Any]], List[int]]: """ @@ -335,7 +303,7 @@ class ExpressionSelector: Args: chat_id: 聊天流ID - chat_info: 聊天内容信息 + target_message: 目标消息内容 max_num: 最大选择数量 Returns: @@ -343,11 +311,7 @@ class ExpressionSelector: """ try: # 使用模型预测最合适的表达方式 - style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, max_num * 2) - - - # 直接取前max_num个预测结果 - selected_expressions = style_exprs[:max_num] + selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num) selected_ids = [expr["id"] for expr in selected_expressions] # 更新last_active_time @@ -381,8 +345,8 @@ class ExpressionSelector: Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 """ try: - # 1. 使用模型预测最合适的表达方式 - style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, 20) + # 1. 使用随机抽样选择表达方式 + style_exprs = self._random_expressions(chat_id, 20) if len(style_exprs) < 10: logger.info(f"聊天流 {chat_id} 表达方式正在积累中") @@ -460,21 +424,6 @@ class ExpressionSelector: logger.error(f"classic模式处理表达方式选择时出错: {e}") return [], [] - async def select_suitable_expressions_llm( - self, - chat_id: str, - chat_info: str, - max_num: int = 10, - target_message: Optional[str] = None, - ) -> Tuple[List[Dict[str, Any]], List[int]]: - """ - 使用LLM选择适合的表达方式(保持向后兼容) - - 注意:此方法已被 select_suitable_expressions 替代,建议使用新方法 - """ - logger.warning("select_suitable_expressions_llm 方法已过时,请使用 select_suitable_expressions") - return await self.select_suitable_expressions(chat_id, chat_info, max_num, target_message) - def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]): """对一批表达方式更新last_active_time""" if not expressions_to_update: @@ -482,19 +431,17 @@ class ExpressionSelector: updates_by_key = {} for expr in expressions_to_update: source_id: str = expr.get("source_id") # type: ignore - expr_type: str = expr.get("type", "style") situation: str = expr.get("situation") # type: ignore style: str = expr.get("style") # type: ignore if not source_id or not situation or not style: logger.warning(f"表达方式缺少必要字段,无法更新: {expr}") continue - key = (source_id, expr_type, situation, style) + key = (source_id, situation, style) if key not in updates_by_key: updates_by_key[key] = expr - for chat_id, expr_type, situation, style in updates_by_key: + for chat_id, situation, style in updates_by_key: query = Expression.select().where( (Expression.chat_id == chat_id) - & (Expression.type == expr_type) & (Expression.situation == situation) & (Expression.style == style) ) diff --git a/src/express/expressor_model/model.py b/src/express/expressor_model/model.py index d8aec88a..d47873d9 100644 --- a/src/express/expressor_model/model.py +++ b/src/express/expressor_model/model.py @@ -58,8 +58,18 @@ class ExpressorModel: # 取最高分 if not scores: return None, {} - best = max(scores.items(), key=lambda x: x[1])[0] - return best, scores + + # 根据k参数限制返回的候选数量 + if k is not None and k > 0: + # 按分数降序排序,取前k个 + sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) + limited_scores = dict(sorted_scores[:k]) + best = sorted_scores[0][0] if sorted_scores else None + return best, limited_scores + else: + # 如果没有指定k,返回所有分数 + best = max(scores.items(), key=lambda x: x[1])[0] + return best, scores def update_positive(self, text: str, cid: str): """更新正反馈学习""" diff --git a/src/memory_system/question_maker.py b/src/memory_system/question_maker.py index 9a814bf2..316afee5 100644 --- a/src/memory_system/question_maker.py +++ b/src/memory_system/question_maker.py @@ -60,13 +60,8 @@ class QuestionMaker: # 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个 conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0] - if not conflicts_with_zero: - if random.random() >= 0.01: - return None - # 以均匀概率选择一个(此时权重都等同于 0.05,无需再按权重) - chosen_conflict = random.choice(conflicts) - else: - # 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.05 + if conflicts_with_zero: + # 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.01 weights = [] for conflict in conflicts: current_raise_time = getattr(conflict, "raise_time", 0) or 0 @@ -77,12 +72,9 @@ class QuestionMaker: chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0] # 选中后,自增 raise_time 并保存 - try: chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1 chosen_conflict.save() - except Exception: - # 静默失败不影响流程 - pass + return chosen_conflict diff --git a/src/memory_system/questions.py b/src/memory_system/questions.py index ac2cf341..3816213c 100644 --- a/src/memory_system/questions.py +++ b/src/memory_system/questions.py @@ -288,7 +288,7 @@ class ConflictTracker: if existing_conflict: # 检查raise_time是否大于3且没有答案 current_raise_time = getattr(existing_conflict, "raise_time", 0) or 0 - if current_raise_time > 1 and not existing_conflict.answer: + if current_raise_time > 0 and not existing_conflict.answer: # 删除该条目 await self.delete_conflict(original_question, tracker.chat_id) logger.info(f"追踪结束后删除条目(raise_time={current_raise_time}且无答案): {original_question}") diff --git a/test_expression_style_situation_integration.py b/test_expression_style_situation_integration.py deleted file mode 100644 index 5fedf8e5..00000000 --- a/test_expression_style_situation_integration.py +++ /dev/null @@ -1,188 +0,0 @@ -""" -测试修改后的 expression_learner 与 style_learner 的集成 -验证学习新表达时是否正确处理 situation 字段 -""" - -import os -import sys -import asyncio -import time - -# 添加项目根目录到Python路径 -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from src.express.expression_learner import ExpressionLearner -from src.express.style_learner import style_learner_manager -from src.common.logger import get_logger - -logger = get_logger("expression_style_integration_test") - - -async def test_expression_style_integration(): - """测试 expression_learner 与 style_learner 的集成(包含 situation)""" - print("=== Expression Learner 与 Style Learner 集成测试(含 Situation) ===\n") - - # 创建测试聊天室ID - test_chat_id = "test_integration_situation_chat" - - # 创建 ExpressionLearner 实例 - expression_learner = ExpressionLearner(test_chat_id) - - print(f"测试聊天室: {test_chat_id}") - - # 模拟学习到的表达数据(包含 situation) - mock_learnt_expressions = [ - (test_chat_id, "打招呼", "温柔回复", "你好,有什么可以帮助你的吗?", "你好"), - (test_chat_id, "表示感谢", "礼貌回复", "谢谢你的帮助!", "谢谢"), - (test_chat_id, "表达惊讶", "幽默回复", "哇,这也太厉害了吧!", "太棒了"), - (test_chat_id, "询问问题", "严肃回复", "请详细解释一下这个问题。", "请解释"), - (test_chat_id, "表达开心", "活泼回复", "哈哈,太好玩了!", "哈哈"), - ] - - print("模拟学习到的表达数据(包含 situation):") - for chat_id, situation, style, context, up_content in mock_learnt_expressions: - print(f" {situation} -> {style} (输入: {up_content})") - - # 模拟 learn_and_store 方法的处理逻辑 - print(f"\n开始处理学习数据...") - - # 按chat_id分组 - chat_dict = {} - for chat_id, situation, style, context, up_content in mock_learnt_expressions: - if chat_id not in chat_dict: - chat_dict[chat_id] = [] - chat_dict[chat_id].append({ - "situation": situation, - "style": style, - "context": context, - "up_content": up_content, - }) - - # 训练 style_learner(包含 situation 处理) - trained_chat_ids = set() - - for chat_id, expr_list in chat_dict.items(): - print(f"\n处理聊天室: {chat_id}") - - for new_expr in expr_list: - # 训练 style_learner(包含 situation) - if new_expr.get("up_content") and new_expr.get("style"): - try: - # 获取 learner 实例 - learner = style_learner_manager.get_learner(chat_id) - - # 先添加风格和对应的 situation(如果不存在) - if new_expr.get("situation"): - learner.add_style(new_expr["style"], new_expr["situation"]) - print(f" ✓ 添加风格: '{new_expr['style']}' (situation: '{new_expr['situation']}')") - else: - learner.add_style(new_expr["style"]) - print(f" ✓ 添加风格: '{new_expr['style']}' (无 situation)") - - # 学习映射关系 - success = style_learner_manager.learn_mapping( - chat_id, - new_expr["up_content"], - new_expr["style"] - ) - if success: - print(f" ✓ StyleLearner学习成功: {new_expr['up_content']} -> {new_expr['style']}" + - (f" (situation: {new_expr['situation']})" if new_expr.get("situation") else "")) - trained_chat_ids.add(chat_id) - else: - print(f" ✗ StyleLearner学习失败: {new_expr['up_content']} -> {new_expr['style']}") - except Exception as e: - print(f" ✗ StyleLearner学习异常: {e}") - - # 保存模型 - if trained_chat_ids: - print(f"\n开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...") - try: - save_success = style_learner_manager.save_all_models() - - if save_success: - print(f"✓ StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}") - else: - print("✗ StyleLearner 模型保存失败") - - except Exception as e: - print(f"✗ StyleLearner 模型保存异常: {e}") - - # 测试预测功能 - print(f"\n测试 StyleLearner 预测功能:") - test_inputs = ["你好", "谢谢", "太棒了", "请解释", "哈哈"] - - for test_input in test_inputs: - try: - best_style, scores = style_learner_manager.predict_style(test_chat_id, test_input, top_k=3) - if best_style: - # 获取对应的 situation - learner = style_learner_manager.get_learner(test_chat_id) - situation = learner.get_situation(best_style) - print(f" 输入: '{test_input}' -> 预测: '{best_style}' (situation: '{situation}')") - if scores: - top_scores = dict(list(scores.items())[:3]) - print(f" 分数: {top_scores}") - else: - print(f" 输入: '{test_input}' -> 无预测结果") - except Exception as e: - print(f" 输入: '{test_input}' -> 预测异常: {e}") - - # 获取统计信息 - print(f"\nStyleLearner 统计信息:") - try: - stats = style_learner_manager.get_all_stats() - if test_chat_id in stats: - chat_stats = stats[test_chat_id] - print(f" 聊天室: {test_chat_id}") - print(f" 总样本数: {chat_stats['total_samples']}") - print(f" 当前风格数: {chat_stats['style_count']}") - print(f" 最大风格数: {chat_stats['max_styles']}") - print(f" 风格列表: {chat_stats['all_styles']}") - - # 显示每个风格的 situation 信息 - print(f" 风格和 situation 信息:") - for style in chat_stats['all_styles']: - situation = learner.get_situation(style) - print(f" '{style}' -> situation: '{situation}'") - else: - print(f" 未找到聊天室 {test_chat_id} 的统计信息") - except Exception as e: - print(f" 获取统计信息异常: {e}") - - # 测试模型保存和加载 - print(f"\n测试模型保存和加载...") - try: - # 创建新的管理器并加载模型 - new_manager = style_learner_manager # 使用同一个管理器 - new_learner = new_manager.get_learner(test_chat_id) - - # 验证加载后的 situation 信息 - loaded_style_info = new_learner.get_all_style_info() - print(f" 加载后风格数: {len(loaded_style_info)}") - for style, (style_id, situation) in loaded_style_info.items(): - print(f" 加载验证: '{style}' -> situation: '{situation}'") - - print("✓ 模型保存和加载测试通过") - except Exception as e: - print(f"✗ 模型保存和加载测试失败: {e}") - - print(f"\n=== 集成测试完成 ===") - print(f"✅ 所有功能测试通过!") - print(f"✓ Expression Learner 学习到新表达时自动添加 situation 到 StyleLearner") - print(f"✓ StyleLearner 正确存储和获取 situation 信息") - print(f"✓ 预测功能正常工作,可以获取对应的 situation") - print(f"✓ 模型保存和加载支持 situation 字段") - - -def main(): - """主函数""" - print("Expression Learner 与 Style Learner 集成测试(含 Situation)") - print("=" * 70) - - # 运行异步测试 - asyncio.run(test_expression_style_integration()) - - -if __name__ == "__main__": - main()