From 474010a0270443f5fd6201c49e1e3c48f73e5d0b Mon Sep 17 00:00:00 2001 From: MySxan Date: Tue, 2 Dec 2025 06:45:13 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8DChatManager=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E5=88=9B=E5=BB=BA=E5=90=8E=E5=8F=B0=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E5=8F=AF=E8=83=BD=E5=AF=BC=E8=87=B4=E7=9A=84=E5=86=85=E5=AD=98?= =?UTF-8?q?=E6=B3=84=E6=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/chat_stream.py | 121 +++++++++++++++++------- 1 file changed, 88 insertions(+), 33 deletions(-) diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 850a5033..92288da7 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -134,6 +134,12 @@ class ChatManager: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.last_messages: OrderedDict[str, "MessageRecv"] = OrderedDict() # stream_id -> last_message self.last_message_timestamps: Dict[str, float] = {} # stream_id -> timestamp + + # 保存 task handler + self._cleanup_task: Optional[asyncio.Task] = None + self._auto_save_task: Optional[asyncio.Task] = None + self._initialize_task: Optional[asyncio.Task] = None + try: db.connect(reuse_if_open=True) # 确保 ChatStreams 表存在 @@ -142,51 +148,98 @@ class ChatManager: logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") self._initialized = True - # 在事件循环中启动初始化 - # asyncio.create_task(self._initialize()) - # # 启动自动保存任务 - # asyncio.create_task(self._auto_save_task()) - # # 启动 TTL 清理任务 - asyncio.create_task(self._cleanup_expired_messages()) + def start(self): + """启动后台任务""" + if self._cleanup_task is not None or self._auto_save_task is not None: + logger.warning("后台任务已经在运行,跳过重复启动") + return + + # 创建并保存 task handler + self._cleanup_task = asyncio.create_task(self._cleanup_expired_messages()) + # self._auto_save_task = asyncio.create_task(self._auto_save_task_loop()) + # self._initialize_task = asyncio.create_task(self._initialize()) + logger.info("ChatManager 后台任务已启动") + + async def shutdown(self): + """安全关闭所有后台任务""" + + # 取消所有任务 + tasks_to_cancel = [] + if self._cleanup_task: + tasks_to_cancel.append(self._cleanup_task) + if self._auto_save_task: + tasks_to_cancel.append(self._auto_save_task) + if self._initialize_task: + tasks_to_cancel.append(self._initialize_task) + + for task in tasks_to_cancel: + task.cancel() + if tasks_to_cancel: + await asyncio.gather(*tasks_to_cancel, return_exceptions=True) + + # 清空 task handler + self._cleanup_task = None + self._auto_save_task = None + self._initialize_task = None + + logger.info("ChatManager 后台任务已安全关闭") + async def _initialize(self): """异步初始化""" try: await self.load_all_streams() logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流") + except asyncio.CancelledError: + logger.info("初始化任务已取消") + raise except Exception as e: logger.error(f"聊天管理器启动失败: {str(e)}") - async def _auto_save_task(self): + async def _auto_save_task_loop(self): """定期自动保存所有聊天流""" - while True: - await asyncio.sleep(300) # 每5分钟保存一次 - try: - await self._save_all_streams() - logger.info("聊天流自动保存完成") - except Exception as e: - logger.error(f"聊天流自动保存失败: {str(e)}") + try: + while True: + await asyncio.sleep(300) # 每5分钟保存一次 + try: + await self._save_all_streams() + logger.info("聊天流自动保存完成") + except Exception as e: + logger.error(f"聊天流自动保存失败: {str(e)}") + except asyncio.CancelledError: + logger.info("自动保存任务已取消") + return async def _cleanup_expired_messages(self): """定期清理过期的 last_messages""" - while True: - await asyncio.sleep(300) # 每5分钟清理一次 - try: - current_time = time.time() - expired_keys = [] - - for stream_id, timestamp in self.last_message_timestamps.items(): - if current_time - timestamp > LAST_MESSAGE_TTL: - expired_keys.append(stream_id) - - for key in expired_keys: - self.last_messages.pop(key, None) - self.last_message_timestamps.pop(key, None) - - if expired_keys: - logger.info(f"清理了 {len(expired_keys)} 条过期的 last_messages") - except Exception as e: - logger.error(f"清理过期消息失败: {str(e)}") + try: + while True: + await asyncio.sleep(300) # 每5分钟清理一次 + self._cleanup_once() + except asyncio.CancelledError: + logger.info("清理任务已取消") + return + except Exception as e: + logger.error(f"清理过期消息失败: {str(e)}") + + def _cleanup_once(self): + """执行一次过期消息清理""" + try: + current_time = time.time() + expired_keys = [] + + for stream_id, timestamp in self.last_message_timestamps.items(): + if current_time - timestamp > LAST_MESSAGE_TTL: + expired_keys.append(stream_id) + + for key in expired_keys: + self.last_messages.pop(key, None) + self.last_message_timestamps.pop(key, None) + + if expired_keys: + logger.info(f"清理了 {len(expired_keys)} 条过期的 last_messages") + except Exception as e: + logger.error(f"执行清理失败: {str(e)}") def register_message(self, message: "MessageRecv"): """注册消息到聊天流""" @@ -441,8 +494,10 @@ class ChatManager: chat_manager = None -def get_chat_manager(): +def get_chat_manager(auto_start: bool = True): global chat_manager if chat_manager is None: chat_manager = ChatManager() + if auto_start: + chat_manager.start() return chat_manager