diff --git a/src/common/database.py b/src/common/database.py index ee0ead0b..17a71709 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -52,6 +52,20 @@ class DBWrapper: def __getitem__(self, key): return get_db()[key] +def close_db(): + """关闭全局 MongoDB 客户端连接。""" + global _client, _db + if _client: + try: + _client.close() + # print(f"数据库连接已由进程 {os.getpid()} 关闭。") # 可选:添加日志 + except Exception as e: + # print(f"关闭数据库连接时出错: {e}") # 可选:记录关闭错误 + pass # 关闭期间避免程序崩溃 + finally: + # 重置全局变量,以便下次 get_db 能重新连接(如果需要) + _client = None + _db = None # 全局数据库访问点 db: Database = DBWrapper() diff --git a/src/plugins/group_nickname/nickname_mapper.py b/src/plugins/group_nickname/nickname_mapper.py index b5e258a4..3cf2687f 100644 --- a/src/plugins/group_nickname/nickname_mapper.py +++ b/src/plugins/group_nickname/nickname_mapper.py @@ -1,32 +1,17 @@ import json from typing import Dict, Any, Optional +import asyncio # 可能需要用于锁 + from src.common.logger_manager import get_logger from src.plugins.models.utils_model import LLMRequest # 从全局配置导入 from src.config.config import global_config - logger = get_logger("nickname_mapper") -llm_mapper: Optional[LLMRequest] = None -if global_config.ENABLE_NICKNAME_MAPPING: # 使用全局开关 - try: - # 从全局配置获取模型设置 - model_config = global_config.llm_nickname_mapping - if not model_config or not model_config.get("name"): - logger.error("在全局配置中未找到有效的 'llm_nickname_mapping' 配置或缺少 'name' 字段。") - else: - llm_mapper = LLMRequest( # <-- LLM 初始化 - model=global_config.llm_nickname_mapping, - temperature=global_config.llm_nickname_mapping["temp"], - max_tokens=256, - request_type="nickname_mapping", - ) - logger.info("绰号映射 LLM 初始化成功 (使用全局配置)。") - - except Exception as e: - logger.error(f"使用全局配置初始化绰号映射 LLM 失败: {e}", exc_info=True) - llm_mapper = None +if global_config.ENABLE_NICKNAME_MAPPING: + _llm_mapper_instance: Optional[LLMRequest] = None + _llm_mapper_init_lock = asyncio.Lock() # 使用异步锁,因为下面的函数是 async def _build_mapping_prompt(chat_history_str: str, bot_reply: str, user_name_map: Dict[str, str]) -> str: user_list_str = "\n".join([f"- {uid}: {name}" for uid, name in user_name_map.items()]) @@ -68,6 +53,39 @@ def _build_mapping_prompt(chat_history_str: str, bot_reply: str, user_name_map: """ return prompt +async def _get_or_initialize_llm_mapper() -> Optional[LLMRequest]: + """获取或在需要时初始化绰号映射 LLM 的单例。""" + global _llm_mapper_instance + # 双重检查锁定模式(适用于 asyncio) + if _llm_mapper_instance is None: + async with _llm_mapper_init_lock: + # 再次检查,防止在等待锁时其他协程已完成初始化 + if _llm_mapper_instance is None: + logger.info("首次调用,尝试初始化绰号映射 LLM...") + if not global_config.ENABLE_NICKNAME_MAPPING: + logger.info("绰号映射功能已禁用,LLM 初始化跳过。") + # 可以选择返回 None 或者设置一个特殊标记 + # 这里我们假设如果禁用,就不应该尝试使用,所以保持 None + # _llm_mapper_instance = None # 已经是 None + else: + try: + model_config = global_config.llm_nickname_mapping + if not model_config or not model_config.get("name"): + logger.error("在全局配置中未找到有效的 'llm_nickname_mapping' 配置或缺少 'name' 字段。") + # 初始化失败,保持 None + else: + _llm_mapper_instance = LLMRequest( + model=global_config.llm_nickname_mapping, + temperature=global_config.llm_nickname_mapping["temp"], + max_tokens=256, + request_type="nickname_mapping", + ) + logger.info("绰号映射 LLM 初始化成功。") + except Exception as e: + logger.error(f"初始化绰号映射 LLM 失败: {e}", exc_info=True) + # 初始化失败,保持 None + _llm_mapper_instance = None # 确保显式设置为 None + return _llm_mapper_instance async def analyze_chat_for_nicknames( chat_history_str: str, @@ -83,9 +101,7 @@ async def analyze_chat_for_nicknames( logger.debug("绰号映射功能已禁用。") return {"is_exist": False} - if llm_mapper is None: - logger.error("绰号映射 LLM 未初始化。无法执行分析。") - return {"is_exist": False} + llm_mapper = await _get_or_initialize_llm_mapper() prompt = _build_mapping_prompt(chat_history_str, bot_reply, user_name_map) logger.debug(f"构建的绰号映射 Prompt:\n{prompt}") diff --git a/src/plugins/group_nickname/nickname_processor.py b/src/plugins/group_nickname/nickname_processor.py index 669b3094..2f01492e 100644 --- a/src/plugins/group_nickname/nickname_processor.py +++ b/src/plugins/group_nickname/nickname_processor.py @@ -1,3 +1,4 @@ +import os import asyncio import traceback from multiprocessing import Process, Queue as mpQueue, Event @@ -9,7 +10,7 @@ from pymongo.errors import OperationFailure from src.common.logger_manager import get_logger # 导入日志管理器 from src.config.config import global_config # 导入全局配置 from .nickname_mapper import analyze_chat_for_nicknames # 导入绰号分析函数 -from src.common.database import db # 导入数据库初始化和关闭函数 +from src.common.database import get_db, close_db logger = get_logger("nickname_processor") # 获取日志记录器实例 # --- 运行时状态 (用于安全停止进程) --- @@ -20,48 +21,37 @@ mongo_client: Optional[MongoClient] = None # MongoDB 客户端实例 person_info_collection = None # 用户信息集合对象 # --- 数据库更新逻辑 (使用推荐的新结构) --- -async def update_nickname_counts(group_id: str, nickname_map: Dict[str, str]): +async def update_nickname_counts(group_id: str, nickname_map: Dict[str, str], current_db): """ 更新数据库中用户的群组绰号计数。 - 使用新的数据结构: - { - "user_id": 12345, - "group_nicknames": [ # <--- 字段名统一为 group_nicknames - { - "group_id": "群号1", - "nicknames": [ { "name": "绰号A", "count": 5 }, ... ] - }, ... - ] - } + 使用传入的数据库实例。 """ - person_info_collection = db.person_info # 获取集合对象 + # 从传入的 db 实例获取 collection + person_info_collection = current_db.person_info # <--- 使用 current_db if not nickname_map: logger.debug("提供的用于更新的绰号映射为空。") return - logger.info(f"尝试更新群组 '{group_id}' 的绰号计数 (新结构),映射为: {nickname_map}") + logger.info(f"尝试更新群组 '{group_id}' 的绰号计数,映射为: {nickname_map}") - for user_id_str, nickname in nickname_map.items(): # user_id 从 map 中取出是 str + for user_id_str, nickname in nickname_map.items(): if not user_id_str or not nickname: logger.warning(f"跳过绰号映射中的无效条目: user_id='{user_id_str}', nickname='{nickname}'") continue - group_id_str = str(group_id) # 确保 group_id 是字符串 + group_id_str = str(group_id) try: - # 假设数据库中存储的用户ID是整数类型,如果不是请移除 int() user_id_int = int(user_id_str) except ValueError: logger.warning(f"无效的用户ID格式: '{user_id_str}',跳过。") continue try: - # 步骤 1: 确保用户文档存在,且有 group_nicknames 字段 (如果不存在则添加空数组) - # 注意:这里不再使用 $setOnInsert 添加 group_nicknames,因为 $addToSet 或 $push 在字段不存在时会自动创建。 - # upsert=True 确保用户文档存在。 + # 确保后续所有的数据库操作都使用从 current_db 获取的 person_info_collection person_info_collection.update_one( {"user_id": user_id_int}, - {"$setOnInsert": {"user_id": user_id_int}}, # 确保 upsert 时 user_id 被正确设置 + {"$setOnInsert": {"user_id": user_id_int}}, upsert=True ) # 确保 group_nicknames 字段存在且为数组 (如果不存在则创建) @@ -71,7 +61,7 @@ async def update_nickname_counts(group_id: str, nickname_map: Dict[str, str]): ) - # 步骤 2: 尝试直接增加现有绰号的计数 + # 尝试直接增加现有绰号的计数 # 条件:用户存在,且 group_nicknames 数组中存在一个元素其 group_id 匹配,且该元素的 nicknames 数组中存在一个元素的 name 匹配 update_result = person_info_collection.update_one( { @@ -93,14 +83,14 @@ async def update_nickname_counts(group_id: str, nickname_map: Dict[str, str]): logger.debug(f"用户 '{user_id_str}' 在群组 '{group_id_str}' 中的绰号 '{nickname}' 计数已增加。") continue # 处理完成,进行下一次循环 - # 步骤 3: 如果步骤 2 未修改任何内容,尝试将新绰号添加到现有群组的 nicknames 数组中 + # 如果未修改任何内容,尝试将新绰号添加到现有群组的 nicknames 数组中 # 条件:用户存在,且 group_nicknames 数组中存在一个元素其 group_id 匹配 update_result = person_info_collection.update_one( { "user_id": user_id_int, - "group_nicknames.group_id": group_id_str # <--- 确保使用 group_nicknames + "group_nicknames.group_id": group_id_str }, - { # <--- 确保使用 group_nicknames + { "$push": {"group_nicknames.$[group].nicknames": {"name": nickname, "count": 1}} }, array_filters=[ @@ -112,15 +102,15 @@ async def update_nickname_counts(group_id: str, nickname_map: Dict[str, str]): logger.debug(f"为用户 '{user_id_str}' 在群组 '{group_id_str}' 中添加了新绰号 '{nickname}',计数为 1。") continue # 处理完成,进行下一次循环 - # 步骤 4: 如果步骤 2 和 3 都未修改任何内容,说明群组条目本身可能不存在于 group_nicknames 数组中,尝试添加新的群组条目 + # 如果未修改任何内容,说明群组条目本身可能不存在于 group_nicknames 数组中,尝试添加新的群组条目 # 条件:用户存在,且 group_nicknames 数组中 *不包含* 指定 group_id 的元素 update_result = person_info_collection.update_one( { "user_id": user_id_int, - "group_nicknames.group_id": {"$ne": group_id_str} # <--- 检查 group_id 是否不存在 + "group_nicknames.group_id": {"$ne": group_id_str} }, { - "$push": { # <--- 确保使用 group_nicknames + "$push": { "group_nicknames": { "group_id": group_id_str, "nicknames": [{"name": nickname, "count": 1}] @@ -172,49 +162,84 @@ async def add_to_nickname_queue( logger.warning(f"无法将项目添加到绰号队列(可能已满): {e}", exc_info=True) -async def _nickname_processing_loop(queue: mpQueue, stop_event): - """独立进程中的主循环,处理队列任务。""" - - logger.info("绰号处理循环已启动。") +async def _nickname_processing_loop(queue: mpQueue, stop_event, current_db): + """独立进程中的主循环,处理队列任务,使用传入的数据库连接。""" + pid = os.getpid() # 获取进程ID用于日志 + logger.info(f"绰号处理循环已启动 (PID: {pid})。 使用数据库: {current_db.name}") while not stop_event.is_set(): try: if not queue.empty(): + # 或者使用 queue.get(timeout=...) 来避免忙等待,并处理 Empty 异常 item = queue.get() if isinstance(item, tuple) and len(item) == 4: chat_history_str, bot_reply, group_id, user_name_map = item - logger.debug(f"正在处理群组 {group_id} 的绰号映射任务...") + logger.debug(f"(PID: {pid}) 正在处理群组 {group_id} 的绰号映射任务...") analysis_result = await analyze_chat_for_nicknames(chat_history_str, bot_reply, user_name_map) if analysis_result.get("is_exist") and analysis_result.get("data"): - await update_nickname_counts(group_id, analysis_result["data"]) + # 将数据库实例传递下去 + await update_nickname_counts(group_id, analysis_result["data"], current_db) else: - logger.warning(f"从队列接收到意外的项目类型: {type(item)}") - - await asyncio.sleep(5) + logger.warning(f"(PID: {pid}) 从队列接收到意外的项目类型: {type(item)}") + # 处理完一个任务后短暂休眠,避免CPU空转 + await asyncio.sleep(0.1) else: + # 队列为空时,休眠更长时间 await asyncio.sleep(global_config.NICKNAME_PROCESS_SLEEP_INTERVAL) except asyncio.CancelledError: - logger.info("绰号处理循环已取消。") + logger.info(f"绰号处理循环已取消 (PID: {pid})。") break except Exception as e: - logger.error(f"绰号处理循环出错: {e}\n{traceback.format_exc()}") - await asyncio.sleep(5) + logger.error(f"(PID: {pid}) 绰号处理循环出错: {e}\n{traceback.format_exc()}") + await asyncio.sleep(5) # 出错后等待一段时间 - logger.info("绰号处理循环已结束。") + logger.info(f"绰号处理循环已结束 (PID: {pid})。") def _run_processor_process(queue: mpQueue, stop_event): - """进程启动函数,运行异步循环。""" + """进程启动函数,管理自己的数据库连接并运行异步循环。""" + db_instance = None # 初始化数据库实例变量 + loop = None + pid = os.getpid() + logger.info(f"绰号处理器进程启动中 (PID: {pid})...") + try: + # 调用 get_db() 会触发此进程的懒加载逻辑 + logger.info(f"子进程 (PID: {pid}) - 即将调用 get_db()") + db_instance = get_db() + logger.info(f"子进程 (PID: {pid}) - 完成 get_db(), 连接到数据库: {db_instance.name}") + logger.info(f"绰号处理器进程 (PID: {pid}) 已获取数据库连接: {db_instance.name}") + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - loop.run_until_complete(_nickname_processing_loop(queue, stop_event)) - loop.close() + # 将获取到的数据库实例传递给异步循环 + logger.info(f"子进程 (PID: {pid}) - 即将运行 _nickname_processing_loop") + loop.run_until_complete(_nickname_processing_loop(queue, stop_event, db_instance)) + logger.info(f"子进程 (PID: {pid}) - 完成 _nickname_processing_loop") + except Exception as e: - logger.error(f"运行绰号处理器进程时出错: {e}", exc_info=True) + logger.error(f"(PID: {pid}) 运行绰号处理器进程时出错: {e}", exc_info=True) + finally: + # --- 清理工作 --- + if loop: + try: + # 关闭事件循环 + if loop.is_running(): + loop.stop() # 先停止 + loop.close() + logger.info(f"(PID: {pid}) asyncio 事件循环已关闭。") + except Exception as loop_close_err: + logger.error(f"(PID: {pid}) 关闭 asyncio 事件循环时出错: {loop_close_err}", exc_info=True) + + try: + close_db() + logger.info(f"(PID: {pid}) 数据库连接已通过 close_db() 关闭。") + except Exception as db_close_err: + logger.error(f"(PID: {pid}) 关闭数据库连接时出错: {db_close_err}", exc_info=True) + logger.info(f"绰号处理器进程已结束 (PID: {pid})。") def start_nickname_processor(): """启动绰号映射处理进程。""" diff --git a/src/plugins/heartFC_chat/heartFC_chat.py b/src/plugins/heartFC_chat/heartFC_chat.py index 0ad2e061..a707b3a7 100644 --- a/src/plugins/heartFC_chat/heartFC_chat.py +++ b/src/plugins/heartFC_chat/heartFC_chat.py @@ -737,7 +737,7 @@ class HeartFChatting: # 4. 获取当前上下文中涉及的用户 ID 及其已知名称 user_ids_in_history = set() for msg in history_messages: - sender_id = msg.get('sender_id') + sender_id = msg["user_info"].get('user_id') if sender_id: user_ids_in_history.add(str(sender_id)) # 确保是字符串 @@ -747,13 +747,6 @@ class HeartFChatting: try: names_data = await relationship_manager.get_person_names_batch(platform, list(user_ids_in_history)) - except AttributeError: - logger.warning("relationship_manager does not have get_person_names_batch method. Falling back to single lookups.") - names_data = {} - for user_id in user_ids_in_history: - name = await relationship_manager.get_person_name(platform, user_id) - if name: - names_data[user_id] = name except Exception as e: logger.error(f"Error getting person names: {e}", exc_info=True) names_data = {} # 出错时置空 diff --git a/src/plugins/person_info/relationship_manager.py b/src/plugins/person_info/relationship_manager.py index 846f3408..fe229b2c 100644 --- a/src/plugins/person_info/relationship_manager.py +++ b/src/plugins/person_info/relationship_manager.py @@ -118,57 +118,69 @@ class RelationshipManager: return names_map # --- 结束修改 --- - # --- [修改] 使用全局 db 对象进行查询 --- @staticmethod async def get_users_group_nicknames(platform: str, user_ids: List[str], group_id: str) -> Dict[str, List[Dict[str, int]]]: """ 批量获取多个用户在指定群组的绰号信息。 + + Args: + platform (str): 平台名称。 + user_ids (List[str]): 用户 ID 列表。 + group_id (str): 群组 ID。 + + Returns: + Dict[str, List[Dict[str, int]]]: 映射 {person_name: [{"绰号A": 次数}, ...]} """ if not user_ids or not group_id: return {} person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids] nicknames_data = {} - group_id_str = str(group_id) + group_id_str = str(group_id) # 确保 group_id 是字符串 try: - # --- 修改点:直接使用 db.person_info.find --- - # !!! 确保 'person_info' 是正确的集合名称 !!! + # 查询包含目标 person_id 的文档 cursor = db.person_info.find( - { - "person_id": {"$in": person_ids}, - "group_nickname": {"$elemMatch": {group_id_str: {"$exists": True}}} - }, - {"_id": 0, "person_id": 1, "person_name": 1, "group_nickname": 1} + {"person_id": {"$in": person_ids}}, + {"_id": 0, "person_id": 1, "person_name": 1, "group_nicknames": 1} # 查询所需字段 ) - # --- 结束修改点 --- - # 同样,假设同步迭代可行 - for doc in cursor: # 如果 db 是 motor,这里需要改为 async for + # 假设同步迭代可行 + for doc in cursor: person_name = doc.get("person_name") if not person_name: - continue + continue # 跳过没有 person_name 的用户 - group_nicknames_list = doc.get("group_nickname", []) - user_group_nicknames = [] + group_nicknames_list = doc.get("group_nicknames", []) # 获取 group_nicknames 数组 + target_group_nicknames = [] # 存储目标群组的绰号列表 + + # 遍历 group_nicknames 数组,查找匹配的 group_id for group_entry in group_nicknames_list: - if group_id_str in group_entry and isinstance(group_entry[group_id_str], list): - user_group_nicknames = group_entry[group_id_str] - break + # 确保 group_entry 是字典且包含 group_id 键 + if isinstance(group_entry, dict) and group_entry.get("group_id") == group_id_str: + # 提取 nicknames 列表 + nicknames_raw = group_entry.get("nicknames", []) + if isinstance(nicknames_raw, list): + target_group_nicknames = nicknames_raw + break # 找到匹配的 group_id 后即可退出内层循环 - if user_group_nicknames: - valid_nicknames = [] - for item in user_group_nicknames: - if isinstance(item, dict) and len(item) == 1: - key, value = list(item.items())[0] - if isinstance(key, str) and isinstance(value, int): - valid_nicknames.append(item) - else: - logger.warning(f"数据库中用户 {person_name} 群组 {group_id_str} 的绰号格式无效: {item}") + # 如果找到了目标群组的绰号列表 + if target_group_nicknames: + valid_nicknames_formatted = [] # 存储格式化后的绰号 + for item in target_group_nicknames: + # 校验每个绰号条目的格式 { "name": str, "count": int } + if isinstance(item, dict) and \ + isinstance(item.get("name"), str) and \ + isinstance(item.get("count"), int) and \ + item["count"] > 0: # 确保 count 是正整数 + # --- 格式转换:从 { "name": "xxx", "count": y } 转为 { "xxx": y } --- + valid_nicknames_formatted.append({item["name"]: item["count"]}) + # --- 结束格式转换 --- else: - logger.warning(f"数据库中用户 {person_name} 群组 {group_id_str} 的绰号条目格式无效: {item}") - if valid_nicknames: - nicknames_data[person_name] = valid_nicknames + logger.warning(f"数据库中用户 {person_name} 群组 {group_id_str} 的绰号格式无效或 count <= 0: {item}") + + if valid_nicknames_formatted: # 如果存在有效的、格式化后的绰号 + nicknames_data[person_name] = valid_nicknames_formatted # 使用 person_name 作为 key logger.debug(f"批量获取群组 {group_id_str} 中 {len(user_ids)} 个用户的绰号,找到 {len(nicknames_data)} 个用户的数据。") @@ -178,7 +190,6 @@ class RelationshipManager: logger.error(f"批量获取群组绰号时出错: {e}", exc_info=True) return nicknames_data - # --- 结束修改 --- @staticmethod async def is_qved_name(platform, user_id):