From 925a5310583dc6dc175aa5458879f63c47e360d5 Mon Sep 17 00:00:00 2001 From: Bakadax Date: Wed, 30 Apr 2025 22:44:39 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A2=AB=E5=82=AC=E7=A6=BB=E6=95=99=E5=AE=A4?= =?UTF-8?q?=E4=BA=86QAQ=EF=BC=8C=E4=BB=8A=E5=A4=A9=E5=B0=B1=E5=85=88?= =?UTF-8?q?=E5=88=B0=E6=AD=A4=E4=B8=BA=E6=AD=A2=E4=BA=86=EF=BC=8C=E7=8E=B0?= =?UTF-8?q?=E5=9C=A8=E6=98=AF=E8=A7=A3=E5=86=B3prompt=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../heartFC_chat/heartflow_prompt_builder.py | 5 +- .../person_info/relationship_manager.py | 83 +++++++++---------- 2 files changed, 41 insertions(+), 47 deletions(-) diff --git a/src/plugins/heartFC_chat/heartflow_prompt_builder.py b/src/plugins/heartFC_chat/heartflow_prompt_builder.py index 216955c3..afef18c8 100644 --- a/src/plugins/heartFC_chat/heartflow_prompt_builder.py +++ b/src/plugins/heartFC_chat/heartflow_prompt_builder.py @@ -232,7 +232,7 @@ class PromptBuilder: user_ids_in_context = set() if message_list_before_now: for msg in message_list_before_now: - sender_id = msg.get('sender_id') + sender_id = msg["user_info"].get('user_id') if sender_id: user_ids_in_context.add(str(sender_id)) else: @@ -419,7 +419,8 @@ class PromptBuilder: user_ids_in_context = set() if message_list_before_now: for msg in message_list_before_now: - sender_id = msg.get('sender_id') + print(msg) + sender_id = msg["user_info"].get('user_id') if sender_id: user_ids_in_context.add(str(sender_id)) else: diff --git a/src/plugins/person_info/relationship_manager.py b/src/plugins/person_info/relationship_manager.py index c8bda8f3..846f3408 100644 --- a/src/plugins/person_info/relationship_manager.py +++ b/src/plugins/person_info/relationship_manager.py @@ -6,6 +6,7 @@ from .person_info import person_info_manager import time import random from typing import List, Dict +from ...common.database import db # import re # import traceback @@ -81,91 +82,81 @@ class RelationshipManager: is_known = person_info_manager.is_person_known(platform, user_id) return is_known + # --- [修改] 使用全局 db 对象进行查询 --- @staticmethod async def get_person_names_batch(platform: str, user_ids: List[str]) -> Dict[str, str]: """ 批量获取多个用户的 person_name。 - - Args: - platform (str): 平台名称。 - user_ids (List[str]): 用户 ID 列表。 - - Returns: - Dict[str, str]: 映射 {user_id: person_name},只包含成功获取到名称的用户。 """ if not user_ids: return {} - person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids] # 确保 uid 是字符串 + person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids] names_map = {} try: - # 使用 $in 操作符批量查询 - cursor = person_info_manager.collection.find( + # --- 修改点:直接使用 db.person_info.find --- + # !!! 确保 'person_info' 是正确的集合名称 !!! + cursor = db.person_info.find( {"person_id": {"$in": person_ids}}, {"_id": 0, "person_id": 1, "person_name": 1} # 只查询需要的字段 ) - async for doc in cursor: - # 从 person_id 反向推导出原始 user_id - # 注意:这依赖于 get_person_id 的实现方式,假设它是 platform_userid 格式 + # --- 结束修改点 --- + + # 注意:pymongo 的 find 返回的是同步游标,如果你的 db 对象是 motor 客户端,需要使用 await cursor.to_list(length=None) + # 假设这里 db 是 pymongo 同步客户端,或者你的环境允许在异步函数中迭代同步游标 + for doc in cursor: # 如果 db 是 motor,这里会报错,需要改为 async for original_user_id = doc.get("person_id", "").split("_", 1)[-1] person_name = doc.get("person_name") if original_user_id and person_name: names_map[original_user_id] = person_name - logger.debug(f"Batch get person names for {len(user_ids)} users, found {len(names_map)} names.") + logger.debug(f"批量获取 {len(user_ids)} 个用户的 person_name,找到 {len(names_map)} 个。") + except AttributeError as e: + # 如果 db 对象没有 person_info 属性,或者 find 方法不存在 + logger.error(f"访问数据库时出错: {e}。请检查 common/database.py 和集合名称。") except Exception as e: - logger.error(f"Error during batch get person names: {e}", exc_info=True) + logger.error(f"批量获取 person_name 时出错: {e}", exc_info=True) return names_map - # --- 结束新增 --- + # --- 结束修改 --- - # --- [新增] 批量获取用户群组绰号 --- + # --- [修改] 使用全局 db 对象进行查询 --- @staticmethod async def get_users_group_nicknames(platform: str, user_ids: List[str], group_id: str) -> Dict[str, List[Dict[str, int]]]: """ 批量获取多个用户在指定群组的绰号信息。 - - Args: - platform (str): 平台名称。 - user_ids (List[str]): 用户 ID 列表。 - group_id (str): 群组 ID。 - - Returns: - Dict[str, List[Dict[str, int]]]: 映射 {person_name: [{"绰号A": 次数}, ...]} - 只包含成功获取到绰号信息的用户。 - 键是用户的 person_name。 """ if not user_ids or not group_id: return {} person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids] nicknames_data = {} - group_id_str = str(group_id) # 确保 group_id 是字符串 + group_id_str = str(group_id) try: - # 查询包含目标 person_id 且 group_nickname 字段存在的文档 - cursor = person_info_manager.collection.find( + # --- 修改点:直接使用 db.person_info.find --- + # !!! 确保 'person_info' 是正确的集合名称 !!! + cursor = db.person_info.find( { "person_id": {"$in": person_ids}, - "group_nickname": {"$elemMatch": {group_id_str: {"$exists": True}}} # 确保该群组的条目存在 + "group_nickname": {"$elemMatch": {group_id_str: {"$exists": True}}} }, - {"_id": 0, "person_id": 1, "person_name": 1, "group_nickname": 1} # 查询所需字段 + {"_id": 0, "person_id": 1, "person_name": 1, "group_nickname": 1} ) + # --- 结束修改点 --- - async for doc in cursor: + # 同样,假设同步迭代可行 + for doc in cursor: # 如果 db 是 motor,这里需要改为 async for person_name = doc.get("person_name") - if not person_name: # 如果没有 person_name,则跳过此用户 + if not person_name: continue group_nicknames_list = doc.get("group_nickname", []) user_group_nicknames = [] - # 遍历 group_nickname 列表,找到对应 group_id 的条目 for group_entry in group_nicknames_list: if group_id_str in group_entry and isinstance(group_entry[group_id_str], list): - # 提取该群组的绰号列表 [{"绰号": 次数}, ...] user_group_nicknames = group_entry[group_id_str] - break # 找到后即可退出内层循环 + break - if user_group_nicknames: # 确保列表非空 - # 过滤掉格式不正确的条目 + if user_group_nicknames: valid_nicknames = [] for item in user_group_nicknames: if isinstance(item, dict) and len(item) == 1: @@ -173,19 +164,21 @@ class RelationshipManager: if isinstance(key, str) and isinstance(value, int): valid_nicknames.append(item) else: - logger.warning(f"Invalid nickname format in DB for user {person_name}, group {group_id_str}: {item}") + logger.warning(f"数据库中用户 {person_name} 群组 {group_id_str} 的绰号格式无效: {item}") else: - logger.warning(f"Invalid nickname entry format in DB for user {person_name}, group {group_id_str}: {item}") - + logger.warning(f"数据库中用户 {person_name} 群组 {group_id_str} 的绰号条目格式无效: {item}") if valid_nicknames: - nicknames_data[person_name] = valid_nicknames # 使用 person_name 作为 key + nicknames_data[person_name] = valid_nicknames - logger.debug(f"Batch get group nicknames for {len(user_ids)} users in group {group_id_str}, found data for {len(nicknames_data)} users.") + logger.debug(f"批量获取群组 {group_id_str} 中 {len(user_ids)} 个用户的绰号,找到 {len(nicknames_data)} 个用户的数据。") + except AttributeError as e: + logger.error(f"访问数据库时出错: {e}。请检查 common/database.py 和集合名称 'person_info'。") except Exception as e: - logger.error(f"Error during batch get group nicknames: {e}", exc_info=True) + logger.error(f"批量获取群组绰号时出错: {e}", exc_info=True) return nicknames_data + # --- 结束修改 --- @staticmethod async def is_qved_name(platform, user_id):