被催离教室了QAQ,今天就先到此为止了,现在是解决prompt问题

pull/914/head
Bakadax 2025-04-30 22:44:39 +08:00
parent 9d28c3660d
commit 925a531058
2 changed files with 41 additions and 47 deletions

View File

@ -232,7 +232,7 @@ class PromptBuilder:
user_ids_in_context = set()
if message_list_before_now:
for msg in message_list_before_now:
sender_id = msg.get('sender_id')
sender_id = msg["user_info"].get('user_id')
if sender_id:
user_ids_in_context.add(str(sender_id))
else:
@ -419,7 +419,8 @@ class PromptBuilder:
user_ids_in_context = set()
if message_list_before_now:
for msg in message_list_before_now:
sender_id = msg.get('sender_id')
print(msg)
sender_id = msg["user_info"].get('user_id')
if sender_id:
user_ids_in_context.add(str(sender_id))
else:

View File

@ -6,6 +6,7 @@ from .person_info import person_info_manager
import time
import random
from typing import List, Dict
from ...common.database import db
# import re
# import traceback
@ -81,91 +82,81 @@ class RelationshipManager:
is_known = person_info_manager.is_person_known(platform, user_id)
return is_known
# --- [修改] 使用全局 db 对象进行查询 ---
@staticmethod
async def get_person_names_batch(platform: str, user_ids: List[str]) -> Dict[str, str]:
"""
批量获取多个用户的 person_name
Args:
platform (str): 平台名称
user_ids (List[str]): 用户 ID 列表
Returns:
Dict[str, str]: 映射 {user_id: person_name}只包含成功获取到名称的用户
"""
if not user_ids:
return {}
person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids] # 确保 uid 是字符串
person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids]
names_map = {}
try:
# 使用 $in 操作符批量查询
cursor = person_info_manager.collection.find(
# --- 修改点:直接使用 db.person_info.find ---
# !!! 确保 'person_info' 是正确的集合名称 !!!
cursor = db.person_info.find(
{"person_id": {"$in": person_ids}},
{"_id": 0, "person_id": 1, "person_name": 1} # 只查询需要的字段
)
async for doc in cursor:
# 从 person_id 反向推导出原始 user_id
# 注意:这依赖于 get_person_id 的实现方式,假设它是 platform_userid 格式
# --- 结束修改点 ---
# 注意pymongo 的 find 返回的是同步游标,如果你的 db 对象是 motor 客户端,需要使用 await cursor.to_list(length=None)
# 假设这里 db 是 pymongo 同步客户端,或者你的环境允许在异步函数中迭代同步游标
for doc in cursor: # 如果 db 是 motor这里会报错需要改为 async for
original_user_id = doc.get("person_id", "").split("_", 1)[-1]
person_name = doc.get("person_name")
if original_user_id and person_name:
names_map[original_user_id] = person_name
logger.debug(f"Batch get person names for {len(user_ids)} users, found {len(names_map)} names.")
logger.debug(f"批量获取 {len(user_ids)} 个用户的 person_name找到 {len(names_map)} 个。")
except AttributeError as e:
# 如果 db 对象没有 person_info 属性,或者 find 方法不存在
logger.error(f"访问数据库时出错: {e}。请检查 common/database.py 和集合名称。")
except Exception as e:
logger.error(f"Error during batch get person names: {e}", exc_info=True)
logger.error(f"批量获取 person_name 时出错: {e}", exc_info=True)
return names_map
# --- 结束新增 ---
# --- 结束修改 ---
# --- [新增] 批量获取用户群组绰号 ---
# --- [修改] 使用全局 db 对象进行查询 ---
@staticmethod
async def get_users_group_nicknames(platform: str, user_ids: List[str], group_id: str) -> Dict[str, List[Dict[str, int]]]:
"""
批量获取多个用户在指定群组的绰号信息
Args:
platform (str): 平台名称
user_ids (List[str]): 用户 ID 列表
group_id (str): 群组 ID
Returns:
Dict[str, List[Dict[str, int]]]: 映射 {person_name: [{"绰号A": 次数}, ...]}
只包含成功获取到绰号信息的用户
键是用户的 person_name
"""
if not user_ids or not group_id:
return {}
person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids]
nicknames_data = {}
group_id_str = str(group_id) # 确保 group_id 是字符串
group_id_str = str(group_id)
try:
# 查询包含目标 person_id 且 group_nickname 字段存在的文档
cursor = person_info_manager.collection.find(
# --- 修改点:直接使用 db.person_info.find ---
# !!! 确保 'person_info' 是正确的集合名称 !!!
cursor = db.person_info.find(
{
"person_id": {"$in": person_ids},
"group_nickname": {"$elemMatch": {group_id_str: {"$exists": True}}} # 确保该群组的条目存在
"group_nickname": {"$elemMatch": {group_id_str: {"$exists": True}}}
},
{"_id": 0, "person_id": 1, "person_name": 1, "group_nickname": 1} # 查询所需字段
{"_id": 0, "person_id": 1, "person_name": 1, "group_nickname": 1}
)
# --- 结束修改点 ---
async for doc in cursor:
# 同样,假设同步迭代可行
for doc in cursor: # 如果 db 是 motor这里需要改为 async for
person_name = doc.get("person_name")
if not person_name: # 如果没有 person_name则跳过此用户
if not person_name:
continue
group_nicknames_list = doc.get("group_nickname", [])
user_group_nicknames = []
# 遍历 group_nickname 列表,找到对应 group_id 的条目
for group_entry in group_nicknames_list:
if group_id_str in group_entry and isinstance(group_entry[group_id_str], list):
# 提取该群组的绰号列表 [{"绰号": 次数}, ...]
user_group_nicknames = group_entry[group_id_str]
break # 找到后即可退出内层循环
break
if user_group_nicknames: # 确保列表非空
# 过滤掉格式不正确的条目
if user_group_nicknames:
valid_nicknames = []
for item in user_group_nicknames:
if isinstance(item, dict) and len(item) == 1:
@ -173,19 +164,21 @@ class RelationshipManager:
if isinstance(key, str) and isinstance(value, int):
valid_nicknames.append(item)
else:
logger.warning(f"Invalid nickname format in DB for user {person_name}, group {group_id_str}: {item}")
logger.warning(f"数据库中用户 {person_name} 群组 {group_id_str} 的绰号格式无效: {item}")
else:
logger.warning(f"Invalid nickname entry format in DB for user {person_name}, group {group_id_str}: {item}")
logger.warning(f"数据库中用户 {person_name} 群组 {group_id_str} 的绰号条目格式无效: {item}")
if valid_nicknames:
nicknames_data[person_name] = valid_nicknames # 使用 person_name 作为 key
nicknames_data[person_name] = valid_nicknames
logger.debug(f"Batch get group nicknames for {len(user_ids)} users in group {group_id_str}, found data for {len(nicknames_data)} users.")
logger.debug(f"批量获取群组 {group_id_str}{len(user_ids)} 个用户的绰号,找到 {len(nicknames_data)} 个用户的数据。")
except AttributeError as e:
logger.error(f"访问数据库时出错: {e}。请检查 common/database.py 和集合名称 'person_info'")
except Exception as e:
logger.error(f"Error during batch get group nicknames: {e}", exc_info=True)
logger.error(f"批量获取群组绰号时出错: {e}", exc_info=True)
return nicknames_data
# --- 结束修改 ---
@staticmethod
async def is_qved_name(platform, user_id):