feat:记忆系统再优化,现在及时构建,并且不会重复构建

pull/1182/head
SengokuCola 2025-08-14 13:13:13 +08:00
parent 3bf476c610
commit bf7419c693
22 changed files with 210 additions and 1314 deletions

View File

@ -18,7 +18,6 @@ from src.chat.chat_loop.hfc_utils import CycleDetail
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.chat.express.expression_learner import expression_learner_manager
from src.person_info.person_info import Person
from src.person_info.group_relationship_manager import get_group_relationship_manager
from src.plugin_system.base.component_types import ChatMode, EventType
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
@ -27,6 +26,8 @@ import math
from src.mais4u.s4u_config import s4u_config
# no_reply逻辑已集成到heartFC_chat.py中不再需要导入
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
# 导入记忆系统
from src.chat.memory_system.Hippocampus import hippocampus_manager
ERROR_LOOP_INFO = {
"loop_plan_info": {
@ -90,7 +91,6 @@ class HeartFChatting:
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
self.group_relationship_manager = get_group_relationship_manager()
self.action_manager = ActionManager()
@ -386,20 +386,19 @@ class HeartFChatting:
await self.relationship_builder.build_relation()
await self.expression_learner.trigger_learning_for_chat()
# 群印象构建:仅在群聊中触发
# if self.chat_stream.group_info and getattr(self.chat_stream.group_info, "group_id", None):
# await self.group_relationship_manager.build_relation(
# chat_id=self.stream_id,
# platform=self.chat_stream.platform
# )
# 记忆构建为当前chat_id构建记忆
try:
await hippocampus_manager.build_memory_for_chat(self.stream_id)
except Exception as e:
logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
if random.random() > global_config.chat.focus_value and mode == ChatMode.FOCUS:
#如果激活度没有激活并且聊天活跃度低有可能不进行plan相当于不在电脑前不进行认真思考
actions = [
{
"action_type": "no_reply",
"reasoning": "选择不回复",
"reasoning": "专注不足",
"action_data": {},
}
]

View File

@ -254,7 +254,7 @@ class ExpressionSelector:
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
# logger.info(f"模型名称: {model_name}")
logger.info(f"LLM返回结果: {content}")
# logger.info(f"LLM返回结果: {content}")
# if reasoning_content:
# logger.info(f"LLM推理: {reasoning_content}")
# else:

View File

@ -3,7 +3,6 @@ from typing import Any, Optional, Dict
from src.common.logger import get_logger
from src.chat.heart_flow.sub_heartflow import SubHeartflow
from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("heartflow")
@ -27,8 +26,6 @@ class Heartflow:
# 注册子心流
self.subheartflows[subheartflow_id] = new_subflow
heartflow_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id
logger.info(f"[{heartflow_name}] 开始接收消息")
return new_subflow
except Exception as e:

View File

@ -35,13 +35,13 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool, list[s
interested_rate = 0.0
with Timer("记忆激活"):
interested_rate, keywords = await hippocampus_manager.get_activate_from_text(
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text,
max_depth= 4,
fast_retrieval=False,
)
message.key_words = keywords
message.key_words_lite = keywords
message.key_words_lite = keywords_lite
logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}")
text_len = len(message.processed_plain_text)

View File

@ -7,24 +7,21 @@ import re
import jieba
import networkx as nx
import numpy as np
from typing import List, Tuple, Set, Coroutine, Any
from typing import List, Tuple, Set, Coroutine, Any, Dict
from collections import Counter
from itertools import combinations
import traceback
from rich.traceback import install
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入
from src.common.logger import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
build_readable_messages,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
) # 导入 build_readable_messages
from src.chat.utils.utils import translate_timestamp_to_human_readable
# 添加cosine_similarity函数
def cosine_similarity(v1, v2):
"""计算余弦相似度"""
@ -334,6 +331,9 @@ class Hippocampus:
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
f"如果确定找不出主题或者没有明显主题,返回<none>。"
)
return prompt
@staticmethod
@ -417,14 +417,17 @@ class Hippocampus:
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
text_length = len(text)
topic_num: int | list[int] = 0
if text_length <= 6:
words = jieba.cut(text)
keywords = [word for word in words if len(word) > 1]
keywords = list(set(keywords))[:3] # 限制最多3个关键词
if keywords:
logger.debug(f"提取关键词: {keywords}")
return keywords
elif text_length <= 12:
words = jieba.cut(text)
keywords_lite = [word for word in words if len(word) > 1]
keywords_lite = list(set(keywords_lite))
if keywords_lite:
logger.debug(f"提取关键词极简版: {keywords_lite}")
if text_length <= 12:
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
elif text_length <= 20:
topic_num = [2, 4] # 11-20字符: 2个关键词 (22.76%的文本)
@ -451,169 +454,7 @@ class Hippocampus:
if keywords:
logger.debug(f"提取关键词: {keywords}")
return keywords
async def get_memory_from_text(
self,
text: str,
max_memory_num: int = 3,
max_memory_length: int = 2,
max_depth: int = 3,
fast_retrieval: bool = False,
) -> list:
"""从文本中提取关键词并获取相关记忆。
Args:
text (str): 输入文本
max_memory_num (int, optional): 返回的记忆条目数量上限默认为3表示最多返回3条与输入文本相关度最高的记忆
max_memory_length (int, optional): 每个主题最多返回的记忆条目数量默认为2表示每个主题最多返回2条相似度最高的记忆
max_depth (int, optional): 记忆检索深度默认为3值越大检索范围越广可以获取更多间接相关的记忆但速度会变慢
fast_retrieval (bool, optional): 是否使用快速检索默认为False
如果为True使用jieba分词和TF-IDF提取关键词速度更快但可能不够准确
如果为False使用LLM提取关键词速度较慢但更准确
Returns:
list: 记忆列表每个元素是一个元组 (topic, memory_content)
- topic: str, 记忆主题
- memory_content: str, 该主题下的完整记忆内容
"""
keywords = await self.get_keywords_from_text(text)
# 过滤掉不存在于记忆图中的关键词
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
if not valid_keywords:
logger.debug("没有找到有效的关键词节点")
return []
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
# 从每个关键词获取记忆
activate_map = {} # 存储每个词的累计激活值
# 对每个关键词进行扩散式检索
for keyword in valid_keywords:
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
# 初始化激活值
activation_values = {keyword: 1.0}
# 记录已访问的节点
visited_nodes = {keyword}
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
nodes_to_process = [(keyword, 1.0, 0)]
while nodes_to_process:
current_node, current_activation, current_depth = nodes_to_process.pop(0)
# 如果激活值小于0或超过最大深度停止扩散
if current_activation <= 0 or current_depth >= max_depth:
continue
# 获取当前节点的所有邻居
neighbors = list(self.memory_graph.G.neighbors(current_node))
for neighbor in neighbors:
if neighbor in visited_nodes:
continue
# 获取连接强度
edge_data = self.memory_graph.G[current_node][neighbor]
strength = edge_data.get("strength", 1)
# 计算新的激活值
new_activation = current_activation - (1 / strength)
if new_activation > 0:
activation_values[neighbor] = new_activation
visited_nodes.add(neighbor)
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
# logger.debug(
# f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
# ) # noqa: E501
# 更新激活映射
for node, activation_value in activation_values.items():
if activation_value > 0:
if node in activate_map:
activate_map[node] += activation_value
else:
activate_map[node] = activation_value
# 输出激活映射
# logger.info("激活映射统计:")
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
# 基于激活值平方的独立概率选择
remember_map = {}
# logger.info("基于激活值平方的归一化选择:")
# 计算所有激活值的平方和
total_squared_activation = sum(activation**2 for activation in activate_map.values())
if total_squared_activation > 0:
# 计算归一化的激活值
normalized_activations = {
node: (activation**2) / total_squared_activation for node, activation in activate_map.items()
}
# 按归一化激活值排序并选择前max_memory_num个
sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num]
# 将选中的节点添加到remember_map
for node, normalized_activation in sorted_nodes:
remember_map[node] = activate_map[node] # 使用原始激活值
logger.debug(
f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})"
)
else:
logger.info("没有有效的激活值")
# 从选中的节点中提取记忆
all_memories = []
# logger.info("开始从选中的节点中提取记忆:")
for node, activation in remember_map.items():
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
node_data = self.memory_graph.G.nodes[node]
memory_items = node_data.get("memory_items", "")
# 直接使用完整的记忆内容
if memory_items:
logger.debug("节点包含完整记忆")
# 计算记忆与输入文本的相似度
memory_words = set(jieba.cut(memory_items))
text_words = set(jieba.cut(text))
all_words = memory_words | text_words
if all_words:
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
v1 = [1 if word in memory_words else 0 for word in all_words]
v2 = [1 if word in text_words else 0 for word in all_words]
_ = cosine_similarity(v1, v2) # 计算但不使用用_表示
# 添加完整记忆到结果中
all_memories.append((node, memory_items, activation))
else:
logger.info("节点没有记忆")
# 去重(基于记忆内容)
logger.debug("开始记忆去重:")
seen_memories = set()
unique_memories = []
for topic, memory_items, activation_value in all_memories:
# memory_items现在是完整的字符串格式
memory = memory_items if memory_items else ""
if memory not in seen_memories:
seen_memories.add(memory)
unique_memories.append((topic, memory_items, activation_value))
logger.debug(f"保留记忆: {memory} (来自节点: {topic}, 激活值: {activation_value:.2f})")
else:
logger.debug(f"跳过重复记忆: {memory} (来自节点: {topic})")
# 转换为(关键词, 记忆)格式
result = []
for topic, memory_items, _ in unique_memories:
# memory_items现在是完整的字符串格式
memory = memory_items if memory_items else ""
result.append((topic, memory))
logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
return result
return keywords,keywords_lite
async def get_memory_from_topic(
self,
@ -771,7 +612,7 @@ class Hippocampus:
return result
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str],list[str]]:
"""从文本中提取关键词并获取相关记忆。
Args:
@ -785,13 +626,13 @@ class Hippocampus:
float: 激活节点数与总节点数的比值
list[str]: 有效的关键词
"""
keywords = await self.get_keywords_from_text(text)
keywords,keywords_lite = await self.get_keywords_from_text(text)
# 过滤掉不存在于记忆图中的关键词
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
if not valid_keywords:
# logger.info("没有找到有效的关键词节点")
return 0, []
return 0, keywords,keywords_lite
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
@ -858,7 +699,7 @@ class Hippocampus:
activation_ratio = activation_ratio * 50
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
return activation_ratio, keywords
return activation_ratio, keywords,keywords_lite
# 负责海马体与其他部分的交互
@ -867,92 +708,6 @@ class EntorhinalCortex:
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
def get_memory_sample(self):
"""从数据库获取记忆样本"""
# 硬编码:每条消息最大记忆次数
max_memorized_time_per_msg = 2
# 创建双峰分布的记忆调度器
sample_scheduler = MemoryBuildScheduler(
n_hours1=global_config.memory.memory_build_distribution[0],
std_hours1=global_config.memory.memory_build_distribution[1],
weight1=global_config.memory.memory_build_distribution[2],
n_hours2=global_config.memory.memory_build_distribution[3],
std_hours2=global_config.memory.memory_build_distribution[4],
weight2=global_config.memory.memory_build_distribution[5],
total_samples=global_config.memory.memory_build_sample_num,
)
timestamps = sample_scheduler.get_timestamp_array()
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False):
logger.debug(f"回忆往事: {readable_timestamp}")
chat_samples = []
for timestamp in timestamps:
if messages := self.random_get_msg_snippet(
timestamp,
global_config.memory.memory_build_sample_length,
max_memorized_time_per_msg,
):
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}")
chat_samples.append(messages)
else:
logger.debug(f"时间戳 {timestamp} 的消息无需记忆")
return chat_samples
@staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟
for _ in range(3):
# 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds
timestamp_start = target_timestamp
timestamp_end = target_timestamp + time_window_seconds
if chosen_message := get_raw_msg_by_timestamp(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=1,
limit_mode="earliest",
):
chat_id: str = chosen_message[0].get("chat_id") # type: ignore
if messages := get_raw_msg_by_timestamp_with_chat(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=chat_size,
limit_mode="earliest",
chat_id=chat_id,
):
# 检查获取到的所有消息是否都未达到最大记忆次数
all_valid = True
for message in messages:
if message.get("memorized_times", 0) >= max_memorized_time_per_msg:
all_valid = False
break
# 如果所有消息都有效
if all_valid:
# 更新数据库中的记忆次数
for message in messages:
# 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0)
# 使用 Peewee 更新记录
Messages.update(memorized_times=current_memorized_times + 1).where(
Messages.message_id == message["message_id"]
).execute()
return messages # 直接返回原始的消息列表
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
# 三次尝试都失败,返回 None
return None
async def sync_memory_to_db(self):
"""将记忆图同步到数据库"""
start_time = time.time()
@ -1407,81 +1162,14 @@ class ParahippocampalGyrus:
similar_topics.sort(key=lambda x: x[1], reverse=True)
similar_topics = similar_topics[:3]
similar_topics_dict[topic] = similar_topics
if global_config.debug.show_prompt:
logger.info(f"prompt: {topic_what_prompt}")
logger.info(f"压缩后的记忆: {compressed_memory}")
logger.info(f"相似主题: {similar_topics_dict}")
return compressed_memory, similar_topics_dict
async def operation_build_memory(self):
# sourcery skip: merge-list-appends-into-extend
logger.info("------------------------------------开始构建记忆--------------------------------------")
start_time = time.time()
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
all_added_nodes = []
all_connected_nodes = []
all_added_edges = []
for i, messages in enumerate(memory_samples, 1):
all_topics = []
compress_rate = global_config.memory.memory_compress_rate
try:
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
except Exception as e:
logger.error(f"压缩记忆时发生错误: {e}")
continue
for topic, memory in compressed_memory:
logger.info(f"取得记忆: {topic} - {memory}")
for topic, similar_topics in similar_topics_dict.items():
logger.debug(f"相似话题: {topic} - {similar_topics}")
current_time = datetime.datetime.now().timestamp()
logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}")
all_added_nodes.extend(topic for topic, _ in compressed_memory)
for topic, memory in compressed_memory:
await self.memory_graph.add_dot(topic, memory, self.hippocampus)
all_topics.append(topic)
if topic in similar_topics_dict:
similar_topics = similar_topics_dict[topic]
for similar_topic, similarity in similar_topics:
if topic != similar_topic:
strength = int(similarity * 10)
logger.debug(f"连接相似节点: {topic}{similar_topic} (强度: {strength})")
all_added_edges.append(f"{topic}-{similar_topic}")
all_connected_nodes.append(topic)
all_connected_nodes.append(similar_topic)
self.memory_graph.G.add_edge(
topic,
similar_topic,
strength=strength,
created_time=current_time,
last_modified=current_time,
)
for topic1, topic2 in combinations(all_topics, 2):
logger.debug(f"连接同批次节点: {topic1}{topic2}")
all_added_edges.append(f"{topic1}-{topic2}")
self.memory_graph.connect_dot(topic1, topic2)
progress = (i / len(memory_samples)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_samples))
bar = "" * filled_length + "-" * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
if all_added_nodes:
logger.info(f"更新记忆: {', '.join(all_added_nodes)}")
if all_added_edges:
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
if all_connected_nodes:
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
end_time = time.time()
logger.info(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------")
async def operation_forget_topic(self, percentage=0.005):
start_time = time.time()
logger.info("[遗忘] 开始检查数据库...")
@ -1650,8 +1338,7 @@ class HippocampusManager:
logger.info(f"""
--------------------------------
记忆系统参数配置:
构建间隔: {global_config.memory.memory_build_interval}|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
记忆构建分布: {global_config.memory.memory_build_distribution}
构建频率: {global_config.memory.memory_build_frequency}|压缩率: {global_config.memory.memory_compress_rate}
遗忘间隔: {global_config.memory.forget_memory_interval}|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
--------------------------------""") # noqa: E501
@ -1663,39 +1350,60 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return self._hippocampus
async def build_memory(self):
"""构建记忆的公共接口"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return await self._hippocampus.parahippocampal_gyrus.operation_build_memory()
async def forget_memory(self, percentage: float = 0.005):
"""遗忘记忆的公共接口"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
async def get_memory_from_text(
self,
text: str,
max_memory_num: int = 3,
max_memory_length: int = 2,
max_depth: int = 3,
fast_retrieval: bool = False,
) -> list:
"""从文本中获取相关记忆的公共接口"""
async def build_memory_for_chat(self, chat_id: str):
"""为指定chat_id构建记忆在heartFC_chat.py中调用"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
try:
response = await self._hippocampus.get_memory_from_text(
text, max_memory_num, max_memory_length, max_depth, fast_retrieval
)
# 检查是否需要构建记忆
logger.info(f"{chat_id} 构建记忆")
if memory_segment_manager.check_and_build_memory_for_chat(chat_id):
logger.info(f"{chat_id} 构建记忆,需要构建记忆")
messages = memory_segment_manager.get_messages_for_memory_build(chat_id, 30 / global_config.memory.memory_build_frequency)
if messages:
logger.info(f"{chat_id} 构建记忆,消息数量: {len(messages)}")
# 调用记忆压缩和构建
compressed_memory, similar_topics_dict = await self._hippocampus.parahippocampal_gyrus.memory_compress(
messages, global_config.memory.memory_compress_rate
)
# 添加记忆节点
current_time = time.time()
for topic, memory in compressed_memory:
await self._hippocampus.memory_graph.add_dot(topic, memory, self._hippocampus)
# 连接相似主题
if topic in similar_topics_dict:
similar_topics = similar_topics_dict[topic]
for similar_topic, similarity in similar_topics:
if topic != similar_topic:
strength = int(similarity * 10)
self._hippocampus.memory_graph.G.add_edge(
topic, similar_topic,
strength=strength,
created_time=current_time,
last_modified=current_time
)
# 同步到数据库
await self._hippocampus.entorhinal_cortex.sync_memory_to_db()
logger.info(f"{chat_id} 构建记忆完成")
return True
except Exception as e:
logger.error(f"文本激活记忆失败: {e}")
response = []
return response
logger.error(f"{chat_id} 构建记忆失败: {e}")
return False
return False
async def get_memory_from_topic(
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
@ -1717,12 +1425,11 @@ class HippocampusManager:
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
try:
response, keywords = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
response, keywords,keywords_lite = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
except Exception as e:
logger.error(f"文本产生激活值失败: {e}")
response = 0.0
keywords = [] # 在异常情况下初始化 keywords 为空列表
return response, keywords
logger.error(traceback.format_exc())
return 0.0, [],[]
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
"""从关键词获取相关记忆的公共接口"""
@ -1741,3 +1448,90 @@ class HippocampusManager:
hippocampus_manager = HippocampusManager()
# 在Hippocampus类中添加新的记忆构建管理器
class MemoryBuilder:
"""记忆构建器
为每个chat_id维护消息缓存和触发机制类似ExpressionLearner
"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.last_update_time: float = time.time()
self.last_processed_time: float = 0.0
def should_trigger_memory_build(self) -> bool:
"""检查是否应该触发记忆构建"""
current_time = time.time()
# 检查时间间隔
time_diff = current_time - self.last_update_time
if time_diff < 600 /global_config.memory.memory_build_frequency:
return False
# 检查消息数量
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_update_time,
timestamp_end=current_time,
)
logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}")
if not recent_messages or len(recent_messages) < 30/global_config.memory.memory_build_frequency :
return False
return True
def get_messages_for_memory_build(self, threshold: int = 25) -> List[Dict[str, Any]]:
"""获取用于记忆构建的消息"""
current_time = time.time()
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_update_time,
timestamp_end=current_time,
limit=threshold,
)
if messages:
# 更新最后处理时间
self.last_processed_time = current_time
self.last_update_time = current_time
return messages or []
class MemorySegmentManager:
"""记忆段管理器
管理所有chat_id的MemoryBuilder实例自动检查和触发记忆构建
"""
def __init__(self):
self.builders: Dict[str, MemoryBuilder] = {}
def get_or_create_builder(self, chat_id: str) -> MemoryBuilder:
"""获取或创建指定chat_id的MemoryBuilder"""
if chat_id not in self.builders:
self.builders[chat_id] = MemoryBuilder(chat_id)
return self.builders[chat_id]
def check_and_build_memory_for_chat(self, chat_id: str) -> bool:
"""检查指定chat_id是否需要构建记忆如果需要则返回True"""
builder = self.get_or_create_builder(chat_id)
return builder.should_trigger_memory_build()
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[Dict[str, Any]]:
"""获取指定chat_id用于记忆构建的消息"""
if chat_id not in self.builders:
return []
return self.builders[chat_id].get_messages_for_memory_build(threshold)
# 创建全局实例
memory_segment_manager = MemorySegmentManager()

View File

@ -105,8 +105,8 @@ class MemoryActivator:
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
)
logger.info(f"当前记忆关键词: {keywords_list}")
logger.info(f"获取到的记忆: {related_memory}")
# logger.info(f"当前记忆关键词: {keywords_list}")
logger.debug(f"获取到的记忆: {related_memory}")
if not related_memory:
logger.debug("海马体没有返回相关记忆")
@ -141,7 +141,7 @@ class MemoryActivator:
# 如果只有少量记忆,直接返回
if len(candidate_memories) <= 2:
logger.info(f"候选记忆较少({len(candidate_memories)}个),直接返回")
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]

View File

@ -1,126 +0,0 @@
import numpy as np
from datetime import datetime, timedelta
from rich.traceback import install
install(extra_lines=3)
class MemoryBuildScheduler:
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
"""
初始化记忆构建调度器
参数:
n_hours1 (float): 第一个分布的均值距离现在的小时数
std_hours1 (float): 第一个分布的标准差小时
weight1 (float): 第一个分布的权重
n_hours2 (float): 第二个分布的均值距离现在的小时数
std_hours2 (float): 第二个分布的标准差小时
weight2 (float): 第二个分布的权重
total_samples (int): 要生成的总时间点数量
"""
# 验证参数
if total_samples <= 0:
raise ValueError("total_samples 必须大于0")
if weight1 < 0 or weight2 < 0:
raise ValueError("权重必须为非负数")
if std_hours1 < 0 or std_hours2 < 0:
raise ValueError("标准差必须为非负数")
# 归一化权重
total_weight = weight1 + weight2
if total_weight == 0:
raise ValueError("权重总和不能为0")
self.weight1 = weight1 / total_weight
self.weight2 = weight2 / total_weight
self.n_hours1 = n_hours1
self.std_hours1 = std_hours1
self.n_hours2 = n_hours2
self.std_hours2 = std_hours2
self.total_samples = total_samples
self.base_time = datetime.now()
def generate_time_samples(self):
"""生成混合分布的时间采样点"""
# 根据权重计算每个分布的样本数
samples1 = max(1, int(self.total_samples * self.weight1))
samples2 = max(1, self.total_samples - samples1) # 确保 samples2 至少为1
# 生成两个正态分布的小时偏移
hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1)
hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2)
# 合并两个分布的偏移
hours_offset = np.concatenate([hours_offset1, hours_offset2])
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
# 按时间排序(从最早到最近)
return sorted(timestamps)
def get_timestamp_array(self):
"""返回时间戳数组"""
timestamps = self.generate_time_samples()
return [int(t.timestamp()) for t in timestamps]
# def print_time_samples(timestamps, show_distribution=True):
# """打印时间样本和分布信息"""
# print(f"\n生成的{len(timestamps)}个时间点分布:")
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
# print("-" * 50)
# now = datetime.now()
# time_diffs = []
# for i, timestamp in enumerate(timestamps, 1):
# hours_diff = (now - timestamp).total_seconds() / 3600
# time_diffs.append(hours_diff)
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
# # 打印统计信息
# print("\n统计信息")
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
# print(f"标准差:{np.std(time_diffs):.2f}小时")
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
# if show_distribution:
# # 计算时间分布的直方图
# hist, bins = np.histogram(time_diffs, bins=40)
# print("\n时间分布每个*代表一个时间点):")
# for i in range(len(hist)):
# if hist[i] > 0:
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
# # 使用示例
# if __name__ == "__main__":
# # 创建一个双峰分布的记忆调度器
# scheduler = MemoryBuildScheduler(
# n_hours1=12, # 第一个分布均值12小时前
# std_hours1=8, # 第一个分布标准差
# weight1=0.7, # 第一个分布权重 70%
# n_hours2=36, # 第二个分布均值36小时前
# std_hours2=24, # 第二个分布标准差
# weight2=0.3, # 第二个分布权重 30%
# total_samples=50, # 总共生成50个时间点
# )
# # 生成时间分布
# timestamps = scheduler.generate_time_samples()
# # 打印结果,包含分布可视化
# print_time_samples(timestamps, show_distribution=True)
# # 打印时间戳数组
# timestamp_array = scheduler.get_timestamp_array()
# print("\n时间戳数组Unix时间戳")
# print("[", end="")
# for i, ts in enumerate(timestamp_array):
# if i > 0:
# print(", ", end="")
# print(ts, end="")
# print("]")

View File

@ -29,7 +29,6 @@ class Message(MessageBase):
chat_stream: "ChatStream" = None # type: ignore
reply: Optional["Message"] = None
processed_plain_text: str = ""
memorized_times: int = 0
def __init__(
self,

View File

@ -119,7 +119,6 @@ class MessageStorage:
# Text content
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=message.memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info,

View File

@ -294,6 +294,9 @@ class DefaultReplyer:
async def build_relation_info(self, sender: str, target: str):
if not global_config.relationship.enable_relationship:
return ""
if sender == global_config.bot.nickname:
return ""
# 获取用户ID
person = Person(person_name = sender)
@ -757,13 +760,19 @@ class DefaultReplyer:
# 处理结果
timing_logs = []
results_dict = {}
almost_zero_str = ""
for name, result, duration in task_results:
results_dict[name] = result
chinese_name = task_name_mapping.get(name, name)
if duration < 0.01:
almost_zero_str += f"{chinese_name},"
continue
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
if duration > 8:
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s请使用更快的模型")
logger.info(f"在回复前的步骤耗时: {'; '.join(timing_logs)}")
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.01s")
expression_habits_block, selected_expressions = results_dict["expression_habits"]
relation_info = results_dict["relation_info"]

View File

@ -642,6 +642,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
person = Person(platform=platform, user_id=user_id)
if not person.is_known:
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
# 如果用户尚未认识则返回False和None
return False, None
person_id = person.person_id
person_name = None

View File

@ -159,7 +159,6 @@ class Messages(BaseModel):
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
display_message = TextField(null=True) # 显示的消息
memorized_times = IntegerField(default=0) # 被记忆的次数
priority_mode = TextField(null=True)
priority_info = TextField(null=True)

View File

@ -598,25 +598,10 @@ class MemoryConfig(ConfigBase):
"""记忆配置类"""
enable_memory: bool = True
memory_build_interval: int = 600
"""记忆构建间隔(秒)"""
memory_build_distribution: tuple[
float,
float,
float,
float,
float,
float,
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
"""记忆构建分布参数分布1均值标准差权重分布2均值标准差权重"""
memory_build_sample_num: int = 8
"""记忆构建采样数量"""
memory_build_sample_length: int = 40
"""记忆构建采样长度"""
"""是否启用记忆系统"""
memory_build_frequency: int = 1
"""记忆构建频率(秒)"""
memory_compress_rate: float = 0.1
"""记忆压缩率"""
@ -630,15 +615,6 @@ class MemoryConfig(ConfigBase):
memory_forget_percentage: float = 0.01
"""记忆遗忘比例"""
consolidate_memory_interval: int = 1000
"""记忆整合间隔(秒)"""
consolidation_similarity_threshold: float = 0.7
"""整合相似度阈值"""
consolidate_memory_percentage: float = 0.01
"""整合检查节点比例"""
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
"""不允许记忆的词列表"""

View File

@ -141,20 +141,14 @@ class MainSystem:
if global_config.memory.enable_memory and self.hippocampus_manager:
tasks.extend(
[
self.build_memory_task(),
# 移除记忆构建的定期调用改为在heartFC_chat.py中调用
# self.build_memory_task(),
self.forget_memory_task(),
]
)
await asyncio.gather(*tasks)
async def build_memory_task(self):
"""记忆构建任务"""
while True:
await asyncio.sleep(global_config.memory.memory_build_interval)
logger.info("正在进行记忆构建")
await self.hippocampus_manager.build_memory() # type: ignore
async def forget_memory_task(self):
"""记忆遗忘任务"""
while True:

View File

@ -40,7 +40,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
if global_config.memory.enable_memory:
with Timer("记忆激活"):
interested_rate,_ = await hippocampus_manager.get_activate_from_text(
interested_rate,_ ,_= await hippocampus_manager.get_activate_from_text(
message.processed_plain_text,
fast_retrieval=True,
)

View File

@ -158,6 +158,9 @@ class PromptBuilder:
return relation_prompt
async def build_memory_block(self, text: str) -> str:
# 待更新记忆系统
return ""
related_memory = await hippocampus_manager.get_memory_from_text(
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
)

View File

@ -1,557 +0,0 @@
import copy
import hashlib
import datetime
import asyncio
import json
from typing import Dict, Union, Optional, List
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import GroupInfo
"""
GroupInfoManager 类方法功能摘要
1. get_group_id - 根据平台和群号生成MD5哈希的唯一group_id
2. create_group_info - 创建新群组信息文档自动合并默认值
3. update_one_field - 更新单个字段值若文档不存在则创建
4. del_one_document - 删除指定group_id的文档
5. get_value - 获取单个字段值返回实际值或默认值
6. get_values - 批量获取字段值任一字段无效则返回空字典
7. add_member - 添加群成员
8. remove_member - 移除群成员
9. get_member_list - 获取群成员列表
"""
logger = get_logger("group_info")
JSON_SERIALIZED_FIELDS = ["member_list", "topic"]
group_info_default = {
"group_id": None,
"group_name": None,
"platform": "unknown",
"group_impression": None,
"member_list": [],
"topic":[],
"create_time": None,
"last_active": None,
"member_count": 0,
}
class GroupInfoManager:
def __init__(self):
self.group_name_list = {}
try:
db.connect(reuse_if_open=True)
# 设置连接池参数
if hasattr(db, "execute_sql"):
# 设置SQLite优化参数
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
db.create_tables([GroupInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 GroupInfo 表创建失败: {e}")
# 初始化时读取所有group_name
try:
for record in GroupInfo.select(GroupInfo.group_id, GroupInfo.group_name).where(
GroupInfo.group_name.is_null(False)
):
if record.group_name:
self.group_name_list[record.group_id] = record.group_name
logger.debug(f"已加载 {len(self.group_name_list)} 个群组名称 (Peewee)")
except Exception as e:
logger.error(f"从 Peewee 加载 group_name_list 失败: {e}")
@staticmethod
def get_group_id(platform: str, group_number: Union[int, str]) -> str:
"""获取群组唯一id"""
# 添加空值检查,防止 platform 为 None 时出错
if platform is None:
platform = "unknown"
elif "-" in platform:
platform = platform.split("-")[1]
components = [platform, str(group_number)]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def is_group_known(self, platform: str, group_number: int):
"""判断是否知道某个群组"""
group_id = self.get_group_id(platform, group_number)
def _db_check_known_sync(g_id: str):
return GroupInfo.get_or_none(GroupInfo.group_id == g_id) is not None
try:
return await asyncio.to_thread(_db_check_known_sync, group_id)
except Exception as e:
logger.error(f"检查群组 {group_id} 是否已知时出错 (Peewee): {e}")
return False
@staticmethod
async def create_group_info(group_id: str, data: Optional[dict] = None):
"""创建一个群组信息项"""
if not group_id:
logger.debug("创建失败group_id不存在")
return
_group_info_default = copy.deepcopy(group_info_default)
model_fields = GroupInfo._meta.fields.keys() # type: ignore
final_data = {"group_id": group_id}
# Start with defaults for all model fields
for key, default_value in _group_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
if key in model_fields:
final_data[key] = value
# Ensure group_id is correctly set from the argument
final_data["group_id"] = group_id
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
def _db_create_sync(g_data: dict):
try:
GroupInfo.create(**g_data)
return True
except Exception as e:
logger.error(f"创建 GroupInfo 记录 {g_data.get('group_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
async def _safe_create_group_info(self, group_id: str, data: Optional[dict] = None):
"""安全地创建群组信息,处理竞态条件"""
if not group_id:
logger.debug("创建失败group_id不存在")
return
_group_info_default = copy.deepcopy(group_info_default)
model_fields = GroupInfo._meta.fields.keys() # type: ignore
final_data = {"group_id": group_id}
# Start with defaults for all model fields
for key, default_value in _group_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
if key in model_fields:
final_data[key] = value
# Ensure group_id is correctly set from the argument
final_data["group_id"] = group_id
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
def _db_safe_create_sync(g_data: dict):
try:
# 首先检查是否已存在
existing = GroupInfo.get_or_none(GroupInfo.group_id == g_data["group_id"])
if existing:
logger.debug(f"群组 {g_data['group_id']} 已存在,跳过创建")
return True
# 尝试创建
GroupInfo.create(**g_data)
return True
except Exception as e:
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建群组 {g_data.get('group_id')},跳过错误")
return True # 其他协程已创建,视为成功
else:
logger.error(f"创建 GroupInfo 记录 {g_data.get('group_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_safe_create_sync, final_data)
async def update_one_field(self, group_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全"""
if field_name not in GroupInfo._meta.fields: # type: ignore
logger.debug(f"更新'{field_name}'失败,未在 GroupInfo Peewee 模型中定义的字段。")
return
processed_value = value
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(value, (list, dict)):
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
elif value is None: # Store None as "[]" for JSON list fields
processed_value = json.dumps([], ensure_ascii=False, indent=None)
def _db_update_sync(g_id: str, f_name: str, val_to_set):
import time
start_time = time.time()
try:
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
query_time = time.time()
if record:
setattr(record, f_name, val_to_set)
record.save()
save_time = time.time()
total_time = save_time - start_time
if total_time > 0.5: # 如果超过500ms就记录日志
logger.warning(
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) group_id={g_id}, field={f_name}"
)
return True, False # Found and updated, no creation needed
else:
total_time = time.time() - start_time
if total_time > 0.5:
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 group_id={g_id}, field={f_name}")
return False, True # Not found, needs creation
except Exception as e:
total_time = time.time() - start_time
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
raise
found, needs_creation = await asyncio.to_thread(_db_update_sync, group_id, field_name, processed_value)
if needs_creation:
logger.info(f"{group_id} 不存在,将新建。")
creation_data = data if data is not None else {}
# Ensure platform and group_number are present for context if available from 'data'
# but primarily, set the field that triggered the update.
# The create_group_info will handle defaults and serialization.
creation_data[field_name] = value # Pass original value to create_group_info
# Ensure platform and group_number are in creation_data if available,
# otherwise create_group_info will use defaults.
if data and "platform" in data:
creation_data["platform"] = data["platform"]
if data and "group_number" in data:
creation_data["group_number"] = data["group_number"]
# 使用安全的创建方法,处理竞态条件
await self._safe_create_group_info(group_id, creation_data)
@staticmethod
async def del_one_document(group_id: str):
"""删除指定 group_id 的文档"""
if not group_id:
logger.debug("删除失败group_id 不能为空")
return
def _db_delete_sync(g_id: str):
try:
query = GroupInfo.delete().where(GroupInfo.group_id == g_id)
deleted_count = query.execute()
return deleted_count
except Exception as e:
logger.error(f"删除 GroupInfo {g_id} 失败 (Peewee): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, group_id)
if deleted_count > 0:
logger.debug(f"删除成功group_id={group_id} (Peewee)")
else:
logger.debug(f"删除失败:未找到 group_id={group_id} 或删除未影响行 (Peewee)")
@staticmethod
async def get_value(group_id: str, field_name: str):
"""获取指定群组指定字段的值"""
default_value_for_field = group_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
def _db_get_value_sync(g_id: str, f_name: str):
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
if record:
val = getattr(record, f_name, None)
if f_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {f_name} for {g_id} 包含无效JSON: {val}. 返回默认值.")
return [] # Default for JSON fields on error
elif val is None: # Field exists in DB but is None
return [] # Default for JSON fields
# If val is already a list/dict (e.g. if somehow set without serialization)
return val # Should ideally not happen if update_one_field is always used
return val
return None # Record not found
try:
value_from_db = await asyncio.to_thread(_db_get_value_sync, group_id, field_name)
if value_from_db is not None:
return value_from_db
if field_name in group_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 group_info_default 中未定义,且在数据库中未找到。")
return None # Ultimate fallback
except Exception as e:
logger.error(f"获取字段 {field_name} for {group_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
return default_value_for_field if field_name in group_info_default else None
@staticmethod
async def get_values(group_id: str, field_names: list) -> dict:
"""获取指定group_id文档的多个字段值若不存在该字段则返回该字段的全局默认值"""
if not group_id:
logger.debug("get_values获取失败group_id不能为空")
return {}
result = {}
def _db_get_record_sync(g_id: str):
return GroupInfo.get_or_none(GroupInfo.group_id == g_id)
record = await asyncio.to_thread(_db_get_record_sync, group_id)
for field_name in field_names:
if field_name not in GroupInfo._meta.fields: # type: ignore
if field_name in group_info_default:
result[field_name] = copy.deepcopy(group_info_default[field_name])
logger.debug(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
else:
logger.debug(f"get_values查询失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
result[field_name] = None
continue
if record:
value = getattr(record, field_name)
if value is not None:
result[field_name] = value
else:
result[field_name] = copy.deepcopy(group_info_default.get(field_name))
else:
result[field_name] = copy.deepcopy(group_info_default.get(field_name))
return result
async def add_member(self, group_id: str, member_info: dict):
"""添加群成员(使用 last_active_time不使用 join_time"""
if not group_id or not member_info:
logger.debug("添加成员失败group_id或member_info不能为空")
return
# 规范化成员字段
normalized_member = dict(member_info)
normalized_member.pop("join_time", None)
if "last_active_time" not in normalized_member:
normalized_member["last_active_time"] = datetime.datetime.now().timestamp()
member_id = normalized_member.get("user_id")
if not member_id:
logger.debug("添加成员失败:缺少 user_id")
return
# 获取当前成员列表
current_members = await self.get_value(group_id, "member_list")
if not isinstance(current_members, list):
current_members = []
# 移除已存在的同 user_id 成员
current_members = [m for m in current_members if m.get("user_id") != member_id]
# 添加新成员
current_members.append(normalized_member)
# 更新成员列表和成员数量
await self.update_one_field(group_id, "member_list", current_members)
await self.update_one_field(group_id, "member_count", len(current_members))
await self.update_one_field(group_id, "last_active", datetime.datetime.now().timestamp())
logger.info(f"群组 {group_id} 添加/更新成员 {normalized_member.get('nickname', member_id)} 成功")
async def remove_member(self, group_id: str, user_id: str):
"""移除群成员"""
if not group_id or not user_id:
logger.debug("移除成员失败group_id或user_id不能为空")
return
# 获取当前成员列表
current_members = await self.get_value(group_id, "member_list")
if not isinstance(current_members, list):
logger.debug(f"群组 {group_id} 成员列表为空或格式错误")
return
# 移除指定成员
original_count = len(current_members)
current_members = [m for m in current_members if m.get("user_id") != user_id]
new_count = len(current_members)
if new_count < original_count:
# 更新成员列表和成员数量
await self.update_one_field(group_id, "member_list", current_members)
await self.update_one_field(group_id, "member_count", new_count)
await self.update_one_field(group_id, "last_active", datetime.datetime.now().timestamp())
logger.info(f"群组 {group_id} 移除成员 {user_id} 成功")
else:
logger.debug(f"群组 {group_id} 中未找到成员 {user_id}")
async def get_member_list(self, group_id: str) -> List[dict]:
"""获取群成员列表"""
if not group_id:
logger.debug("获取成员列表失败group_id不能为空")
return []
members = await self.get_value(group_id, "member_list")
if isinstance(members, list):
return members
return []
async def get_or_create_group(
self, platform: str, group_number: int, group_name: str = None
) -> str:
"""
根据 platform group_number 获取 group_id
如果对应的群组不存在则使用提供的信息创建新群组
使用try-except处理竞态条件避免重复创建错误
"""
group_id = self.get_group_id(platform, group_number)
def _db_get_or_create_sync(g_id: str, init_data: dict):
"""原子性的获取或创建操作"""
# 首先尝试获取现有记录
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
if record:
return record, False # 记录存在,未创建
# 记录不存在,尝试创建
try:
GroupInfo.create(**init_data)
return GroupInfo.get(GroupInfo.group_id == g_id), True # 创建成功
except Exception as e:
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建群组 {g_id},获取现有记录")
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常
raise e
initial_data = {
"group_id": group_id,
"platform": platform,
"group_number": str(group_number),
"group_name": group_name,
"create_time": datetime.datetime.now().timestamp(),
"last_active": datetime.datetime.now().timestamp(),
"member_count": 0,
"member_list": [],
"group_info": {},
}
# 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
if isinstance(initial_data[key], (list, dict)):
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
elif initial_data[key] is None:
initial_data[key] = json.dumps([], ensure_ascii=False)
model_fields = GroupInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, group_id, filtered_initial_data)
if was_created:
logger.info(f"群组 {platform}:{group_number} (group_id: {group_id}) 不存在,将创建新记录 (Peewee)。")
logger.info(f"已为 {group_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
else:
logger.debug(f"群组 {platform}:{group_number} (group_id: {group_id}) 已存在,返回现有记录。")
return group_id
async def get_group_info_by_name(self, group_name: str) -> dict | None:
"""根据 group_name 查找群组并返回基本信息 (如果找到)"""
if not group_name:
logger.debug("get_group_info_by_name 获取失败group_name 不能为空")
return None
found_group_id = None
for gid, name_in_cache in self.group_name_list.items():
if name_in_cache == group_name:
found_group_id = gid
break
if not found_group_id:
def _db_find_by_name_sync(g_name_to_find: str):
return GroupInfo.get_or_none(GroupInfo.group_name == g_name_to_find)
record = await asyncio.to_thread(_db_find_by_name_sync, group_name)
if record:
found_group_id = record.group_id
if (
found_group_id not in self.group_name_list
or self.group_name_list[found_group_id] != group_name
):
self.group_name_list[found_group_id] = group_name
else:
logger.debug(f"数据库中也未找到名为 '{group_name}' 的群组 (Peewee)")
return None
if found_group_id:
required_fields = [
"group_id",
"platform",
"group_number",
"group_name",
"group_impression",
"short_impression",
"member_count",
"create_time",
"last_active",
]
valid_fields_to_get = [
f
for f in required_fields
if f in GroupInfo._meta.fields or f in group_info_default # type: ignore
]
group_data = await self.get_values(found_group_id, valid_fields_to_get)
if group_data:
final_result = {key: group_data.get(key) for key in required_fields}
return final_result
else:
logger.warning(f"找到了 group_id '{found_group_id}' 但 get_values 返回空 (Peewee)")
return None
logger.error(f"逻辑错误:未能为 '{group_name}' 确定 group_id (Peewee)")
return None
group_info_manager = None
def get_group_info_manager():
global group_info_manager
if group_info_manager is None:
group_info_manager = GroupInfoManager()
return group_info_manager

View File

@ -1,183 +0,0 @@
import time
import json
import re
import asyncio
from typing import Any, Optional
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive,
build_readable_messages,
)
from src.person_info.group_info import get_group_info_manager
from src.plugin_system.apis import message_api
from json_repair import repair_json
logger = get_logger("group_relationship_manager")
class GroupRelationshipManager:
def __init__(self):
self.group_llm = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="relationship.group"
)
self.last_group_impression_time = 0.0
self.last_group_impression_message_count = 0
async def build_relation(self, chat_id: str, platform: str) -> None:
"""构建群关系,类似 relationship_builder.build_relation() 的调用方式"""
current_time = time.time()
talk_frequency = global_config.chat.get_current_talk_frequency(chat_id)
# 计算间隔时间基于活跃度动态调整最小10分钟最大30分钟
interval_seconds = max(600, int(1800 / max(0.5, talk_frequency)))
# 统计新消息数量
# 先获取所有新消息,然后过滤掉麦麦的消息和命令消息
all_new_messages = message_api.get_messages_by_time_in_chat(
chat_id=chat_id,
start_time=self.last_group_impression_time,
end_time=current_time,
filter_mai=True,
filter_command=True,
)
new_messages_since_last_impression = len(all_new_messages)
# 触发条件:时间间隔 OR 消息数量阈值
if (current_time - self.last_group_impression_time >= interval_seconds) or \
(new_messages_since_last_impression >= 100):
logger.info(f"[{chat_id}] 触发群印象构建 (时间间隔: {current_time - self.last_group_impression_time:.0f}s, 消息数: {new_messages_since_last_impression})")
# 异步执行群印象构建
asyncio.create_task(
self.build_group_impression(
chat_id=chat_id,
platform=platform,
lookback_hours=12,
max_messages=300
)
)
self.last_group_impression_time = current_time
self.last_group_impression_message_count = 0
else:
# 更新消息计数
self.last_group_impression_message_count = new_messages_since_last_impression
logger.debug(f"[{chat_id}] 群印象构建等待中 (时间: {current_time - self.last_group_impression_time:.0f}s/{interval_seconds}s, 消息: {new_messages_since_last_impression}/100)")
async def build_group_impression(
self,
chat_id: str,
platform: str,
lookback_hours: int = 24,
max_messages: int = 300,
) -> Optional[str]:
"""基于最近聊天记录构建群印象并存储
返回生成的topic
"""
now = time.time()
start_ts = now - lookback_hours * 3600
# 拉取最近消息(包含边界)
messages = get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_ts, now)
if not messages:
logger.info(f"[{chat_id}] 无近期消息,跳过群印象构建")
return None
# 限制数量,优先最新
messages = sorted(messages, key=lambda m: m.get("time", 0))[-max_messages:]
# 构建可读文本
readable = build_readable_messages(
messages=messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
)
if not readable:
logger.info(f"[{chat_id}] 构建可读消息文本为空,跳过")
return None
# 确保群存在
group_info_manager = get_group_info_manager()
group_id = await group_info_manager.get_or_create_group(platform, chat_id)
group_name = await group_info_manager.get_value(group_id, "group_name") or chat_id
alias_str = ", ".join(global_config.bot.alias_names)
prompt = f"""
你的名字是{global_config.bot.nickname}{global_config.bot.nickname}的别名是{alias_str}
你现在在群{group_name}平台{platform}
请你根据以下群内最近的聊天记录总结这个群给你的印象
要求
- 关注群的氛围友好/活跃/娱乐/学习/严肃等常见话题互动风格活跃时段或频率是否有显著文化/
- 用白话表达避免夸张或浮夸的词汇语气自然接地气
- 不要暴露任何个人隐私信息
- 请严格按照json格式输出不要有其他多余内容
{{
"impression": "不超过200字的群印象长描述白话、自然",
"topic": "一句话概括群主要聊什么,白话"
}}
群内聊天节选
{readable}
"""
# 生成印象
content, _ = await self.group_llm.generate_response_async(prompt=prompt)
raw_text = (content or "").strip()
def _strip_code_fences(text: str) -> str:
if text.startswith("```") and text.endswith("```"):
# 去除首尾围栏
return re.sub(r"^```[a-zA-Z0-9_\-]*\n|\n```$", "", text, flags=re.S)
# 提取围栏中的主体
match = re.search(r"```[a-zA-Z0-9_\-]*\n([\s\S]*?)\n```", text)
return match.group(1) if match else text
parsed_text = _strip_code_fences(raw_text)
long_impression: str = ""
topic_val: Any = ""
# 参考关系模块先repair_json再loads兼容返回列表/字典/字符串
try:
fixed = repair_json(parsed_text)
data = json.loads(fixed) if isinstance(fixed, str) else fixed
if isinstance(data, list) and data and isinstance(data[0], dict):
data = data[0]
if isinstance(data, dict):
long_impression = str(data.get("impression") or "").strip()
topic_val = data.get("topic", "")
else:
# 不是字典,直接作为文本
text_fallback = str(data)
long_impression = text_fallback[:400].strip()
topic_val = ""
except Exception:
long_impression = parsed_text[:400].strip()
topic_val = ""
# 兜底
if not long_impression and not topic_val:
logger.info(f"[{chat_id}] LLM未产生有效群印象跳过")
return None
# 写入数据库
await group_info_manager.update_one_field(group_id, "group_impression", long_impression)
if topic_val:
await group_info_manager.update_one_field(group_id, "topic", topic_val)
await group_info_manager.update_one_field(group_id, "last_active", now)
logger.info(f"[{chat_id}] 群印象更新完成: topic={topic_val}")
return str(topic_val) if topic_val else ""
group_relationship_manager: Optional[GroupRelationshipManager] = None
def get_group_relationship_manager() -> GroupRelationshipManager:
global group_relationship_manager
if group_relationship_manager is None:
group_relationship_manager = GroupRelationshipManager()
return group_relationship_manager

View File

@ -71,7 +71,7 @@ class Person:
person_id = get_person_id(platform, user_id)
if is_person_known(person_id=person_id):
logger.info(f"用户 {nickname} 已存在")
logger.debug(f"用户 {nickname} 已存在")
return Person(person_id=person_id)
# 创建Person实例
@ -148,9 +148,13 @@ class Person:
if not is_person_known(person_id=self.person_id):
self.is_known = False
logger.warning(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
logger.debug(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
self.person_name = f"未知用户{self.person_id[:4]}"
return
# raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
self.is_known = False

View File

@ -300,15 +300,6 @@ class RelationshipBuilder:
return cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0
def force_cleanup_user_segments(self, person_id: str) -> bool:
"""强制清理指定用户的所有消息段"""
if person_id in self.person_engaged_cache:
segments_count = len(self.person_engaged_cache[person_id])
del self.person_engaged_cache[person_id]
self._save_cache()
logger.info(f"{self.log_prefix} 强制清理用户 {person_id}{segments_count} 个消息段")
return True
return False
def get_cache_status(self) -> str:
# sourcery skip: merge-list-append, merge-list-appends-into-extend

View File

@ -1,5 +1,5 @@
from src.common.logger import get_logger
from .person_info import Person,is_person_known
from .person_info import Person
import random
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
@ -272,7 +272,7 @@ class RelationshipManager:
return ""
attitude_score = attitude_data["attitude"]
confidence = attitude_data["confidence"]
confidence = pow(attitude_data["confidence"],2)
new_confidence = total_confidence + confidence
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence)/new_confidence
@ -318,7 +318,7 @@ class RelationshipManager:
return ""
neuroticism_score = neuroticism_data["neuroticism"]
confidence = neuroticism_data["confidence"]
confidence = pow(neuroticism_data["confidence"],2)
new_confidence = total_confidence + confidence

View File

@ -1,5 +1,5 @@
[inner]
version = "6.4.2"
version = "6.4.5"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@ -130,10 +130,7 @@ filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合
[memory]
enable_memory = true # 是否启用记忆系统
memory_build_interval = 600 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
memory_build_sample_length = 30 # 采样长度,数值越高一段记忆内容越丰富
memory_build_frequency = 1 # 记忆构建频率 越高,麦麦学习越多
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
forget_memory_interval = 3000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习