mirror of https://github.com/Mai-with-u/MaiBot.git
被催离教室了QAQ,今天就先到此为止了,现在是解决prompt问题
parent
9d28c3660d
commit
925a531058
|
|
@ -232,7 +232,7 @@ class PromptBuilder:
|
|||
user_ids_in_context = set()
|
||||
if message_list_before_now:
|
||||
for msg in message_list_before_now:
|
||||
sender_id = msg.get('sender_id')
|
||||
sender_id = msg["user_info"].get('user_id')
|
||||
if sender_id:
|
||||
user_ids_in_context.add(str(sender_id))
|
||||
else:
|
||||
|
|
@ -419,7 +419,8 @@ class PromptBuilder:
|
|||
user_ids_in_context = set()
|
||||
if message_list_before_now:
|
||||
for msg in message_list_before_now:
|
||||
sender_id = msg.get('sender_id')
|
||||
print(msg)
|
||||
sender_id = msg["user_info"].get('user_id')
|
||||
if sender_id:
|
||||
user_ids_in_context.add(str(sender_id))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from .person_info import person_info_manager
|
|||
import time
|
||||
import random
|
||||
from typing import List, Dict
|
||||
from ...common.database import db
|
||||
# import re
|
||||
# import traceback
|
||||
|
||||
|
|
@ -81,91 +82,81 @@ class RelationshipManager:
|
|||
is_known = person_info_manager.is_person_known(platform, user_id)
|
||||
return is_known
|
||||
|
||||
# --- [修改] 使用全局 db 对象进行查询 ---
|
||||
@staticmethod
|
||||
async def get_person_names_batch(platform: str, user_ids: List[str]) -> Dict[str, str]:
|
||||
"""
|
||||
批量获取多个用户的 person_name。
|
||||
|
||||
Args:
|
||||
platform (str): 平台名称。
|
||||
user_ids (List[str]): 用户 ID 列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 映射 {user_id: person_name},只包含成功获取到名称的用户。
|
||||
"""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids] # 确保 uid 是字符串
|
||||
person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids]
|
||||
names_map = {}
|
||||
try:
|
||||
# 使用 $in 操作符批量查询
|
||||
cursor = person_info_manager.collection.find(
|
||||
# --- 修改点:直接使用 db.person_info.find ---
|
||||
# !!! 确保 'person_info' 是正确的集合名称 !!!
|
||||
cursor = db.person_info.find(
|
||||
{"person_id": {"$in": person_ids}},
|
||||
{"_id": 0, "person_id": 1, "person_name": 1} # 只查询需要的字段
|
||||
)
|
||||
async for doc in cursor:
|
||||
# 从 person_id 反向推导出原始 user_id
|
||||
# 注意:这依赖于 get_person_id 的实现方式,假设它是 platform_userid 格式
|
||||
# --- 结束修改点 ---
|
||||
|
||||
# 注意:pymongo 的 find 返回的是同步游标,如果你的 db 对象是 motor 客户端,需要使用 await cursor.to_list(length=None)
|
||||
# 假设这里 db 是 pymongo 同步客户端,或者你的环境允许在异步函数中迭代同步游标
|
||||
for doc in cursor: # 如果 db 是 motor,这里会报错,需要改为 async for
|
||||
original_user_id = doc.get("person_id", "").split("_", 1)[-1]
|
||||
person_name = doc.get("person_name")
|
||||
if original_user_id and person_name:
|
||||
names_map[original_user_id] = person_name
|
||||
logger.debug(f"Batch get person names for {len(user_ids)} users, found {len(names_map)} names.")
|
||||
logger.debug(f"批量获取 {len(user_ids)} 个用户的 person_name,找到 {len(names_map)} 个。")
|
||||
except AttributeError as e:
|
||||
# 如果 db 对象没有 person_info 属性,或者 find 方法不存在
|
||||
logger.error(f"访问数据库时出错: {e}。请检查 common/database.py 和集合名称。")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during batch get person names: {e}", exc_info=True)
|
||||
logger.error(f"批量获取 person_name 时出错: {e}", exc_info=True)
|
||||
return names_map
|
||||
# --- 结束新增 ---
|
||||
# --- 结束修改 ---
|
||||
|
||||
# --- [新增] 批量获取用户群组绰号 ---
|
||||
# --- [修改] 使用全局 db 对象进行查询 ---
|
||||
@staticmethod
|
||||
async def get_users_group_nicknames(platform: str, user_ids: List[str], group_id: str) -> Dict[str, List[Dict[str, int]]]:
|
||||
"""
|
||||
批量获取多个用户在指定群组的绰号信息。
|
||||
|
||||
Args:
|
||||
platform (str): 平台名称。
|
||||
user_ids (List[str]): 用户 ID 列表。
|
||||
group_id (str): 群组 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, List[Dict[str, int]]]: 映射 {person_name: [{"绰号A": 次数}, ...]}
|
||||
只包含成功获取到绰号信息的用户。
|
||||
键是用户的 person_name。
|
||||
"""
|
||||
if not user_ids or not group_id:
|
||||
return {}
|
||||
|
||||
person_ids = [person_info_manager.get_person_id(platform, str(uid)) for uid in user_ids]
|
||||
nicknames_data = {}
|
||||
group_id_str = str(group_id) # 确保 group_id 是字符串
|
||||
group_id_str = str(group_id)
|
||||
|
||||
try:
|
||||
# 查询包含目标 person_id 且 group_nickname 字段存在的文档
|
||||
cursor = person_info_manager.collection.find(
|
||||
# --- 修改点:直接使用 db.person_info.find ---
|
||||
# !!! 确保 'person_info' 是正确的集合名称 !!!
|
||||
cursor = db.person_info.find(
|
||||
{
|
||||
"person_id": {"$in": person_ids},
|
||||
"group_nickname": {"$elemMatch": {group_id_str: {"$exists": True}}} # 确保该群组的条目存在
|
||||
"group_nickname": {"$elemMatch": {group_id_str: {"$exists": True}}}
|
||||
},
|
||||
{"_id": 0, "person_id": 1, "person_name": 1, "group_nickname": 1} # 查询所需字段
|
||||
{"_id": 0, "person_id": 1, "person_name": 1, "group_nickname": 1}
|
||||
)
|
||||
# --- 结束修改点 ---
|
||||
|
||||
async for doc in cursor:
|
||||
# 同样,假设同步迭代可行
|
||||
for doc in cursor: # 如果 db 是 motor,这里需要改为 async for
|
||||
person_name = doc.get("person_name")
|
||||
if not person_name: # 如果没有 person_name,则跳过此用户
|
||||
if not person_name:
|
||||
continue
|
||||
|
||||
group_nicknames_list = doc.get("group_nickname", [])
|
||||
user_group_nicknames = []
|
||||
# 遍历 group_nickname 列表,找到对应 group_id 的条目
|
||||
for group_entry in group_nicknames_list:
|
||||
if group_id_str in group_entry and isinstance(group_entry[group_id_str], list):
|
||||
# 提取该群组的绰号列表 [{"绰号": 次数}, ...]
|
||||
user_group_nicknames = group_entry[group_id_str]
|
||||
break # 找到后即可退出内层循环
|
||||
break
|
||||
|
||||
if user_group_nicknames: # 确保列表非空
|
||||
# 过滤掉格式不正确的条目
|
||||
if user_group_nicknames:
|
||||
valid_nicknames = []
|
||||
for item in user_group_nicknames:
|
||||
if isinstance(item, dict) and len(item) == 1:
|
||||
|
|
@ -173,19 +164,21 @@ class RelationshipManager:
|
|||
if isinstance(key, str) and isinstance(value, int):
|
||||
valid_nicknames.append(item)
|
||||
else:
|
||||
logger.warning(f"Invalid nickname format in DB for user {person_name}, group {group_id_str}: {item}")
|
||||
logger.warning(f"数据库中用户 {person_name} 群组 {group_id_str} 的绰号格式无效: {item}")
|
||||
else:
|
||||
logger.warning(f"Invalid nickname entry format in DB for user {person_name}, group {group_id_str}: {item}")
|
||||
|
||||
logger.warning(f"数据库中用户 {person_name} 群组 {group_id_str} 的绰号条目格式无效: {item}")
|
||||
if valid_nicknames:
|
||||
nicknames_data[person_name] = valid_nicknames # 使用 person_name 作为 key
|
||||
nicknames_data[person_name] = valid_nicknames
|
||||
|
||||
logger.debug(f"Batch get group nicknames for {len(user_ids)} users in group {group_id_str}, found data for {len(nicknames_data)} users.")
|
||||
logger.debug(f"批量获取群组 {group_id_str} 中 {len(user_ids)} 个用户的绰号,找到 {len(nicknames_data)} 个用户的数据。")
|
||||
|
||||
except AttributeError as e:
|
||||
logger.error(f"访问数据库时出错: {e}。请检查 common/database.py 和集合名称 'person_info'。")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during batch get group nicknames: {e}", exc_info=True)
|
||||
logger.error(f"批量获取群组绰号时出错: {e}", exc_info=True)
|
||||
|
||||
return nicknames_data
|
||||
# --- 结束修改 ---
|
||||
|
||||
@staticmethod
|
||||
async def is_qved_name(platform, user_id):
|
||||
|
|
|
|||
Loading…
Reference in New Issue