附属函数参数修改

pull/1201/head
UnCLAS-Prommer 2025-08-21 00:46:04 +08:00
parent e1a21c5a45
commit e8922672aa
No known key found for this signature in database
11 changed files with 292 additions and 337 deletions

View File

@ -9,7 +9,6 @@ import networkx as nx
import numpy as np import numpy as np
from typing import List, Tuple, Set, Coroutine, Any, Dict from typing import List, Tuple, Set, Coroutine, Any, Dict
from collections import Counter from collections import Counter
from itertools import combinations
import traceback import traceback
from rich.traceback import install from rich.traceback import install
@ -23,6 +22,8 @@ from src.chat.utils.chat_message_builder import (
build_readable_messages, build_readable_messages,
get_raw_msg_by_timestamp_with_chat_inclusive, get_raw_msg_by_timestamp_with_chat_inclusive,
) # 导入 build_readable_messages ) # 导入 build_readable_messages
# 添加cosine_similarity函数 # 添加cosine_similarity函数
def cosine_similarity(v1, v2): def cosine_similarity(v1, v2):
"""计算余弦相似度""" """计算余弦相似度"""
@ -51,18 +52,9 @@ def calculate_information_content(text):
return entropy return entropy
logger = get_logger("memory") logger = get_logger("memory")
class MemoryGraph: class MemoryGraph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
@ -203,7 +195,9 @@ class MemoryGraph:
整合后的记忆""" 整合后的记忆"""
# 调用LLM进行整合 # 调用LLM进行整合
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(integration_prompt) content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(
integration_prompt
)
if content and content.strip(): if content and content.strip():
integrated_content = content.strip() integrated_content = content.strip()
@ -238,7 +232,11 @@ class MemoryGraph:
if memory_items: if memory_items:
# 删除整个节点 # 删除整个节点
self.G.remove_node(topic) self.G.remove_node(topic)
return f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." if len(memory_items) > 50 else f"删除了节点 {topic} 的完整记忆: {memory_items}" return (
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
if len(memory_items) > 50
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
)
else: else:
# 如果没有记忆项,删除该节点 # 如果没有记忆项,删除该节点
self.G.remove_node(topic) self.G.remove_node(topic)
@ -263,7 +261,9 @@ class Hippocampus:
self.parahippocampal_gyrus = ParahippocampalGyrus(self) self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图 # 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db() self.entorhinal_cortex.sync_memory_from_db()
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.modify") self.model_small = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="memory.modify"
)
def get_all_node_names(self) -> list: def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表""" """获取记忆图中所有节点的名字列表"""
@ -333,8 +333,6 @@ class Hippocampus:
f"如果确定找不出主题或者没有明显主题,返回<none>。" f"如果确定找不出主题或者没有明显主题,返回<none>。"
) )
return prompt return prompt
@staticmethod @staticmethod
@ -419,15 +417,12 @@ class Hippocampus:
text_length = len(text) text_length = len(text)
topic_num: int | list[int] = 0 topic_num: int | list[int] = 0
words = jieba.cut(text) words = jieba.cut(text)
keywords_lite = [word for word in words if len(word) > 1] keywords_lite = [word for word in words if len(word) > 1]
keywords_lite = list(set(keywords_lite)) keywords_lite = list(set(keywords_lite))
if keywords_lite: if keywords_lite:
logger.debug(f"提取关键词极简版: {keywords_lite}") logger.debug(f"提取关键词极简版: {keywords_lite}")
if text_length <= 12: if text_length <= 12:
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本) topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
elif text_length <= 20: elif text_length <= 20:
@ -455,7 +450,7 @@ class Hippocampus:
if keywords: if keywords:
logger.debug(f"提取关键词: {keywords}") logger.debug(f"提取关键词: {keywords}")
return keywords,keywords_lite return keywords, keywords_lite
async def get_memory_from_topic( async def get_memory_from_topic(
self, self,
@ -570,15 +565,12 @@ class Hippocampus:
for node, activation in remember_map.items(): for node, activation in remember_map.items():
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):") logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
node_data = self.memory_graph.G.nodes[node] node_data = self.memory_graph.G.nodes[node]
memory_items = node_data.get("memory_items", "") if memory_items := node_data.get("memory_items", ""):
# 直接使用完整的记忆内容
if memory_items:
logger.debug("节点包含完整记忆") logger.debug("节点包含完整记忆")
# 计算记忆与关键词的相似度 # 计算记忆与关键词的相似度
memory_words = set(jieba.cut(memory_items)) memory_words = set(jieba.cut(memory_items))
text_words = set(keywords) text_words = set(keywords)
all_words = memory_words | text_words if all_words := memory_words | text_words:
if all_words:
# 计算相似度(虽然这里没有使用,但保持逻辑一致性) # 计算相似度(虽然这里没有使用,但保持逻辑一致性)
v1 = [1 if word in memory_words else 0 for word in all_words] v1 = [1 if word in memory_words else 0 for word in all_words]
v2 = [1 if word in text_words else 0 for word in all_words] v2 = [1 if word in text_words else 0 for word in all_words]
@ -613,7 +605,9 @@ class Hippocampus:
return result return result
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str],list[str]]: async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str], list[str]]:
"""从文本中提取关键词并获取相关记忆。 """从文本中提取关键词并获取相关记忆。
Args: Args:
@ -627,13 +621,13 @@ class Hippocampus:
float: 激活节点数与总节点数的比值 float: 激活节点数与总节点数的比值
list[str]: 有效的关键词 list[str]: 有效的关键词
""" """
keywords,keywords_lite = await self.get_keywords_from_text(text) keywords, keywords_lite = await self.get_keywords_from_text(text)
# 过滤掉不存在于记忆图中的关键词 # 过滤掉不存在于记忆图中的关键词
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
if not valid_keywords: if not valid_keywords:
# logger.info("没有找到有效的关键词节点") # logger.info("没有找到有效的关键词节点")
return 0, keywords,keywords_lite return 0, keywords, keywords_lite
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}") logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
@ -700,7 +694,7 @@ class Hippocampus:
activation_ratio = activation_ratio * 50 activation_ratio = activation_ratio * 50
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}") logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
return activation_ratio, keywords,keywords_lite return activation_ratio, keywords, keywords_lite
# 负责海马体与其他部分的交互 # 负责海马体与其他部分的交互
@ -865,7 +859,9 @@ class EntorhinalCortex:
end_time = time.time() end_time = time.time()
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}") logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}")
logger.info(f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边") logger.info(
f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边"
)
async def resync_memory_to_db(self): async def resync_memory_to_db(self):
"""清空数据库并重新同步所有记忆数据""" """清空数据库并重新同步所有记忆数据"""
@ -999,11 +995,15 @@ class EntorhinalCortex:
last_modified = node.last_modified or current_time last_modified = node.last_modified or current_time
# 获取权重属性 # 获取权重属性
weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0 weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
# 添加节点到图中 # 添加节点到图中
self.memory_graph.G.add_node( self.memory_graph.G.add_node(
concept, memory_items=memory_items, weight=weight, created_time=created_time, last_modified=last_modified concept,
memory_items=memory_items,
weight=weight,
created_time=created_time,
last_modified=last_modified,
) )
loaded_nodes += 1 loaded_nodes += 1
except Exception as e: except Exception as e:
@ -1046,7 +1046,9 @@ class EntorhinalCortex:
logger.info("[数据库] 已为缺失的时间字段进行补充") logger.info("[数据库] 已为缺失的时间字段进行补充")
# 输出加载统计信息 # 输出加载统计信息
logger.info(f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes}") logger.info(
f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes}"
)
# 负责整合,遗忘,合并记忆 # 负责整合,遗忘,合并记忆
@ -1055,9 +1057,11 @@ class ParahippocampalGyrus:
self.hippocampus = hippocampus self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph self.memory_graph = hippocampus.memory_graph
self.memory_modify_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory.modify") self.memory_modify_model = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="memory.modify"
)
async def memory_compress(self, messages: list, compress_rate=0.1): async def memory_compress(self, messages: list[DatabaseMessages], compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。 """压缩和总结消息内容,生成记忆主题和摘要。
Args: Args:
@ -1314,8 +1318,6 @@ class ParahippocampalGyrus:
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}") logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}")
class HippocampusManager: class HippocampusManager:
def __init__(self): def __init__(self):
self._hippocampus: Hippocampus = None # type: ignore self._hippocampus: Hippocampus = None # type: ignore
@ -1374,7 +1376,10 @@ class HippocampusManager:
logger.info(f"{chat_id} 构建记忆,消息数量: {len(messages)}") logger.info(f"{chat_id} 构建记忆,消息数量: {len(messages)}")
# 调用记忆压缩和构建 # 调用记忆压缩和构建
compressed_memory, similar_topics_dict = await self._hippocampus.parahippocampal_gyrus.memory_compress( (
compressed_memory,
similar_topics_dict,
) = await self._hippocampus.parahippocampal_gyrus.memory_compress(
messages, global_config.memory.memory_compress_rate messages, global_config.memory.memory_compress_rate
) )
@ -1390,10 +1395,11 @@ class HippocampusManager:
if topic != similar_topic: if topic != similar_topic:
strength = int(similarity * 10) strength = int(similarity * 10)
self._hippocampus.memory_graph.G.add_edge( self._hippocampus.memory_graph.G.add_edge(
topic, similar_topic, topic,
similar_topic,
strength=strength, strength=strength,
created_time=current_time, created_time=current_time,
last_modified=current_time last_modified=current_time,
) )
# 同步到数据库 # 同步到数据库
@ -1407,7 +1413,6 @@ class HippocampusManager:
return False return False
async def get_memory_from_topic( async def get_memory_from_topic(
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> list: ) -> list:
@ -1423,16 +1428,20 @@ class HippocampusManager:
response = [] response = []
return response return response
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]: async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str]]:
"""从文本中获取激活值的公共接口""" """从文本中获取激活值的公共接口"""
if not self._initialized: if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
try: try:
response, keywords,keywords_lite = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval) response, keywords, keywords_lite = await self._hippocampus.get_activate_from_text(
text, max_depth, fast_retrieval
)
except Exception as e: except Exception as e:
logger.error(f"文本产生激活值失败: {e}") logger.error(f"文本产生激活值失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return 0.0, [],[] return 0.0, [], []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
"""从关键词获取相关记忆的公共接口""" """从关键词获取相关记忆的公共接口"""
@ -1469,7 +1478,7 @@ class MemoryBuilder:
# 检查时间间隔 # 检查时间间隔
time_diff = current_time - self.last_update_time time_diff = current_time - self.last_update_time
if time_diff < 600 /global_config.memory.memory_build_frequency: if time_diff < 600 / global_config.memory.memory_build_frequency:
return False return False
# 检查消息数量 # 检查消息数量
@ -1482,30 +1491,27 @@ class MemoryBuilder:
logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}") logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}")
if not recent_messages or len(recent_messages) < 30/global_config.memory.memory_build_frequency : if not recent_messages or len(recent_messages) < 30 / global_config.memory.memory_build_frequency:
return False return False
return True return True
def get_messages_for_memory_build(self, threshold: int = 25) -> List[Dict[str, Any]]: def get_messages_for_memory_build(self, threshold: int = 25) -> List[DatabaseMessages]:
"""获取用于记忆构建的消息""" """获取用于记忆构建的消息"""
current_time = time.time() current_time = time.time()
messages = get_raw_msg_by_timestamp_with_chat_inclusive( messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_update_time, timestamp_start=self.last_update_time,
timestamp_end=current_time, timestamp_end=current_time,
limit=threshold, limit=threshold,
) )
tmp_msg = [msg.__dict__ for msg in messages] if messages else []
if messages: if messages:
# 更新最后处理时间 # 更新最后处理时间
self.last_processed_time = current_time self.last_processed_time = current_time
self.last_update_time = current_time self.last_update_time = current_time
return tmp_msg or [] return messages or []
class MemorySegmentManager: class MemorySegmentManager:
@ -1528,7 +1534,7 @@ class MemorySegmentManager:
builder = self.get_or_create_builder(chat_id) builder = self.get_or_create_builder(chat_id)
return builder.should_trigger_memory_build() return builder.should_trigger_memory_build()
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[Dict[str, Any]]: def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[DatabaseMessages]:
"""获取指定chat_id用于记忆构建的消息""" """获取指定chat_id用于记忆构建的消息"""
if chat_id not in self.builders: if chat_id not in self.builders:
return [] return []
@ -1537,4 +1543,3 @@ class MemorySegmentManager:
# 创建全局实例 # 创建全局实例
memory_segment_manager = MemorySegmentManager() memory_segment_manager = MemorySegmentManager()

View File

@ -1,17 +1,17 @@
import json import json
import random
from json_repair import repair_json from json_repair import repair_json
from typing import List, Tuple from typing import List, Tuple
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.utils.utils import parse_keywords_string from src.chat.utils.utils import parse_keywords_string
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
import random from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator") logger = get_logger("memory_activator")
@ -75,8 +75,9 @@ class MemoryActivator:
request_type="memory.selection", request_type="memory.selection",
) )
async def activate_memory_with_chat_history(
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]: self, target_message, chat_history: List[DatabaseMessages]
) -> List[Tuple[str, str]]:
""" """
激活记忆 激活记忆
""" """
@ -86,8 +87,8 @@ class MemoryActivator:
keywords_list = set() keywords_list = set()
for msg in chat_history_prompt: for msg in chat_history:
keywords = parse_keywords_string(msg.get("key_words", "")) keywords = parse_keywords_string(msg.key_words)
if keywords: if keywords:
if len(keywords_list) < 30: if len(keywords_list) < 30:
# 最多容纳30个关键词 # 最多容纳30个关键词
@ -112,20 +113,13 @@ class MemoryActivator:
logger.debug("海马体没有返回相关记忆") logger.debug("海马体没有返回相关记忆")
return [] return []
used_ids = set() used_ids = set()
candidate_memories = [] candidate_memories = []
# 为每个记忆分配随机ID并过滤相关记忆 # 为每个记忆分配随机ID并过滤相关记忆
for memory in related_memory: for memory in related_memory:
keyword, content = memory keyword, content = memory
found = False found = any(kw in content for kw in keywords_list)
for kw in keywords_list:
if kw in content:
found = True
break
if found: if found:
# 随机分配一个不重复的2位数id # 随机分配一个不重复的2位数id
while True: while True:
@ -145,12 +139,11 @@ class MemoryActivator:
# 转换为 (keyword, content) 格式 # 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories] return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
# 使用 LLM 选择合适的记忆 return await self._select_memories_with_llm(target_message, chat_history, candidate_memories)
selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories)
return selected_memories async def _select_memories_with_llm(
self, target_message, chat_history: List[DatabaseMessages], candidate_memories
async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]: ) -> List[Tuple[str, str]]:
""" """
使用 LLM 选择合适的记忆 使用 LLM 选择合适的记忆
@ -165,14 +158,13 @@ class MemoryActivator:
try: try:
# 构建聊天历史字符串 # 构建聊天历史字符串
obs_info_text = build_readable_messages( obs_info_text = build_readable_messages(
chat_history_prompt, chat_history,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="relative", timestamp_mode="relative",
read_mark=0.0, read_mark=0.0,
show_actions=True, show_actions=True,
) )
# 构建记忆信息字符串 # 构建记忆信息字符串
memory_lines = [] memory_lines = []
for memory in candidate_memories: for memory in candidate_memories:
@ -193,18 +185,12 @@ class MemoryActivator:
# 获取并格式化 prompt # 获取并格式化 prompt
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt") prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
formatted_prompt = prompt_template.format( formatted_prompt = prompt_template.format(
obs_info_text=obs_info_text, obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info
target_message=target_message,
memory_info=memory_info
) )
# 调用 LLM # 调用 LLM
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async( response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
formatted_prompt, formatted_prompt, temperature=0.3, max_tokens=150
temperature=0.3,
max_tokens=150
) )
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
@ -221,11 +207,8 @@ class MemoryActivator:
# 解析为 Python 对象 # 解析为 Python 对象
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
# 提取 memory_ids 字段 # 提取 memory_ids 字段并解析逗号分隔的编号
memory_ids_str = result.get("memory_ids", "") if memory_ids_str := result.get("memory_ids", ""):
# 解析逗号分隔的编号
if memory_ids_str:
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()] memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
# 过滤掉空字符串和无效编号 # 过滤掉空字符串和无效编号
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3] valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
@ -240,10 +223,9 @@ class MemoryActivator:
selected_memories = [] selected_memories = []
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories} memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
for memory_id in selected_memory_ids: selected_memories = [
if memory_id in memory_id_to_memory: memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory
selected_memories.append(memory_id_to_memory[memory_id]) ]
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}") logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
logger.info(f"最终选择的记忆数量: {len(selected_memories)}") logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
@ -256,5 +238,4 @@ class MemoryActivator:
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]] return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
init_prompt() init_prompt()

View File

@ -8,6 +8,7 @@ from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@ -352,7 +353,7 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_memory_block(self, chat_history: List[Dict[str, Any]], target: str) -> str: async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
"""构建记忆块 """构建记忆块
Args: Args:
@ -369,7 +370,7 @@ class DefaultReplyer:
instant_memory = None instant_memory = None
running_memories = await self.memory_activator.activate_memory_with_chat_history( running_memories = await self.memory_activator.activate_memory_with_chat_history(
target_message=target, chat_history_prompt=chat_history target_message=target, chat_history=chat_history
) )
if global_config.memory.enable_instant_memory: if global_config.memory.enable_instant_memory:
@ -433,7 +434,7 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}") logger.error(f"工具信息获取失败: {e}")
return "" return ""
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: def _parse_reply_target(self, target_message: Optional[str]) -> Tuple[str, str]:
"""解析回复目标消息 """解析回复目标消息
Args: Args:
@ -514,7 +515,7 @@ class DefaultReplyer:
return name, result, duration return name, result, duration
def build_s4u_chat_history_prompts( def build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]: ) -> Tuple[str, str]:
""" """
构建 s4u 风格的分离对话 prompt 构建 s4u 风格的分离对话 prompt
@ -530,16 +531,16 @@ class DefaultReplyer:
bot_id = str(global_config.bot.qq_account) bot_id = str(global_config.bot.qq_account)
# 过滤消息分离bot和目标用户的对话 vs 其他用户的对话 # 过滤消息分离bot和目标用户的对话 vs 其他用户的对话
for msg_dict in message_list_before_now: for msg in message_list_before_now:
try: try:
msg_user_id = str(msg_dict.get("user_id")) msg_user_id = str(msg.user_info.user_id)
reply_to = msg_dict.get("reply_to", "") reply_to = msg.reply_to
_platform, reply_to_user_id = self._parse_reply_target(reply_to) _platform, reply_to_user_id = self._parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
# bot 和目标用户的对话 # bot 和目标用户的对话
core_dialogue_list.append(msg_dict) core_dialogue_list.append(msg)
except Exception as e: except Exception as e:
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}") logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
# 构建背景对话 prompt # 构建背景对话 prompt
all_dialogue_prompt = "" all_dialogue_prompt = ""
@ -574,7 +575,6 @@ class DefaultReplyer:
core_dialogue_prompt_str = build_readable_messages( core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list, core_dialogue_list,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
read_mark=0.0, read_mark=0.0,
truncate=True, truncate=True,
@ -712,25 +712,20 @@ class DefaultReplyer:
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id,
timestamp=time.time(), timestamp=time.time(),
limit=global_config.chat.max_context_size * 1, limit=global_config.chat.max_context_size * 1,
) )
temp_msg_list_before_long = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_long]
# TODO: 修复!
message_list_before_short = get_raw_msg_before_timestamp_with_chat( message_list_before_short = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id,
timestamp=time.time(), timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33), limit=int(global_config.chat.max_context_size * 0.33),
) )
temp_msg_list_before_short = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_short]
chat_talking_prompt_short = build_readable_messages( chat_talking_prompt_short = build_readable_messages(
temp_msg_list_before_short, message_list_before_short,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="relative", timestamp_mode="relative",
read_mark=0.0, read_mark=0.0,
@ -743,7 +738,7 @@ class DefaultReplyer:
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits" self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
), ),
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"), self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
self._time_and_run_task(self.build_memory_block(temp_msg_list_before_short, target), "memory_block"), self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
self._time_and_run_task( self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info" self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
), ),
@ -827,7 +822,7 @@ class DefaultReplyer:
# 构建分离的对话 prompt # 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts( core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
temp_msg_list_before_long, user_id, sender message_list_before_now_long, user_id, sender
) )
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform: if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
@ -901,11 +896,8 @@ class DefaultReplyer:
timestamp=time.time(), timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15), limit=min(int(global_config.chat.max_context_size * 0.33), 15),
) )
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
chat_talking_prompt_half = build_readable_messages( chat_talking_prompt_half = build_readable_messages(
temp_msg_list_before_now_half, message_list_before_now_half,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="relative", timestamp_mode="relative",
read_mark=0.0, read_mark=0.0,
@ -913,7 +905,7 @@ class DefaultReplyer:
) )
# 并行执行2个构建任务 # 并行执行2个构建任务
(expression_habits_block, selected_expressions), relation_info = await asyncio.gather( (expression_habits_block, _), relation_info = await asyncio.gather(
self.build_expression_habits(chat_talking_prompt_half, target), self.build_expression_habits(chat_talking_prompt_half, target),
self.build_relation_info(sender, target), self.build_relation_info(sender, target),
) )

View File

@ -1,5 +1,5 @@
import copy import copy
from typing import Dict, Any from typing import Any
class BaseDataModel: class BaseDataModel:
@ -7,6 +7,7 @@ class BaseDataModel:
return copy.deepcopy(self) return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any: def temporarily_transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else
""" """
将对象或容器中的 BaseDataModel 子类类对象 BaseDataModel 实例 将对象或容器中的 BaseDataModel 子类类对象 BaseDataModel 实例
递归转换为普通 dict不修改原对象 递归转换为普通 dict不修改原对象

View File

@ -163,11 +163,8 @@ class ChatAction:
limit=15, limit=15,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
tmp_msgs, message_list_before_now,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
read_mark=0.0, read_mark=0.0,
@ -229,11 +226,8 @@ class ChatAction:
limit=10, limit=10,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
tmp_msgs, message_list_before_now,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
read_mark=0.0, read_mark=0.0,

View File

@ -166,11 +166,9 @@ class ChatMood:
limit=10, limit=10,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
tmp_msgs, message_list_before_now,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
read_mark=0.0, read_mark=0.0,
@ -247,11 +245,9 @@ class ChatMood:
limit=5, limit=5,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
tmp_msgs, message_list_before_now,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
read_mark=0.0, read_mark=0.0,

View File

@ -17,6 +17,10 @@ from src.mais4u.mais4u_chat.screen_manager import screen_manager
from src.chat.express.expression_selector import expression_selector from src.chat.express.expression_selector import expression_selector
from .s4u_mood_manager import mood_manager from .s4u_mood_manager import mood_manager
from src.mais4u.mais4u_chat.internal_manager import internal_manager from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.data_models.database_data_model import DatabaseMessages
from typing import List
logger = get_logger("prompt") logger = get_logger("prompt")
@ -97,12 +101,11 @@ class PromptBuilder:
self.activate_messages = "" self.activate_messages = ""
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target): async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
style_habits = [] style_habits = []
# 使用从处理器传来的选中表达方式 # 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个 # LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions ,_ = await expression_selector.select_suitable_expressions_llm( selected_expressions, _ = await expression_selector.select_suitable_expressions_llm(
chat_stream.stream_id, chat_history, max_num=12, target_message=target chat_stream.stream_id, chat_history, max_num=12, target_message=target
) )
@ -122,7 +125,6 @@ class PromptBuilder:
if style_habits_str.strip(): if style_habits_str.strip():
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n" expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
return expression_habits_block return expression_habits_block
async def build_relation_info(self, chat_stream) -> str: async def build_relation_info(self, chat_stream) -> str:
@ -148,9 +150,7 @@ class PromptBuilder:
person_ids.append(person_id) person_ids.append(person_id)
# 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为 # 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为
relation_info_list = [ relation_info_list = [Person(person_id=person_id).build_relationship() for person_id in person_ids]
Person(person_id=person_id).build_relationship() for person_id in person_ids
]
if relation_info := "".join(relation_info_list): if relation_info := "".join(relation_info_list):
relation_prompt = await global_prompt_manager.format_prompt( relation_prompt = await global_prompt_manager.format_prompt(
"relation_prompt", relation_info=relation_info "relation_prompt", relation_info=relation_info
@ -176,38 +176,37 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat( message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id, chat_id=chat_stream.stream_id,
timestamp=time.time(), timestamp=time.time(),
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
limit=300, limit=300,
) )
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}" talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
core_dialogue_list = [] core_dialogue_list: List[DatabaseMessages] = []
background_dialogue_list = [] background_dialogue_list: List[DatabaseMessages] = []
bot_id = str(global_config.bot.qq_account) bot_id = str(global_config.bot.qq_account)
target_user_id = str(message.chat_stream.user_info.user_id) target_user_id = str(message.chat_stream.user_info.user_id)
# TODO: 修复之!
for msg in message_list_before_now: for msg in message_list_before_now:
try: try:
msg_user_id = str(msg.user_info.user_id) msg_user_id = str(msg.user_info.user_id)
if msg_user_id == bot_id: if msg_user_id == bot_id:
if msg.reply_to and talk_type == msg.reply_to: if msg.reply_to and talk_type == msg.reply_to:
core_dialogue_list.append(msg.__dict__) core_dialogue_list.append(msg)
elif msg.reply_to and talk_type != msg.reply_to: elif msg.reply_to and talk_type != msg.reply_to:
background_dialogue_list.append(msg.__dict__) background_dialogue_list.append(msg)
# else: # else:
# background_dialogue_list.append(msg_dict) # background_dialogue_list.append(msg_dict)
elif msg_user_id == target_user_id: elif msg_user_id == target_user_id:
core_dialogue_list.append(msg.__dict__) core_dialogue_list.append(msg)
else: else:
background_dialogue_list.append(msg.__dict__) background_dialogue_list.append(msg)
except Exception as e: except Exception as e:
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}") logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
background_dialogue_prompt = "" background_dialogue_prompt = ""
if background_dialogue_list: if background_dialogue_list:
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length:] context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
background_dialogue_prompt_str = build_readable_messages( background_dialogue_prompt_str = build_readable_messages(
context_msgs, context_msgs,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
@ -217,10 +216,10 @@ class PromptBuilder:
core_msg_str = "" core_msg_str = ""
if core_dialogue_list: if core_dialogue_list:
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length:] core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
first_msg = core_dialogue_list[0] first_msg = core_dialogue_list[0]
start_speaking_user_id = first_msg.get("user_id") start_speaking_user_id = first_msg.user_info.user_id
if start_speaking_user_id == bot_id: if start_speaking_user_id == bot_id:
last_speaking_user_id = bot_id last_speaking_user_id = bot_id
msg_seg_str = "你的发言:\n" msg_seg_str = "你的发言:\n"
@ -229,13 +228,13 @@ class PromptBuilder:
last_speaking_user_id = start_speaking_user_id last_speaking_user_id = start_speaking_user_id
msg_seg_str = "对方的发言:\n" msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n" msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
all_msg_seg_list = [] all_msg_seg_list = []
for msg in core_dialogue_list[1:]: for msg in core_dialogue_list[1:]:
speaker = msg.get("user_id") speaker = msg.user_info.user_id
if speaker == last_speaking_user_id: if speaker == last_speaking_user_id:
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
else: else:
msg_seg_str = f"{msg_seg_str}\n" msg_seg_str = f"{msg_seg_str}\n"
all_msg_seg_list.append(msg_seg_str) all_msg_seg_list.append(msg_seg_str)
@ -252,23 +251,19 @@ class PromptBuilder:
for msg in all_msg_seg_list: for msg in all_msg_seg_list:
core_msg_str += msg core_msg_str += msg
all_dialogue_history = get_raw_msg_before_timestamp_with_chat(
all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id, chat_id=chat_stream.stream_id,
timestamp=time.time(), timestamp=time.time(),
limit=20, limit=20,
) )
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in all_dialogue_prompt]
all_dialogue_prompt_str = build_readable_messages( all_dialogue_prompt_str = build_readable_messages(
tmp_msgs, all_dialogue_history,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
show_pic=False, show_pic=False,
) )
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str
def build_gift_info(self, message: MessageRecvS4U): def build_gift_info(self, message: MessageRecvS4U):
if message.is_gift: if message.is_gift:
@ -283,13 +278,11 @@ class PromptBuilder:
super_chat_manager = get_super_chat_manager() super_chat_manager = get_super_chat_manager()
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id) return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
async def build_prompt_normal( async def build_prompt_normal(
self, self,
message: MessageRecvS4U, message: MessageRecvS4U,
message_txt: str, message_txt: str,
) -> str: ) -> str:
chat_stream = message.chat_stream chat_stream = message.chat_stream
person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id) person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id)
@ -303,12 +296,15 @@ class PromptBuilder:
else: else:
sender_name = f"用户({message.chat_stream.user_info.user_id})" sender_name = f"用户({message.chat_stream.user_info.user_id})"
relation_info_block, memory_block, expression_habits_block = await asyncio.gather( relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
self.build_relation_info(chat_stream), self.build_memory_block(message_txt), self.build_expression_habits(chat_stream, message_txt, sender_name) self.build_relation_info(chat_stream),
self.build_memory_block(message_txt),
self.build_expression_habits(chat_stream, message_txt, sender_name),
) )
core_dialogue_prompt, background_dialogue_prompt,all_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message) core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts(
chat_stream, message
)
gift_info = self.build_gift_info(message) gift_info = self.build_gift_info(message)

View File

@ -99,11 +99,9 @@ class ChatMood:
limit=int(global_config.chat.max_context_size / 3), limit=int(global_config.chat.max_context_size / 3),
limit_mode="last", limit_mode="last",
) )
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
tmp_msgs, message_list_before_now,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
read_mark=0.0, read_mark=0.0,
@ -150,11 +148,9 @@ class ChatMood:
limit=15, limit=15,
limit_mode="last", limit_mode="last",
) )
# TODO: 修复
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
tmp_msgs, message_list_before_now,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
read_mark=0.0, read_mark=0.0,

View File

@ -1,18 +1,21 @@
import json
import traceback
from json_repair import repair_json
from datetime import datetime
from typing import List
from src.common.logger import get_logger from src.common.logger import get_logger
from .person_info import Person from src.common.data_models.database_data_model import DatabaseMessages
import random
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
import json
from json_repair import repair_json
from datetime import datetime
from typing import List, Dict, Any
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import traceback from .person_info import Person
logger = get_logger("relation") logger = get_logger("relation")
def init_prompt(): def init_prompt():
Prompt( Prompt(
""" """
@ -46,7 +49,6 @@ def init_prompt():
"attitude_to_me_prompt", "attitude_to_me_prompt",
) )
Prompt( Prompt(
""" """
你的名字是{bot_name}{bot_name}的别名是{alias_str} 你的名字是{bot_name}{bot_name}的别名是{alias_str}
@ -80,6 +82,7 @@ def init_prompt():
"neuroticism_prompt", "neuroticism_prompt",
) )
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self.relationship_llm = LLMRequest( self.relationship_llm = LLMRequest(
@ -95,18 +98,16 @@ class RelationshipManager:
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"attitude_to_me_prompt", "attitude_to_me_prompt",
bot_name = global_config.bot.nickname, bot_name=global_config.bot.nickname,
alias_str = alias_str, alias_str=alias_str,
person_name = person.person_name, person_name=person.person_name,
nickname = person.nickname, nickname=person.nickname,
readable_messages = readable_messages, readable_messages=readable_messages,
current_time = current_time, current_time=current_time,
) )
attitude, _ = await self.relationship_llm.generate_response_async(prompt=prompt) attitude, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
attitude = repair_json(attitude) attitude = repair_json(attitude)
attitude_data = json.loads(attitude) attitude_data = json.loads(attitude)
@ -119,10 +120,10 @@ class RelationshipManager:
return "" return ""
attitude_score = attitude_data["attitude"] attitude_score = attitude_data["attitude"]
confidence = pow(attitude_data["confidence"],2) confidence = pow(attitude_data["confidence"], 2)
new_confidence = total_confidence + confidence new_confidence = total_confidence + confidence
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence)/new_confidence new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence) / new_confidence
person.attitude_to_me = new_attitude_score person.attitude_to_me = new_attitude_score
person.attitude_to_me_confidence = new_confidence person.attitude_to_me_confidence = new_confidence
@ -138,21 +139,19 @@ class RelationshipManager:
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"neuroticism_prompt", "neuroticism_prompt",
bot_name = global_config.bot.nickname, bot_name=global_config.bot.nickname,
alias_str = alias_str, alias_str=alias_str,
person_name = person.person_name, person_name=person.person_name,
nickname = person.nickname, nickname=person.nickname,
readable_messages = readable_messages, readable_messages=readable_messages,
current_time = current_time, current_time=current_time,
) )
neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt) neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
# logger.info(f"prompt: {prompt}") # logger.info(f"prompt: {prompt}")
# logger.info(f"neuroticism: {neuroticism}") # logger.info(f"neuroticism: {neuroticism}")
neuroticism = repair_json(neuroticism) neuroticism = repair_json(neuroticism)
neuroticism_data = json.loads(neuroticism) neuroticism_data = json.loads(neuroticism)
@ -165,19 +164,20 @@ class RelationshipManager:
return "" return ""
neuroticism_score = neuroticism_data["neuroticism"] neuroticism_score = neuroticism_data["neuroticism"]
confidence = pow(neuroticism_data["confidence"],2) confidence = pow(neuroticism_data["confidence"], 2)
new_confidence = total_confidence + confidence new_confidence = total_confidence + confidence
new_neuroticism_score = (current_neuroticism_score * total_confidence + neuroticism_score * confidence)/new_confidence new_neuroticism_score = (
current_neuroticism_score * total_confidence + neuroticism_score * confidence
) / new_confidence
person.neuroticism = new_neuroticism_score person.neuroticism = new_neuroticism_score
person.neuroticism_confidence = new_confidence person.neuroticism_confidence = new_confidence
return person return person
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]):
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
"""更新用户印象 """更新用户印象
Args: Args:
@ -202,12 +202,11 @@ class RelationshipManager:
# 遍历消息,构建映射 # 遍历消息,构建映射
for msg in user_messages: for msg in user_messages:
if msg.get("user_id") == "system": if msg.user_info.user_id == "system":
continue continue
try: try:
user_id = msg.user_info.user_id
user_id = msg.get("user_id") platform = msg.chat_info.platform
platform = msg.get("chat_info_platform")
assert isinstance(user_id, str) and isinstance(platform, str) assert isinstance(user_id, str) and isinstance(platform, str)
msg_person = Person(user_id=user_id, platform=platform) msg_person = Person(user_id=user_id, platform=platform)
@ -244,7 +243,7 @@ class RelationshipManager:
readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}") readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}")
# await self.get_points( # await self.get_points(
# readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person) # readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person)
await self.get_attitude_to_me(readable_messages=readable_messages, timestamp=timestamp, person=person) await self.get_attitude_to_me(readable_messages=readable_messages, timestamp=timestamp, person=person)
await self.get_neuroticism(readable_messages=readable_messages, timestamp=timestamp, person=person) await self.get_neuroticism(readable_messages=readable_messages, timestamp=timestamp, person=person)
@ -253,9 +252,6 @@ class RelationshipManager:
person.sync_to_database() person.sync_to_database()
def calculate_time_weight(self, point_time: str, current_time: str) -> float: def calculate_time_weight(self, point_time: str, current_time: str) -> float:
"""计算基于时间的权重系数""" """计算基于时间的权重系数"""
try: try:
@ -280,6 +276,7 @@ class RelationshipManager:
logger.error(f"计算时间权重失败: {e}") logger.error(f"计算时间权重失败: {e}")
return 0.5 # 发生错误时返回中等权重 return 0.5 # 发生错误时返回中等权重
init_prompt() init_prompt()
relationship_manager = None relationship_manager = None
@ -290,4 +287,3 @@ def get_relationship_manager():
if relationship_manager is None: if relationship_manager is None:
relationship_manager = RelationshipManager() relationship_manager = RelationshipManager()
return relationship_manager return relationship_manager

View File

@ -412,7 +412,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
def build_readable_messages_to_str( def build_readable_messages_to_str(
messages: List[Dict[str, Any]], messages: List[DatabaseMessages],
replace_bot_name: bool = True, replace_bot_name: bool = True,
timestamp_mode: str = "relative", timestamp_mode: str = "relative",
read_mark: float = 0.0, read_mark: float = 0.0,
@ -440,7 +440,7 @@ def build_readable_messages_to_str(
async def build_readable_messages_with_details( async def build_readable_messages_with_details(
messages: List[Dict[str, Any]], messages: List[DatabaseMessages],
replace_bot_name: bool = True, replace_bot_name: bool = True,
timestamp_mode: str = "relative", timestamp_mode: str = "relative",
truncate: bool = False, truncate: bool = False,

View File

@ -2,13 +2,14 @@ import random
from typing import Tuple from typing import Tuple
# 导入新插件系统 # 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType, ChatMode from src.plugin_system import BaseAction, ActionActivationType
# 导入依赖的系统组件 # 导入依赖的系统组件
from src.common.logger import get_logger from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式 # 导入API模块 - 标准Python包方式
from src.plugin_system.apis import emoji_api, llm_api, message_api from src.plugin_system.apis import emoji_api, llm_api, message_api
# NoReplyAction已集成到heartFC_chat.py中不再需要导入 # NoReplyAction已集成到heartFC_chat.py中不再需要导入
from src.config.config import global_config from src.config.config import global_config
@ -84,11 +85,8 @@ class EmojiAction(BaseAction):
messages_text = "" messages_text = ""
if recent_messages: if recent_messages:
# 使用message_api构建可读的消息字符串 # 使用message_api构建可读的消息字符串
# TODO: 修复
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in recent_messages]
messages_text = message_api.build_readable_messages( messages_text = message_api.build_readable_messages(
messages=tmp_msgs, messages=recent_messages,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
truncate=False, truncate=False,
show_actions=False, show_actions=False,