diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index cc29d6f2..8a4b0986 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -65,24 +65,20 @@ class ExpressionLearner: self.chat_id = chat_id self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id - # 维护每个chat的上次学习时间 self.last_learning_time: float = time.time() - + # 学习参数 self.min_messages_for_learning = 25 # 触发学习所需的最少消息数 self.min_learning_interval = 300 # 最短学习时间间隔(秒) - - - def can_learn_for_chat(self) -> bool: """ 检查指定聊天流是否允许学习表达 - + Args: chat_id: 聊天流ID - + Returns: bool: 是否允许学习 """ @@ -96,10 +92,10 @@ class ExpressionLearner: def should_trigger_learning(self) -> bool: """ 检查是否应该触发学习 - + Args: chat_id: 聊天流ID - + Returns: bool: 是否应该触发学习 """ @@ -107,23 +103,25 @@ class ExpressionLearner: # 获取该聊天流的学习强度 try: - _, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id) + _, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat( + self.chat_id + ) except Exception as e: logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}") return False - + # 检查是否允许学习 if not enable_learning: return False - + # 根据学习强度计算最短学习时间间隔 min_interval = self.min_learning_interval / learning_intensity - + # 检查时间间隔 time_diff = current_time - self.last_learning_time if time_diff < min_interval: return False - + # 检查消息数量(只检查指定聊天流的消息) recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, @@ -133,69 +131,42 @@ class ExpressionLearner: if not recent_messages or len(recent_messages) < self.min_messages_for_learning: return False - + return True async def trigger_learning_for_chat(self) -> bool: """ 为指定聊天流触发学习 - + Args: chat_id: 聊天流ID - + Returns: bool: 是否成功触发学习 """ if not self.should_trigger_learning(): return False - + try: logger.info(f"为聊天流 {self.chat_name} 触发表达学习") - + # 学习语言风格 learnt_style = await self.learn_and_store(num=25) - + # 更新学习时间 self.last_learning_time = time.time() - + if learnt_style: logger.info(f"聊天流 {self.chat_name} 表达学习完成") return True else: logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果") return False - + except Exception as e: logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}") return False - # def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: - # """ - # 获取指定chat_id的style表达方式(已禁用grammar的获取) - # 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 - # """ - # learnt_style_expressions = [] - - # # 直接从数据库查询 - # style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style")) - # for expr in style_query: - # # 确保create_date存在,如果不存在则使用last_active_time - # create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - # learnt_style_expressions.append( - # { - # "situation": expr.situation, - # "style": expr.style, - # "count": expr.count, - # "last_active_time": expr.last_active_time, - # "source_id": self.chat_id, - # "type": "style", - # "create_date": create_date, - # } - # ) - # return learnt_style_expressions - - - def _apply_global_decay_to_database(self, current_time: float) -> None: """ 对数据库中的所有表达方式应用全局衰减 @@ -345,7 +316,7 @@ class ExpressionLearner: prompt = "learn_style_prompt" current_time = time.time() - + # 获取上次学习时间 random_msg = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, @@ -414,19 +385,20 @@ class ExpressionLearner: init_prompt() + class ExpressionLearnerManager: def __init__(self): self.expression_learners = {} - + self._ensure_expression_directories() self._auto_migrate_json_to_db() self._migrate_old_data_create_date() - + def get_expression_learner(self, chat_id: str) -> ExpressionLearner: if chat_id not in self.expression_learners: self.expression_learners[chat_id] = ExpressionLearner(chat_id) return self.expression_learners[chat_id] - + def _ensure_expression_directories(self): """ 确保表达方式相关的目录结构存在 @@ -445,7 +417,6 @@ class ExpressionLearnerManager: except Exception as e: logger.error(f"创建目录失败 {directory}: {e}") - def _auto_migrate_json_to_db(self): """ 自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。 @@ -564,7 +535,7 @@ class ExpressionLearnerManager: try: deleted_count = self.delete_all_grammar_expressions() logger.info(f"grammar表达删除完成,共删除 {deleted_count} 个表达") - + # 创建done.done2标记文件 with open(done_flag2, "w", encoding="utf-8") as f: f.write("done\n") @@ -598,7 +569,7 @@ class ExpressionLearnerManager: def delete_all_grammar_expressions(self) -> int: """ 检查expression库中所有type为"grammar"的表达并全部删除 - + Returns: int: 删除的grammar表达数量 """ @@ -606,13 +577,13 @@ class ExpressionLearnerManager: # 查询所有type为"grammar"的表达 grammar_expressions = Expression.select().where(Expression.type == "grammar") grammar_count = grammar_expressions.count() - + if grammar_count == 0: logger.info("expression库中没有找到grammar类型的表达") return 0 - + logger.info(f"找到 {grammar_count} 个grammar类型的表达,开始删除...") - + # 删除所有grammar类型的表达 deleted_count = 0 for expr in grammar_expressions: @@ -622,10 +593,10 @@ class ExpressionLearnerManager: except Exception as e: logger.error(f"删除grammar表达失败: {e}") continue - + logger.info(f"成功删除 {deleted_count} 个grammar类型的表达") return deleted_count - + except Exception as e: logger.error(f"删除grammar表达过程中发生错误: {e}") return 0 diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 7d2591ff..7e8355c6 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -4,9 +4,10 @@ import os import pickle import random import asyncio -from typing import List, Dict, Any, TYPE_CHECKING +from typing import List, Dict, Any from src.config.config import global_config from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages from src.person_info.relationship_manager import get_relationship_manager from src.person_info.person_info import Person, get_person_id from src.chat.message_receive.chat_stream import get_chat_manager @@ -17,8 +18,6 @@ from src.chat.utils.chat_message_builder import ( num_new_messages_since, ) -if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger("relationship_builder")