MaiBot/src/plugins/PFC/pfc_utils.py

385 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import traceback
import json
import re
from typing import Dict, Any, Optional, Tuple, List, Union
from src.common.logger_manager import get_logger # 确认 logger 的导入路径
from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.plugins.heartFC_chat.heartflow_prompt_builder import prompt_builder # 确认 prompt_builder 的导入路径
from src.plugins.chat.chat_stream import ChatStream
from ..person_info.person_info import person_info_manager
import math
from src.plugins.utils.chat_message_builder import build_readable_messages
from .observation_info import ObservationInfo
from src.config.config import global_config
logger = get_logger("pfc_utils")
async def retrieve_contextual_info(text: str, private_name: str) -> Tuple[str, str]:
"""
根据输入文本检索相关的记忆和知识。
Args:
text: 用于检索的上下文文本 (例如聊天记录)。
private_name: 私聊对象的名称,用于日志记录。
Returns:
Tuple[str, str]: (检索到的记忆字符串, 检索到的知识字符串)
"""
retrieved_memory_str = "无相关记忆。"
retrieved_knowledge_str = "无相关知识。"
memory_log_msg = "未自动检索到相关记忆。"
knowledge_log_msg = "未自动检索到相关知识。"
if not text or text == "还没有聊天记录。" or text == "[构建聊天记录出错]":
logger.debug(f"[私聊][{private_name}] (retrieve_contextual_info) 无有效上下文,跳过检索。")
return retrieved_memory_str, retrieved_knowledge_str
# 1. 检索记忆 (逻辑来自原 _get_memory_info)
try:
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=text,
max_memory_num=2,
max_memory_length=2,
max_depth=3,
fast_retrieval=False,
)
if related_memory:
related_memory_info = ""
for memory in related_memory:
related_memory_info += memory[1] + "\n"
if related_memory_info:
# 注意:原版提示信息可以根据需要调整
retrieved_memory_str = f"你回忆起:\n{related_memory_info.strip()}\n(以上是你的回忆,供参考)\n"
memory_log_msg = f"自动检索到记忆: {related_memory_info.strip()[:100]}..."
else:
memory_log_msg = "自动检索记忆返回为空。"
logger.debug(f"[私聊][{private_name}] (retrieve_contextual_info) 记忆检索: {memory_log_msg}")
except Exception as e:
logger.error(
f"[私聊][{private_name}] (retrieve_contextual_info) 自动检索记忆时出错: {e}\n{traceback.format_exc()}"
)
retrieved_memory_str = "检索记忆时出错。\n"
# 2. 检索知识 (逻辑来自原 action_planner 和 reply_generator)
try:
# 使用导入的 prompt_builder 实例及其方法
knowledge_result = await prompt_builder.get_prompt_info(
message=text,
threshold=0.38, # threshold 可以根据需要调整
)
if knowledge_result:
retrieved_knowledge_str = knowledge_result # 直接使用返回结果
knowledge_log_msg = "自动检索到相关知识。"
logger.debug(f"[私聊][{private_name}] (retrieve_contextual_info) 知识检索: {knowledge_log_msg}")
except Exception as e:
logger.error(
f"[私聊][{private_name}] (retrieve_contextual_info) 自动检索知识时出错: {e}\n{traceback.format_exc()}"
)
retrieved_knowledge_str = "检索知识时出错。\n"
return retrieved_memory_str, retrieved_knowledge_str
def get_items_from_json(
content: str,
private_name: str,
*items: str,
default_values: Optional[Dict[str, Any]] = None,
required_types: Optional[Dict[str, type]] = None,
allow_array: bool = True,
) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]:
"""从文本中提取JSON内容并获取指定字段
Args:
content: 包含JSON的文本
private_name: 私聊名称
*items: 要提取的字段名
default_values: 字段的默认值,格式为 {字段名: 默认值}
required_types: 字段的必需类型,格式为 {字段名: 类型}
allow_array: 是否允许解析JSON数组
Returns:
Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表)
"""
cleaned_content = content.strip()
result: Union[Dict[str, Any], List[Dict[str, Any]]] = {} # 初始化类型
# 匹配 ```json ... ``` 或 ``` ... ```
markdown_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", cleaned_content, re.IGNORECASE)
if markdown_match:
cleaned_content = markdown_match.group(1).strip()
logger.debug(f"[私聊][{private_name}] 已去除 Markdown 标记,剩余内容: {cleaned_content[:100]}...")
# --- 新增结束 ---
# 设置默认值
default_result: Dict[str, Any] = {} # 用于单对象时的默认值
if default_values:
default_result.update(default_values)
result = default_result.copy() # 先用默认值初始化
# 首先尝试解析为JSON数组
if allow_array:
try:
# 尝试直接解析清理后的内容为列表
json_array = json.loads(cleaned_content)
if isinstance(json_array, list):
valid_items_list: List[Dict[str, Any]] = []
for item in json_array:
if not isinstance(item, dict):
logger.warning(f"[私聊][{private_name}] JSON数组中的元素不是字典: {item}")
continue
current_item_result = default_result.copy() # 每个元素都用默认值初始化
valid_item = True
# 提取并验证字段
for field in items:
if field in item:
current_item_result[field] = item[field]
elif field not in default_result: # 如果字段不存在且没有默认值
logger.warning(f"[私聊][{private_name}] JSON数组元素缺少必要字段 '{field}': {item}")
valid_item = False
break # 这个元素无效
if not valid_item:
continue
# 验证类型
if required_types:
for field, expected_type in required_types.items():
# 检查 current_item_result 中是否存在该字段 (可能来自 item 或 default_values)
if field in current_item_result and not isinstance(
current_item_result[field], expected_type
):
logger.warning(
f"[私聊][{private_name}] JSON数组元素字段 '{field}' 类型错误 (应为 {expected_type.__name__}, 实际为 {type(current_item_result[field]).__name__}): {item}"
)
valid_item = False
break
if not valid_item:
continue
# 验证字符串不为空 (只检查 items 中要求的字段)
for field in items:
if (
field in current_item_result
and isinstance(current_item_result[field], str)
and not current_item_result[field].strip()
):
logger.warning(f"[私聊][{private_name}] JSON数组元素字段 '{field}' 不能为空字符串: {item}")
valid_item = False
break
if valid_item:
valid_items_list.append(current_item_result) # 只添加完全有效的项
if valid_items_list: # 只有当列表不为空时才认为是成功
logger.debug(f"[私聊][{private_name}] 成功解析JSON数组包含 {len(valid_items_list)} 个有效项目。")
return True, valid_items_list
else:
# 如果列表为空(可能所有项都无效),则继续尝试解析为单个对象
logger.debug(f"[私聊][{private_name}] 解析为JSON数组但未找到有效项目尝试解析单个JSON对象。")
# result 重置回单个对象的默认值
result = default_result.copy()
except json.JSONDecodeError:
logger.debug(f"[私聊][{private_name}] JSON数组直接解析失败尝试解析单个JSON对象")
# result 重置回单个对象的默认值
result = default_result.copy()
except Exception as e:
logger.error(f"[私聊][{private_name}] 尝试解析JSON数组时发生未知错误: {str(e)}")
# result 重置回单个对象的默认值
result = default_result.copy()
# 尝试解析为单个JSON对象
try:
# 尝试直接解析清理后的内容
json_data = json.loads(cleaned_content)
if not isinstance(json_data, dict):
logger.error(f"[私聊][{private_name}] 解析为单个对象,但结果不是字典类型: {type(json_data)}")
return False, default_result # 返回失败和默认值
except json.JSONDecodeError:
# 如果直接解析失败,尝试用正则表达式查找 JSON 对象部分 (作为后备)
# 这个正则比较简单,可能无法处理嵌套或复杂的 JSON
json_pattern = r"\{[\s\S]*?\}" # 使用非贪婪匹配
json_match = re.search(json_pattern, cleaned_content)
if json_match:
try:
potential_json_str = json_match.group()
json_data = json.loads(potential_json_str)
if not isinstance(json_data, dict):
logger.error(f"[私聊][{private_name}] 正则提取后解析,但结果不是字典类型: {type(json_data)}")
return False, default_result
logger.debug(f"[私聊][{private_name}] 通过正则提取并成功解析JSON对象。")
except json.JSONDecodeError:
logger.error(f"[私聊][{private_name}] 正则提取的部分 '{potential_json_str[:100]}...' 无法解析为JSON。")
return False, default_result
else:
logger.error(
f"[私聊][{private_name}] 无法在返回内容中找到有效的JSON对象部分。原始内容: {cleaned_content[:100]}..."
)
return False, default_result
# 提取并验证字段 (适用于单个JSON对象)
# 确保 result 是字典类型用于更新
if not isinstance(result, dict):
result = default_result.copy() # 如果之前是列表,重置为字典
valid_single_object = True
for item in items:
if item in json_data:
result[item] = json_data[item]
elif item not in default_result: # 如果字段不存在且没有默认值
logger.error(f"[私聊][{private_name}] JSON对象缺少必要字段 '{item}'。JSON内容: {json_data}")
valid_single_object = False
break # 这个对象无效
if not valid_single_object:
return False, default_result
# 验证类型
if required_types:
for field, expected_type in required_types.items():
if field in result and not isinstance(result[field], expected_type):
logger.error(
f"[私聊][{private_name}] JSON对象字段 '{field}' 类型错误 (应为 {expected_type.__name__}, 实际为 {type(result[field]).__name__})"
)
valid_single_object = False
break
if not valid_single_object:
return False, default_result
# 验证字符串不为空 (只检查 items 中要求的字段)
for field in items:
if field in result and isinstance(result[field], str) and not result[field].strip():
logger.error(f"[私聊][{private_name}] JSON对象字段 '{field}' 不能为空字符串")
valid_single_object = False
break
if valid_single_object:
logger.debug(f"[私聊][{private_name}] 成功解析并验证了单个JSON对象。")
return True, result # 返回提取并验证后的字典
else:
return False, default_result # 验证失败
async def get_person_id(private_name: str, chat_stream: ChatStream):
private_user_id_str: Optional[str] = None
private_platform_str: Optional[str] = None
private_nickname_str = private_name
if chat_stream.user_info:
private_user_id_str = str(chat_stream.user_info.user_id)
private_platform_str = chat_stream.user_info.platform
logger.debug(
f"[私聊][{private_name}] 从 ChatStream 获取到私聊对象信息: ID={private_user_id_str}, Platform={private_platform_str}, Name={private_nickname_str}"
)
elif chat_stream.group_info is None and private_name:
pass
if private_user_id_str and private_platform_str:
try:
private_user_id_int = int(private_user_id_str)
# person_id = person_info_manager.get_person_id( # get_person_id 可能只查询,不创建
# private_platform_str,
# private_user_id_int
# )
# 使用 get_or_create_person 确保用户存在
person_id = await person_info_manager.get_or_create_person(
platform=private_platform_str,
user_id=private_user_id_int,
nickname=private_name, # 使用传入的 private_name 作为昵称
)
if person_id is None: # 如果 get_or_create_person 返回 None说明创建失败
logger.error(f"[私聊][{private_name}] get_or_create_person 未能获取或创建 person_id。")
return None # 返回 None 表示失败
return person_id, private_platform_str, private_user_id_str # 返回获取或创建的 person_id
except ValueError:
logger.error(f"[私聊][{private_name}] 无法将 private_user_id_str ('{private_user_id_str}') 转换为整数。")
return None # 返回 None 表示失败
except Exception as e_pid:
logger.error(f"[私聊][{private_name}] 获取或创建 person_id 时出错: {e_pid}")
return None # 返回 None 表示失败
else:
logger.warning(
f"[私聊][{private_name}] 未能确定私聊对象的 user_id 或 platform无法获取 person_id。将在收到消息后尝试。"
)
return None # 返回 None 表示失败
async def adjust_relationship_value_nonlinear(old_value: float, raw_adjustment: float) -> float:
# 限制 old_value 范围
old_value = max(-1000, min(1000, old_value))
value = raw_adjustment
if old_value >= 0:
if value >= 0:
value = value * math.cos(math.pi * old_value / 2000)
if old_value > 500:
rdict = await person_info_manager.get_specific_value_list("relationship_value", lambda x: x > 700)
high_value_count = len(rdict)
if old_value > 700:
value *= 3 / (high_value_count + 2)
else:
value *= 3 / (high_value_count + 3)
elif value < 0:
value = value * math.exp(old_value / 2000)
else:
value = 0
else:
if value >= 0:
value = value * math.exp(old_value / 2000)
elif value < 0:
value = value * math.cos(math.pi * old_value / 2000)
else:
value = 0
return value
async def build_chat_history_text(observation_info: ObservationInfo, private_name: str) -> str:
"""构建聊天历史记录文本 (包含未处理消息)"""
chat_history_text = ""
try:
if hasattr(observation_info, "chat_history_str") and observation_info.chat_history_str:
chat_history_text = observation_info.chat_history_str
elif hasattr(observation_info, "chat_history") and observation_info.chat_history:
history_slice = observation_info.chat_history[-20:]
chat_history_text = await build_readable_messages(
history_slice, replace_bot_name=True, merge_messages=False, timestamp_mode="relative", read_mark=0.0
)
else:
chat_history_text = "还没有聊天记录。\n"
unread_count = getattr(observation_info, "new_messages_count", 0)
unread_messages = getattr(observation_info, "unprocessed_messages", [])
if unread_count > 0 and unread_messages:
bot_qq_str = str(global_config.BOT_QQ)
other_unread_messages = [
msg for msg in unread_messages if msg.get("user_info", {}).get("user_id") != bot_qq_str
]
other_unread_count = len(other_unread_messages)
if other_unread_count > 0:
new_messages_str = await build_readable_messages(
other_unread_messages,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
chat_history_text += f"\n{new_messages_str}\n------\n"
except AttributeError as e:
logger.warning(f"[私聊][{private_name}] 构建聊天记录文本时属性错误: {e}")
chat_history_text = "[获取聊天记录时出错]\n"
except Exception as e:
logger.error(f"[私聊][{private_name}] 处理聊天记录时发生未知错误: {e}")
chat_history_text = "[处理聊天记录时出错]\n"
return chat_history_text