fix: 修复ChatManager重复创建后台任务可能导致的内存泄漏

pull/1396/head
MySxan 2025-12-02 06:45:13 +00:00
parent 940204c072
commit 474010a027
1 changed files with 88 additions and 33 deletions

View File

@ -134,6 +134,12 @@ class ChatManager:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: OrderedDict[str, "MessageRecv"] = OrderedDict() # stream_id -> last_message self.last_messages: OrderedDict[str, "MessageRecv"] = OrderedDict() # stream_id -> last_message
self.last_message_timestamps: Dict[str, float] = {} # stream_id -> timestamp self.last_message_timestamps: Dict[str, float] = {} # stream_id -> timestamp
# 保存 task handler
self._cleanup_task: Optional[asyncio.Task] = None
self._auto_save_task: Optional[asyncio.Task] = None
self._initialize_task: Optional[asyncio.Task] = None
try: try:
db.connect(reuse_if_open=True) db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在 # 确保 ChatStreams 表存在
@ -142,51 +148,98 @@ class ChatManager:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
self._initialized = True self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize()) def start(self):
# # 启动自动保存任务 """启动后台任务"""
# asyncio.create_task(self._auto_save_task()) if self._cleanup_task is not None or self._auto_save_task is not None:
# # 启动 TTL 清理任务 logger.warning("后台任务已经在运行,跳过重复启动")
asyncio.create_task(self._cleanup_expired_messages()) return
# 创建并保存 task handler
self._cleanup_task = asyncio.create_task(self._cleanup_expired_messages())
# self._auto_save_task = asyncio.create_task(self._auto_save_task_loop())
# self._initialize_task = asyncio.create_task(self._initialize())
logger.info("ChatManager 后台任务已启动")
async def shutdown(self):
"""安全关闭所有后台任务"""
# 取消所有任务
tasks_to_cancel = []
if self._cleanup_task:
tasks_to_cancel.append(self._cleanup_task)
if self._auto_save_task:
tasks_to_cancel.append(self._auto_save_task)
if self._initialize_task:
tasks_to_cancel.append(self._initialize_task)
for task in tasks_to_cancel:
task.cancel()
if tasks_to_cancel:
await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
# 清空 task handler
self._cleanup_task = None
self._auto_save_task = None
self._initialize_task = None
logger.info("ChatManager 后台任务已安全关闭")
async def _initialize(self): async def _initialize(self):
"""异步初始化""" """异步初始化"""
try: try:
await self.load_all_streams() await self.load_all_streams()
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流") logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except asyncio.CancelledError:
logger.info("初始化任务已取消")
raise
except Exception as e: except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}") logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self): async def _auto_save_task_loop(self):
"""定期自动保存所有聊天流""" """定期自动保存所有聊天流"""
while True: try:
await asyncio.sleep(300) # 每5分钟保存一次 while True:
try: await asyncio.sleep(300) # 每5分钟保存一次
await self._save_all_streams() try:
logger.info("聊天流自动保存完成") await self._save_all_streams()
except Exception as e: logger.info("聊天流自动保存完成")
logger.error(f"聊天流自动保存失败: {str(e)}") except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
except asyncio.CancelledError:
logger.info("自动保存任务已取消")
return
async def _cleanup_expired_messages(self): async def _cleanup_expired_messages(self):
"""定期清理过期的 last_messages""" """定期清理过期的 last_messages"""
while True: try:
await asyncio.sleep(300) # 每5分钟清理一次 while True:
try: await asyncio.sleep(300) # 每5分钟清理一次
current_time = time.time() self._cleanup_once()
expired_keys = [] except asyncio.CancelledError:
logger.info("清理任务已取消")
return
except Exception as e:
logger.error(f"清理过期消息失败: {str(e)}")
for stream_id, timestamp in self.last_message_timestamps.items(): def _cleanup_once(self):
if current_time - timestamp > LAST_MESSAGE_TTL: """执行一次过期消息清理"""
expired_keys.append(stream_id) try:
current_time = time.time()
expired_keys = []
for key in expired_keys: for stream_id, timestamp in self.last_message_timestamps.items():
self.last_messages.pop(key, None) if current_time - timestamp > LAST_MESSAGE_TTL:
self.last_message_timestamps.pop(key, None) expired_keys.append(stream_id)
if expired_keys: for key in expired_keys:
logger.info(f"清理了 {len(expired_keys)} 条过期的 last_messages") self.last_messages.pop(key, None)
except Exception as e: self.last_message_timestamps.pop(key, None)
logger.error(f"清理过期消息失败: {str(e)}")
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 条过期的 last_messages")
except Exception as e:
logger.error(f"执行清理失败: {str(e)}")
def register_message(self, message: "MessageRecv"): def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流""" """注册消息到聊天流"""
@ -441,8 +494,10 @@ class ChatManager:
chat_manager = None chat_manager = None
def get_chat_manager(): def get_chat_manager(auto_start: bool = True):
global chat_manager global chat_manager
if chat_manager is None: if chat_manager is None:
chat_manager = ChatManager() chat_manager = ChatManager()
if auto_start:
chat_manager.start()
return chat_manager return chat_manager