From 698b8355a44d898e7e471c116f68541ff26cfa9a Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 22 Feb 2026 22:33:54 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E6=97=A7=E6=96=87=E4=BB=B6?= =?UTF-8?q?=EF=BC=9B=E8=A1=A5=E5=85=85chat=5Fmanager=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/chat_manager.py | 15 +- src/chat/message_receive/chat_stream.py | 431 ----------------------- 2 files changed, 14 insertions(+), 432 deletions(-) delete mode 100644 src/chat/message_receive/chat_stream.py diff --git a/src/chat/message_receive/chat_manager.py b/src/chat/message_receive/chat_manager.py index f0cc4e80..fa20d876 100644 --- a/src/chat/message_receive/chat_manager.py +++ b/src/chat/message_receive/chat_manager.py @@ -145,12 +145,25 @@ class ChatManager: self.sessions.clear() raise e + async def regularly_save_sessions(self, interval_seconds: int = 300): + """定期将会话记录保存到数据库中 + + Args: + interval_seconds: 保存间隔时间,单位为秒,默认为300秒(5分钟) + """ + while True: + await asyncio.sleep(interval_seconds) + try: + await asyncio.to_thread(self.save_all_sessions) + except Exception as e: + logger.error(f"定期保存会话记录时发生错误: {e}") + def save_all_sessions(self): """将内存中的全部会话记录保存到数据库""" try: for session in self.sessions.values(): self._save_session(session) - logger.info(f"已保存 {len(self.sessions)} 个会话记录到数据库中") + logger.info(f"共 {len(self.sessions)} 个会话已经保存到数据库中") except Exception as e: logger.error(f"保存会话记录到数据库时发生错误: {e}") raise e diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py deleted file mode 100644 index 3e59b802..00000000 --- a/src/chat/message_receive/chat_stream.py +++ /dev/null @@ -1,431 +0,0 @@ -import asyncio -import hashlib -import time -import copy -from datetime import datetime -from typing import Dict, Optional, TYPE_CHECKING -from rich.traceback import install -from maim_message import GroupInfo, UserInfo - -from sqlmodel import select, col -from src.common.logger import get_logger -from src.common.database.database import get_db_session -from src.common.database.database_model import ChatSession - -# 避免循环导入,使用TYPE_CHECKING进行类型提示 -if TYPE_CHECKING: - from .message import MessageRecv - - -install(extra_lines=3) - - -logger = get_logger("chat_stream") - - -class ChatMessageContext: - """聊天消息上下文,存储消息的上下文信息""" - - def __init__(self, message: "MessageRecv"): - self.message = message - - def get_template_name(self) -> Optional[str]: - """获取模板名称""" - if self.message.message_info.template_info and not self.message.message_info.template_info.template_default: - return self.message.message_info.template_info.template_name # type: ignore - return None - - def get_last_message(self) -> "MessageRecv": - """获取最后一条消息""" - return self.message - - def check_types(self, types: list) -> bool: - # sourcery skip: invert-any-all, use-any, use-next - """检查消息类型""" - if not self.message.message_info.format_info.accept_format: # type: ignore - return False - for t in types: - if t not in self.message.message_info.format_info.accept_format: # type: ignore - return False - return True - - def get_priority_mode(self) -> str: - """获取优先级模式""" - return self.message.priority_mode - - def get_priority_info(self) -> Optional[dict]: - """获取优先级信息""" - if hasattr(self.message, "priority_info") and self.message.priority_info: - return self.message.priority_info - return None - - -class ChatStream: - """聊天流对象,存储一个完整的聊天上下文""" - - def __init__( - self, - stream_id: str, - platform: str, - user_info: UserInfo, - group_info: Optional[GroupInfo] = None, - data: Optional[dict] = None, - ): - self.stream_id = stream_id - self.platform = platform - self.user_info = user_info - self.group_info = group_info - self.create_time = data.get("create_time", time.time()) if data else time.time() - self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time - self.saved = False - self.context: Optional[ChatMessageContext] = None - - def to_dict(self) -> dict: - """转换为字典格式""" - return { - "stream_id": self.stream_id, - "platform": self.platform, - "user_info": self.user_info.to_dict() if self.user_info else None, - "group_info": self.group_info.to_dict() if self.group_info else None, - "create_time": self.create_time, - "last_active_time": self.last_active_time, - } - - @classmethod - def from_dict(cls, data: dict) -> "ChatStream": - """从字典创建实例""" - user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None - group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None - - if user_info is None: - raise ValueError("user_info is required to build ChatStream") - - return cls( - stream_id=data["stream_id"], - platform=data["platform"], - user_info=user_info, - group_info=group_info, - data=data, - ) - - def update_active_time(self): - """更新最后活跃时间""" - self.last_active_time = time.time() - self.saved = False - - def set_context(self, message: "MessageRecv"): - """设置聊天消息上下文""" - self.context = ChatMessageContext(message) - - -class ChatManager: - """聊天管理器,管理所有聊天流""" - - _instance = None - _initialized = False - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - if not self._initialized: - self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream - self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message - get_db_session() - - self._initialized = True - # 在事件循环中启动初始化 - # asyncio.create_task(self._initialize()) - # # 启动自动保存任务 - # asyncio.create_task(self._auto_save_task()) - - async def _initialize(self): - """异步初始化""" - try: - await self.load_all_streams() - logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流") - except Exception as e: - logger.error(f"聊天管理器启动失败: {str(e)}") - - async def _auto_save_task(self): - """定期自动保存所有聊天流""" - while True: - await asyncio.sleep(300) # 每5分钟保存一次 - try: - await self._save_all_streams() - logger.info("聊天流自动保存完成") - except Exception as e: - logger.error(f"聊天流自动保存失败: {str(e)}") - - def register_message(self, message: "MessageRecv"): - """注册消息到聊天流""" - platform = message.message_info.platform or "" - if not platform: - raise ValueError("platform is required for ChatStream") - if message.message_info.user_info is None and message.message_info.group_info is None: - raise ValueError("user_info or group_info is required for ChatStream") - stream_id = self._generate_stream_id( - platform, - message.message_info.user_info, - message.message_info.group_info, - ) - self.last_messages[stream_id] = message - # logger.debug(f"注册消息到聊天流: {stream_id}") - - @staticmethod - def _generate_stream_id( - platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None - ) -> str: - """生成聊天流唯一ID""" - if not user_info and not group_info: - raise ValueError("用户信息或群组信息必须提供") - if group_info is None and user_info is None: - raise ValueError("用户信息或群组信息必须提供") - - if group_info: - # 组合关键信息 - components = [platform, str(group_info.group_id)] - else: - if user_info is None: - raise ValueError("用户信息或群组信息必须提供") - if user_info.user_id is None: - raise ValueError("user_id is required for private stream") - components = [platform, str(user_info.user_id), "private"] - - # 使用MD5生成唯一ID - key = "_".join(components) - return hashlib.md5(key.encode()).hexdigest() - - def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str: - """获取聊天流ID""" - components = [platform, id] if is_group else [platform, id, "private"] - key = "_".join(components) - return hashlib.md5(key.encode()).hexdigest() - - async def get_or_create_stream( - self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None - ) -> ChatStream: - """获取或创建聊天流 - - Args: - platform: 平台标识 - user_info: 用户信息 - group_info: 群组信息(可选) - - Returns: - ChatStream: 聊天流对象 - """ - # 生成stream_id - try: - stream_id = self._generate_stream_id(platform, user_info, group_info) - - # 检查内存中是否存在 - if stream_id in self.streams: - stream = self.streams[stream_id] - - # 更新用户信息和群组信息 - stream.update_active_time() - stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 - if user_info and user_info.platform and user_info.user_id: - stream.user_info = user_info - if group_info: - stream.group_info = group_info - from .message import MessageRecv # 延迟导入,避免循环引用 - - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): - stream.set_context(self.last_messages[stream_id]) - else: - logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") - return stream - - # 检查数据库中是否存在 - def _db_find_stream_sync(s_id: str): - with get_db_session() as session: - statement = select(ChatSession).where(col(ChatSession.session_id) == s_id) - return session.exec(statement).first() - - model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) - - if model_instance: - # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 - user_info_data = { - "platform": model_instance.platform, - "user_id": model_instance.user_id, - "user_nickname": "", - "user_cardname": "", - } - group_info_data = None - if model_instance.group_id: - group_info_data = { - "platform": model_instance.platform, - "group_id": model_instance.group_id, - "group_name": "", - } - - data_for_from_dict = { - "stream_id": model_instance.session_id, - "platform": model_instance.platform, - "user_info": user_info_data, - "group_info": group_info_data, - "create_time": model_instance.created_timestamp.timestamp(), - "last_active_time": model_instance.last_active_timestamp.timestamp(), - } - stream = ChatStream.from_dict(data_for_from_dict) - # 更新用户信息和群组信息 - stream.user_info = user_info - if group_info: - stream.group_info = group_info - stream.update_active_time() - else: - # 创建新的聊天流 - stream = ChatStream( - stream_id=stream_id, - platform=platform, - user_info=user_info, - group_info=group_info, - ) - except Exception as e: - logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) - raise e - - stream = copy.deepcopy(stream) - from .message import MessageRecv # 延迟导入,避免循环引用 - - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): - stream.set_context(self.last_messages[stream_id]) - else: - logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") - # 保存到内存和数据库 - self.streams[stream_id] = stream - await self._save_stream(stream) - return stream - - def get_stream(self, stream_id: str) -> Optional[ChatStream]: - """通过stream_id获取聊天流""" - stream = self.streams.get(stream_id) - if not stream: - return None - if stream_id in self.last_messages: - stream.set_context(self.last_messages[stream_id]) - return stream - - def get_stream_by_info( - self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None - ) -> Optional[ChatStream]: - """通过信息获取聊天流""" - stream_id = self._generate_stream_id(platform, user_info, group_info) - return self.streams.get(stream_id) - - def get_stream_name(self, stream_id: str) -> Optional[str]: - """根据 stream_id 获取聊天流名称""" - stream = self.get_stream(stream_id) - if not stream: - return None - - if stream.group_info and stream.group_info.group_name: - return stream.group_info.group_name - elif stream.user_info and stream.user_info.user_nickname: - return f"{stream.user_info.user_nickname}的私聊" - else: - return None - - @staticmethod - async def _save_stream(stream: ChatStream): - """保存聊天流到数据库""" - if stream.saved: - return - stream_data_dict = stream.to_dict() - - def _db_save_stream_sync(s_data_dict: dict): - user_info_d = s_data_dict.get("user_info") - group_info_d = s_data_dict.get("group_info") - - with get_db_session() as session: - statement = select(ChatSession).where(col(ChatSession.session_id) == s_data_dict["stream_id"]) - record = session.exec(statement).first() - if record is None: - record = ChatSession( - session_id=s_data_dict["stream_id"], - platform=s_data_dict["platform"], - user_id=user_info_d["user_id"] if user_info_d else None, - group_id=group_info_d["group_id"] if group_info_d else None, - created_timestamp=datetime.fromtimestamp(s_data_dict["create_time"]), - last_active_timestamp=datetime.fromtimestamp(s_data_dict["last_active_time"]), - ) - session.add(record) - return - - record.platform = s_data_dict["platform"] - record.user_id = user_info_d["user_id"] if user_info_d else None - record.group_id = group_info_d["group_id"] if group_info_d else None - record.created_timestamp = datetime.fromtimestamp(s_data_dict["create_time"]) - record.last_active_timestamp = datetime.fromtimestamp(s_data_dict["last_active_time"]) - - try: - await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) - stream.saved = True - except Exception as e: - logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) - - async def _save_all_streams(self): - """保存所有聊天流""" - for stream in self.streams.values(): - await self._save_stream(stream) - - async def load_all_streams(self): - """从数据库加载所有聊天流""" - logger.info("正在从数据库加载所有聊天流") - - def _db_load_all_streams_sync(): - loaded_streams_data = [] - with get_db_session() as session: - statement = select(ChatSession) - for model_instance in session.exec(statement).all(): - user_info_data = { - "platform": model_instance.platform, - "user_id": model_instance.user_id or "", - "user_nickname": "", - "user_cardname": "", - } - group_info_data = None - if model_instance.group_id: - group_info_data = { - "platform": model_instance.platform, - "group_id": model_instance.group_id, - "group_name": "", - } - - data_for_from_dict = { - "stream_id": model_instance.session_id, - "platform": model_instance.platform, - "user_info": user_info_data, - "group_info": group_info_data, - "create_time": model_instance.created_timestamp.timestamp(), - "last_active_time": model_instance.last_active_timestamp.timestamp(), - } - loaded_streams_data.append(data_for_from_dict) - return loaded_streams_data - - try: - all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync) - self.streams.clear() - for data in all_streams_data_list: - stream = ChatStream.from_dict(data) - stream.saved = True - self.streams[stream.stream_id] = stream - if stream.stream_id in self.last_messages: - stream.set_context(self.last_messages[stream.stream_id]) - except Exception as e: - logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) - - -chat_manager = None - - -def get_chat_manager(): - global chat_manager - if chat_manager is None: - chat_manager = ChatManager() - return chat_manager