mirror of https://github.com/Mai-with-u/MaiBot.git
feat:记忆系统再优化,现在及时构建,并且不会重复构建
parent
3bf476c610
commit
bf7419c693
|
|
@ -18,7 +18,6 @@ from src.chat.chat_loop.hfc_utils import CycleDetail
|
|||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.person_info.group_relationship_manager import get_group_relationship_manager
|
||||
from src.plugin_system.base.component_types import ChatMode, EventType
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
|
|
@ -27,6 +26,8 @@ import math
|
|||
from src.mais4u.s4u_config import s4u_config
|
||||
# no_reply逻辑已集成到heartFC_chat.py中,不再需要导入
|
||||
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
|
||||
# 导入记忆系统
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
|
|
@ -90,7 +91,6 @@ class HeartFChatting:
|
|||
|
||||
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
self.group_relationship_manager = get_group_relationship_manager()
|
||||
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
|
|
@ -386,20 +386,19 @@ class HeartFChatting:
|
|||
await self.relationship_builder.build_relation()
|
||||
await self.expression_learner.trigger_learning_for_chat()
|
||||
|
||||
# 群印象构建:仅在群聊中触发
|
||||
# if self.chat_stream.group_info and getattr(self.chat_stream.group_info, "group_id", None):
|
||||
# await self.group_relationship_manager.build_relation(
|
||||
# chat_id=self.stream_id,
|
||||
# platform=self.chat_stream.platform
|
||||
# )
|
||||
|
||||
# 记忆构建:为当前chat_id构建记忆
|
||||
try:
|
||||
await hippocampus_manager.build_memory_for_chat(self.stream_id)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
|
||||
|
||||
|
||||
if random.random() > global_config.chat.focus_value and mode == ChatMode.FOCUS:
|
||||
#如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考
|
||||
actions = [
|
||||
{
|
||||
"action_type": "no_reply",
|
||||
"reasoning": "选择不回复",
|
||||
"reasoning": "专注不足",
|
||||
"action_data": {},
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -254,7 +254,7 @@ class ExpressionSelector:
|
|||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
||||
|
||||
# logger.info(f"模型名称: {model_name}")
|
||||
logger.info(f"LLM返回结果: {content}")
|
||||
# logger.info(f"LLM返回结果: {content}")
|
||||
# if reasoning_content:
|
||||
# logger.info(f"LLM推理: {reasoning_content}")
|
||||
# else:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import Any, Optional, Dict
|
|||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
|
|
@ -27,8 +26,6 @@ class Heartflow:
|
|||
|
||||
# 注册子心流
|
||||
self.subheartflows[subheartflow_id] = new_subflow
|
||||
heartflow_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id
|
||||
logger.info(f"[{heartflow_name}] 开始接收消息")
|
||||
|
||||
return new_subflow
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -35,13 +35,13 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool, list[s
|
|||
interested_rate = 0.0
|
||||
|
||||
with Timer("记忆激活"):
|
||||
interested_rate, keywords = await hippocampus_manager.get_activate_from_text(
|
||||
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
max_depth= 4,
|
||||
fast_retrieval=False,
|
||||
)
|
||||
message.key_words = keywords
|
||||
message.key_words_lite = keywords
|
||||
message.key_words_lite = keywords_lite
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}")
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
|
|
|
|||
|
|
@ -7,24 +7,21 @@ import re
|
|||
import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Set, Coroutine, Any
|
||||
from typing import List, Tuple, Set, Coroutine, Any, Dict
|
||||
from collections import Counter
|
||||
from itertools import combinations
|
||||
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||
from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
) # 导入 build_readable_messages
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
# 添加cosine_similarity函数
|
||||
def cosine_similarity(v1, v2):
|
||||
"""计算余弦相似度"""
|
||||
|
|
@ -334,6 +331,9 @@ class Hippocampus:
|
|||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||
f"如果确定找不出主题或者没有明显主题,返回<none>。"
|
||||
)
|
||||
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -417,14 +417,17 @@ class Hippocampus:
|
|||
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
|
||||
text_length = len(text)
|
||||
topic_num: int | list[int] = 0
|
||||
if text_length <= 6:
|
||||
words = jieba.cut(text)
|
||||
keywords = [word for word in words if len(word) > 1]
|
||||
keywords = list(set(keywords))[:3] # 限制最多3个关键词
|
||||
if keywords:
|
||||
logger.debug(f"提取关键词: {keywords}")
|
||||
return keywords
|
||||
elif text_length <= 12:
|
||||
|
||||
|
||||
words = jieba.cut(text)
|
||||
keywords_lite = [word for word in words if len(word) > 1]
|
||||
keywords_lite = list(set(keywords_lite))
|
||||
if keywords_lite:
|
||||
logger.debug(f"提取关键词极简版: {keywords_lite}")
|
||||
|
||||
|
||||
|
||||
if text_length <= 12:
|
||||
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
|
||||
elif text_length <= 20:
|
||||
topic_num = [2, 4] # 11-20字符: 2个关键词 (22.76%的文本)
|
||||
|
|
@ -451,169 +454,7 @@ class Hippocampus:
|
|||
if keywords:
|
||||
logger.debug(f"提取关键词: {keywords}")
|
||||
|
||||
return keywords
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
text: str,
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3,
|
||||
fast_retrieval: bool = False,
|
||||
) -> list:
|
||||
"""从文本中提取关键词并获取相关记忆。
|
||||
|
||||
Args:
|
||||
text (str): 输入文本
|
||||
max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3,表示最多返回3条与输入文本相关度最高的记忆。
|
||||
max_memory_length (int, optional): 每个主题最多返回的记忆条目数量。默认为2,表示每个主题最多返回2条相似度最高的记忆。
|
||||
max_depth (int, optional): 记忆检索深度。默认为3。值越大,检索范围越广,可以获取更多间接相关的记忆,但速度会变慢。
|
||||
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
|
||||
如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。
|
||||
如果为False,使用LLM提取关键词,速度较慢但更准确。
|
||||
|
||||
Returns:
|
||||
list: 记忆列表,每个元素是一个元组 (topic, memory_content)
|
||||
- topic: str, 记忆主题
|
||||
- memory_content: str, 该主题下的完整记忆内容
|
||||
"""
|
||||
keywords = await self.get_keywords_from_text(text)
|
||||
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
if not valid_keywords:
|
||||
logger.debug("没有找到有效的关键词节点")
|
||||
return []
|
||||
|
||||
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
# 从每个关键词获取记忆
|
||||
activate_map = {} # 存储每个词的累计激活值
|
||||
|
||||
# 对每个关键词进行扩散式检索
|
||||
for keyword in valid_keywords:
|
||||
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
|
||||
# 初始化激活值
|
||||
activation_values = {keyword: 1.0}
|
||||
# 记录已访问的节点
|
||||
visited_nodes = {keyword}
|
||||
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
||||
nodes_to_process = [(keyword, 1.0, 0)]
|
||||
|
||||
while nodes_to_process:
|
||||
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
||||
|
||||
# 如果激活值小于0或超过最大深度,停止扩散
|
||||
if current_activation <= 0 or current_depth >= max_depth:
|
||||
continue
|
||||
|
||||
# 获取当前节点的所有邻居
|
||||
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
||||
|
||||
for neighbor in neighbors:
|
||||
if neighbor in visited_nodes:
|
||||
continue
|
||||
|
||||
# 获取连接强度
|
||||
edge_data = self.memory_graph.G[current_node][neighbor]
|
||||
strength = edge_data.get("strength", 1)
|
||||
|
||||
# 计算新的激活值
|
||||
new_activation = current_activation - (1 / strength)
|
||||
|
||||
if new_activation > 0:
|
||||
activation_values[neighbor] = new_activation
|
||||
visited_nodes.add(neighbor)
|
||||
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||
# logger.debug(
|
||||
# f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
|
||||
# ) # noqa: E501
|
||||
|
||||
# 更新激活映射
|
||||
for node, activation_value in activation_values.items():
|
||||
if activation_value > 0:
|
||||
if node in activate_map:
|
||||
activate_map[node] += activation_value
|
||||
else:
|
||||
activate_map[node] = activation_value
|
||||
|
||||
# 输出激活映射
|
||||
# logger.info("激活映射统计:")
|
||||
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
||||
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
|
||||
|
||||
# 基于激活值平方的独立概率选择
|
||||
remember_map = {}
|
||||
# logger.info("基于激活值平方的归一化选择:")
|
||||
|
||||
# 计算所有激活值的平方和
|
||||
total_squared_activation = sum(activation**2 for activation in activate_map.values())
|
||||
if total_squared_activation > 0:
|
||||
# 计算归一化的激活值
|
||||
normalized_activations = {
|
||||
node: (activation**2) / total_squared_activation for node, activation in activate_map.items()
|
||||
}
|
||||
|
||||
# 按归一化激活值排序并选择前max_memory_num个
|
||||
sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num]
|
||||
|
||||
# 将选中的节点添加到remember_map
|
||||
for node, normalized_activation in sorted_nodes:
|
||||
remember_map[node] = activate_map[node] # 使用原始激活值
|
||||
logger.debug(
|
||||
f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})"
|
||||
)
|
||||
else:
|
||||
logger.info("没有有效的激活值")
|
||||
|
||||
# 从选中的节点中提取记忆
|
||||
all_memories = []
|
||||
# logger.info("开始从选中的节点中提取记忆:")
|
||||
for node, activation in remember_map.items():
|
||||
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
memory_items = node_data.get("memory_items", "")
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
logger.debug("节点包含完整记忆")
|
||||
# 计算记忆与输入文本的相似度
|
||||
memory_words = set(jieba.cut(memory_items))
|
||||
text_words = set(jieba.cut(text))
|
||||
all_words = memory_words | text_words
|
||||
if all_words:
|
||||
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
|
||||
v1 = [1 if word in memory_words else 0 for word in all_words]
|
||||
v2 = [1 if word in text_words else 0 for word in all_words]
|
||||
_ = cosine_similarity(v1, v2) # 计算但不使用,用_表示
|
||||
|
||||
# 添加完整记忆到结果中
|
||||
all_memories.append((node, memory_items, activation))
|
||||
else:
|
||||
logger.info("节点没有记忆")
|
||||
|
||||
# 去重(基于记忆内容)
|
||||
logger.debug("开始记忆去重:")
|
||||
seen_memories = set()
|
||||
unique_memories = []
|
||||
for topic, memory_items, activation_value in all_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items if memory_items else ""
|
||||
if memory not in seen_memories:
|
||||
seen_memories.add(memory)
|
||||
unique_memories.append((topic, memory_items, activation_value))
|
||||
logger.debug(f"保留记忆: {memory} (来自节点: {topic}, 激活值: {activation_value:.2f})")
|
||||
else:
|
||||
logger.debug(f"跳过重复记忆: {memory} (来自节点: {topic})")
|
||||
|
||||
# 转换为(关键词, 记忆)格式
|
||||
result = []
|
||||
for topic, memory_items, _ in unique_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items if memory_items else ""
|
||||
result.append((topic, memory))
|
||||
logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
|
||||
|
||||
return result
|
||||
return keywords,keywords_lite
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self,
|
||||
|
|
@ -771,7 +612,7 @@ class Hippocampus:
|
|||
|
||||
return result
|
||||
|
||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
|
||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str],list[str]]:
|
||||
"""从文本中提取关键词并获取相关记忆。
|
||||
|
||||
Args:
|
||||
|
|
@ -785,13 +626,13 @@ class Hippocampus:
|
|||
float: 激活节点数与总节点数的比值
|
||||
list[str]: 有效的关键词
|
||||
"""
|
||||
keywords = await self.get_keywords_from_text(text)
|
||||
keywords,keywords_lite = await self.get_keywords_from_text(text)
|
||||
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
if not valid_keywords:
|
||||
# logger.info("没有找到有效的关键词节点")
|
||||
return 0, []
|
||||
return 0, keywords,keywords_lite
|
||||
|
||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
|
|
@ -858,7 +699,7 @@ class Hippocampus:
|
|||
activation_ratio = activation_ratio * 50
|
||||
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
|
||||
|
||||
return activation_ratio, keywords
|
||||
return activation_ratio, keywords,keywords_lite
|
||||
|
||||
|
||||
# 负责海马体与其他部分的交互
|
||||
|
|
@ -867,92 +708,6 @@ class EntorhinalCortex:
|
|||
self.hippocampus = hippocampus
|
||||
self.memory_graph = hippocampus.memory_graph
|
||||
|
||||
def get_memory_sample(self):
|
||||
"""从数据库获取记忆样本"""
|
||||
# 硬编码:每条消息最大记忆次数
|
||||
max_memorized_time_per_msg = 2
|
||||
|
||||
# 创建双峰分布的记忆调度器
|
||||
sample_scheduler = MemoryBuildScheduler(
|
||||
n_hours1=global_config.memory.memory_build_distribution[0],
|
||||
std_hours1=global_config.memory.memory_build_distribution[1],
|
||||
weight1=global_config.memory.memory_build_distribution[2],
|
||||
n_hours2=global_config.memory.memory_build_distribution[3],
|
||||
std_hours2=global_config.memory.memory_build_distribution[4],
|
||||
weight2=global_config.memory.memory_build_distribution[5],
|
||||
total_samples=global_config.memory.memory_build_sample_num,
|
||||
)
|
||||
|
||||
timestamps = sample_scheduler.get_timestamp_array()
|
||||
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
|
||||
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
|
||||
for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False):
|
||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||
chat_samples = []
|
||||
for timestamp in timestamps:
|
||||
if messages := self.random_get_msg_snippet(
|
||||
timestamp,
|
||||
global_config.memory.memory_build_sample_length,
|
||||
max_memorized_time_per_msg,
|
||||
):
|
||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||
chat_samples.append(messages)
|
||||
else:
|
||||
logger.debug(f"时间戳 {timestamp} 的消息无需记忆")
|
||||
|
||||
return chat_samples
|
||||
|
||||
@staticmethod
|
||||
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
||||
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
||||
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
||||
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
||||
|
||||
for _ in range(3):
|
||||
# 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds
|
||||
timestamp_start = target_timestamp
|
||||
timestamp_end = target_timestamp + time_window_seconds
|
||||
|
||||
if chosen_message := get_raw_msg_by_timestamp(
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
limit=1,
|
||||
limit_mode="earliest",
|
||||
):
|
||||
chat_id: str = chosen_message[0].get("chat_id") # type: ignore
|
||||
|
||||
if messages := get_raw_msg_by_timestamp_with_chat(
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
limit=chat_size,
|
||||
limit_mode="earliest",
|
||||
chat_id=chat_id,
|
||||
):
|
||||
# 检查获取到的所有消息是否都未达到最大记忆次数
|
||||
all_valid = True
|
||||
for message in messages:
|
||||
if message.get("memorized_times", 0) >= max_memorized_time_per_msg:
|
||||
all_valid = False
|
||||
break
|
||||
|
||||
# 如果所有消息都有效
|
||||
if all_valid:
|
||||
# 更新数据库中的记忆次数
|
||||
for message in messages:
|
||||
# 确保在更新前获取最新的 memorized_times
|
||||
current_memorized_times = message.get("memorized_times", 0)
|
||||
# 使用 Peewee 更新记录
|
||||
Messages.update(memorized_times=current_memorized_times + 1).where(
|
||||
Messages.message_id == message["message_id"]
|
||||
).execute()
|
||||
return messages # 直接返回原始的消息列表
|
||||
|
||||
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
||||
|
||||
# 三次尝试都失败,返回 None
|
||||
return None
|
||||
|
||||
async def sync_memory_to_db(self):
|
||||
"""将记忆图同步到数据库"""
|
||||
start_time = time.time()
|
||||
|
|
@ -1407,81 +1162,14 @@ class ParahippocampalGyrus:
|
|||
similar_topics.sort(key=lambda x: x[1], reverse=True)
|
||||
similar_topics = similar_topics[:3]
|
||||
similar_topics_dict[topic] = similar_topics
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"prompt: {topic_what_prompt}")
|
||||
logger.info(f"压缩后的记忆: {compressed_memory}")
|
||||
logger.info(f"相似主题: {similar_topics_dict}")
|
||||
|
||||
return compressed_memory, similar_topics_dict
|
||||
|
||||
async def operation_build_memory(self):
|
||||
# sourcery skip: merge-list-appends-into-extend
|
||||
logger.info("------------------------------------开始构建记忆--------------------------------------")
|
||||
start_time = time.time()
|
||||
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
|
||||
all_added_nodes = []
|
||||
all_connected_nodes = []
|
||||
all_added_edges = []
|
||||
for i, messages in enumerate(memory_samples, 1):
|
||||
all_topics = []
|
||||
compress_rate = global_config.memory.memory_compress_rate
|
||||
try:
|
||||
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
||||
except Exception as e:
|
||||
logger.error(f"压缩记忆时发生错误: {e}")
|
||||
continue
|
||||
for topic, memory in compressed_memory:
|
||||
logger.info(f"取得记忆: {topic} - {memory}")
|
||||
for topic, similar_topics in similar_topics_dict.items():
|
||||
logger.debug(f"相似话题: {topic} - {similar_topics}")
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}")
|
||||
all_added_nodes.extend(topic for topic, _ in compressed_memory)
|
||||
|
||||
for topic, memory in compressed_memory:
|
||||
await self.memory_graph.add_dot(topic, memory, self.hippocampus)
|
||||
all_topics.append(topic)
|
||||
|
||||
if topic in similar_topics_dict:
|
||||
similar_topics = similar_topics_dict[topic]
|
||||
for similar_topic, similarity in similar_topics:
|
||||
if topic != similar_topic:
|
||||
strength = int(similarity * 10)
|
||||
|
||||
logger.debug(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})")
|
||||
all_added_edges.append(f"{topic}-{similar_topic}")
|
||||
|
||||
all_connected_nodes.append(topic)
|
||||
all_connected_nodes.append(similar_topic)
|
||||
|
||||
self.memory_graph.G.add_edge(
|
||||
topic,
|
||||
similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time,
|
||||
)
|
||||
|
||||
for topic1, topic2 in combinations(all_topics, 2):
|
||||
logger.debug(f"连接同批次节点: {topic1} 和 {topic2}")
|
||||
all_added_edges.append(f"{topic1}-{topic2}")
|
||||
self.memory_graph.connect_dot(topic1, topic2)
|
||||
|
||||
progress = (i / len(memory_samples)) * 100
|
||||
bar_length = 30
|
||||
filled_length = int(bar_length * i // len(memory_samples))
|
||||
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
||||
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
||||
|
||||
if all_added_nodes:
|
||||
logger.info(f"更新记忆: {', '.join(all_added_nodes)}")
|
||||
if all_added_edges:
|
||||
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
|
||||
if all_connected_nodes:
|
||||
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
|
||||
|
||||
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------")
|
||||
|
||||
async def operation_forget_topic(self, percentage=0.005):
|
||||
start_time = time.time()
|
||||
logger.info("[遗忘] 开始检查数据库...")
|
||||
|
|
@ -1650,8 +1338,7 @@ class HippocampusManager:
|
|||
logger.info(f"""
|
||||
--------------------------------
|
||||
记忆系统参数配置:
|
||||
构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
|
||||
记忆构建分布: {global_config.memory.memory_build_distribution}
|
||||
构建频率: {global_config.memory.memory_build_frequency}秒|压缩率: {global_config.memory.memory_compress_rate}
|
||||
遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后
|
||||
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
|
||||
--------------------------------""") # noqa: E501
|
||||
|
|
@ -1663,39 +1350,60 @@ class HippocampusManager:
|
|||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus
|
||||
|
||||
async def build_memory(self):
|
||||
"""构建记忆的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return await self._hippocampus.parahippocampal_gyrus.operation_build_memory()
|
||||
|
||||
async def forget_memory(self, percentage: float = 0.005):
|
||||
"""遗忘记忆的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
|
||||
|
||||
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
text: str,
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3,
|
||||
fast_retrieval: bool = False,
|
||||
) -> list:
|
||||
"""从文本中获取相关记忆的公共接口"""
|
||||
async def build_memory_for_chat(self, chat_id: str):
|
||||
"""为指定chat_id构建记忆(在heartFC_chat.py中调用)"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
|
||||
try:
|
||||
response = await self._hippocampus.get_memory_from_text(
|
||||
text, max_memory_num, max_memory_length, max_depth, fast_retrieval
|
||||
)
|
||||
# 检查是否需要构建记忆
|
||||
logger.info(f"为 {chat_id} 构建记忆")
|
||||
if memory_segment_manager.check_and_build_memory_for_chat(chat_id):
|
||||
logger.info(f"为 {chat_id} 构建记忆,需要构建记忆")
|
||||
messages = memory_segment_manager.get_messages_for_memory_build(chat_id, 30 / global_config.memory.memory_build_frequency)
|
||||
if messages:
|
||||
logger.info(f"为 {chat_id} 构建记忆,消息数量: {len(messages)}")
|
||||
|
||||
# 调用记忆压缩和构建
|
||||
compressed_memory, similar_topics_dict = await self._hippocampus.parahippocampal_gyrus.memory_compress(
|
||||
messages, global_config.memory.memory_compress_rate
|
||||
)
|
||||
|
||||
# 添加记忆节点
|
||||
current_time = time.time()
|
||||
for topic, memory in compressed_memory:
|
||||
await self._hippocampus.memory_graph.add_dot(topic, memory, self._hippocampus)
|
||||
|
||||
# 连接相似主题
|
||||
if topic in similar_topics_dict:
|
||||
similar_topics = similar_topics_dict[topic]
|
||||
for similar_topic, similarity in similar_topics:
|
||||
if topic != similar_topic:
|
||||
strength = int(similarity * 10)
|
||||
self._hippocampus.memory_graph.G.add_edge(
|
||||
topic, similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time
|
||||
)
|
||||
|
||||
# 同步到数据库
|
||||
await self._hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
logger.info(f"为 {chat_id} 构建记忆完成")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文本激活记忆失败: {e}")
|
||||
response = []
|
||||
return response
|
||||
logger.error(f"为 {chat_id} 构建记忆失败: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
|
|
@ -1717,12 +1425,11 @@ class HippocampusManager:
|
|||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
try:
|
||||
response, keywords = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
|
||||
response, keywords,keywords_lite = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
|
||||
except Exception as e:
|
||||
logger.error(f"文本产生激活值失败: {e}")
|
||||
response = 0.0
|
||||
keywords = [] # 在异常情况下初始化 keywords 为空列表
|
||||
return response, keywords
|
||||
logger.error(traceback.format_exc())
|
||||
return 0.0, [],[]
|
||||
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
"""从关键词获取相关记忆的公共接口"""
|
||||
|
|
@ -1741,3 +1448,90 @@ class HippocampusManager:
|
|||
hippocampus_manager = HippocampusManager()
|
||||
|
||||
|
||||
# 在Hippocampus类中添加新的记忆构建管理器
|
||||
class MemoryBuilder:
|
||||
"""记忆构建器
|
||||
|
||||
为每个chat_id维护消息缓存和触发机制,类似ExpressionLearner
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.last_update_time: float = time.time()
|
||||
self.last_processed_time: float = 0.0
|
||||
|
||||
def should_trigger_memory_build(self) -> bool:
|
||||
"""检查是否应该触发记忆构建"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = current_time - self.last_update_time
|
||||
if time_diff < 600 /global_config.memory.memory_build_frequency:
|
||||
return False
|
||||
|
||||
# 检查消息数量
|
||||
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_update_time,
|
||||
timestamp_end=current_time,
|
||||
)
|
||||
|
||||
logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}")
|
||||
|
||||
if not recent_messages or len(recent_messages) < 30/global_config.memory.memory_build_frequency :
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_messages_for_memory_build(self, threshold: int = 25) -> List[Dict[str, Any]]:
|
||||
"""获取用于记忆构建的消息"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_update_time,
|
||||
timestamp_end=current_time,
|
||||
limit=threshold,
|
||||
)
|
||||
|
||||
if messages:
|
||||
# 更新最后处理时间
|
||||
self.last_processed_time = current_time
|
||||
self.last_update_time = current_time
|
||||
|
||||
return messages or []
|
||||
|
||||
|
||||
|
||||
class MemorySegmentManager:
|
||||
"""记忆段管理器
|
||||
|
||||
管理所有chat_id的MemoryBuilder实例,自动检查和触发记忆构建
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.builders: Dict[str, MemoryBuilder] = {}
|
||||
|
||||
def get_or_create_builder(self, chat_id: str) -> MemoryBuilder:
|
||||
"""获取或创建指定chat_id的MemoryBuilder"""
|
||||
if chat_id not in self.builders:
|
||||
self.builders[chat_id] = MemoryBuilder(chat_id)
|
||||
return self.builders[chat_id]
|
||||
|
||||
def check_and_build_memory_for_chat(self, chat_id: str) -> bool:
|
||||
"""检查指定chat_id是否需要构建记忆,如果需要则返回True"""
|
||||
builder = self.get_or_create_builder(chat_id)
|
||||
return builder.should_trigger_memory_build()
|
||||
|
||||
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[Dict[str, Any]]:
|
||||
"""获取指定chat_id用于记忆构建的消息"""
|
||||
if chat_id not in self.builders:
|
||||
return []
|
||||
return self.builders[chat_id].get_messages_for_memory_build(threshold)
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
memory_segment_manager = MemorySegmentManager()
|
||||
|
||||
|
|
|
|||
|
|
@ -105,8 +105,8 @@ class MemoryActivator:
|
|||
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
|
||||
)
|
||||
|
||||
logger.info(f"当前记忆关键词: {keywords_list}")
|
||||
logger.info(f"获取到的记忆: {related_memory}")
|
||||
# logger.info(f"当前记忆关键词: {keywords_list}")
|
||||
logger.debug(f"获取到的记忆: {related_memory}")
|
||||
|
||||
if not related_memory:
|
||||
logger.debug("海马体没有返回相关记忆")
|
||||
|
|
@ -141,7 +141,7 @@ class MemoryActivator:
|
|||
|
||||
# 如果只有少量记忆,直接返回
|
||||
if len(candidate_memories) <= 2:
|
||||
logger.info(f"候选记忆较少({len(candidate_memories)}个),直接返回")
|
||||
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,126 +0,0 @@
|
|||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class MemoryBuildScheduler:
|
||||
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
||||
"""
|
||||
初始化记忆构建调度器
|
||||
|
||||
参数:
|
||||
n_hours1 (float): 第一个分布的均值(距离现在的小时数)
|
||||
std_hours1 (float): 第一个分布的标准差(小时)
|
||||
weight1 (float): 第一个分布的权重
|
||||
n_hours2 (float): 第二个分布的均值(距离现在的小时数)
|
||||
std_hours2 (float): 第二个分布的标准差(小时)
|
||||
weight2 (float): 第二个分布的权重
|
||||
total_samples (int): 要生成的总时间点数量
|
||||
"""
|
||||
# 验证参数
|
||||
if total_samples <= 0:
|
||||
raise ValueError("total_samples 必须大于0")
|
||||
if weight1 < 0 or weight2 < 0:
|
||||
raise ValueError("权重必须为非负数")
|
||||
if std_hours1 < 0 or std_hours2 < 0:
|
||||
raise ValueError("标准差必须为非负数")
|
||||
|
||||
# 归一化权重
|
||||
total_weight = weight1 + weight2
|
||||
if total_weight == 0:
|
||||
raise ValueError("权重总和不能为0")
|
||||
self.weight1 = weight1 / total_weight
|
||||
self.weight2 = weight2 / total_weight
|
||||
|
||||
self.n_hours1 = n_hours1
|
||||
self.std_hours1 = std_hours1
|
||||
self.n_hours2 = n_hours2
|
||||
self.std_hours2 = std_hours2
|
||||
self.total_samples = total_samples
|
||||
self.base_time = datetime.now()
|
||||
|
||||
def generate_time_samples(self):
|
||||
"""生成混合分布的时间采样点"""
|
||||
# 根据权重计算每个分布的样本数
|
||||
samples1 = max(1, int(self.total_samples * self.weight1))
|
||||
samples2 = max(1, self.total_samples - samples1) # 确保 samples2 至少为1
|
||||
|
||||
# 生成两个正态分布的小时偏移
|
||||
hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1)
|
||||
hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2)
|
||||
|
||||
# 合并两个分布的偏移
|
||||
hours_offset = np.concatenate([hours_offset1, hours_offset2])
|
||||
|
||||
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
|
||||
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
|
||||
|
||||
# 按时间排序(从最早到最近)
|
||||
return sorted(timestamps)
|
||||
|
||||
def get_timestamp_array(self):
|
||||
"""返回时间戳数组"""
|
||||
timestamps = self.generate_time_samples()
|
||||
return [int(t.timestamp()) for t in timestamps]
|
||||
|
||||
|
||||
# def print_time_samples(timestamps, show_distribution=True):
|
||||
# """打印时间样本和分布信息"""
|
||||
# print(f"\n生成的{len(timestamps)}个时间点分布:")
|
||||
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
||||
# print("-" * 50)
|
||||
|
||||
# now = datetime.now()
|
||||
# time_diffs = []
|
||||
|
||||
# for i, timestamp in enumerate(timestamps, 1):
|
||||
# hours_diff = (now - timestamp).total_seconds() / 3600
|
||||
# time_diffs.append(hours_diff)
|
||||
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
||||
|
||||
# # 打印统计信息
|
||||
# print("\n统计信息:")
|
||||
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
||||
# print(f"标准差:{np.std(time_diffs):.2f}小时")
|
||||
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
||||
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
||||
|
||||
# if show_distribution:
|
||||
# # 计算时间分布的直方图
|
||||
# hist, bins = np.histogram(time_diffs, bins=40)
|
||||
# print("\n时间分布(每个*代表一个时间点):")
|
||||
# for i in range(len(hist)):
|
||||
# if hist[i] > 0:
|
||||
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
|
||||
|
||||
|
||||
# # 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 创建一个双峰分布的记忆调度器
|
||||
# scheduler = MemoryBuildScheduler(
|
||||
# n_hours1=12, # 第一个分布均值(12小时前)
|
||||
# std_hours1=8, # 第一个分布标准差
|
||||
# weight1=0.7, # 第一个分布权重 70%
|
||||
# n_hours2=36, # 第二个分布均值(36小时前)
|
||||
# std_hours2=24, # 第二个分布标准差
|
||||
# weight2=0.3, # 第二个分布权重 30%
|
||||
# total_samples=50, # 总共生成50个时间点
|
||||
# )
|
||||
|
||||
# # 生成时间分布
|
||||
# timestamps = scheduler.generate_time_samples()
|
||||
|
||||
# # 打印结果,包含分布可视化
|
||||
# print_time_samples(timestamps, show_distribution=True)
|
||||
|
||||
# # 打印时间戳数组
|
||||
# timestamp_array = scheduler.get_timestamp_array()
|
||||
# print("\n时间戳数组(Unix时间戳):")
|
||||
# print("[", end="")
|
||||
# for i, ts in enumerate(timestamp_array):
|
||||
# if i > 0:
|
||||
# print(", ", end="")
|
||||
# print(ts, end="")
|
||||
# print("]")
|
||||
|
|
@ -29,7 +29,6 @@ class Message(MessageBase):
|
|||
chat_stream: "ChatStream" = None # type: ignore
|
||||
reply: Optional["Message"] = None
|
||||
processed_plain_text: str = ""
|
||||
memorized_times: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -119,7 +119,6 @@ class MessageStorage:
|
|||
# Text content
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=message.memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info,
|
||||
|
|
|
|||
|
|
@ -294,6 +294,9 @@ class DefaultReplyer:
|
|||
async def build_relation_info(self, sender: str, target: str):
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
if sender == global_config.bot.nickname:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person = Person(person_name = sender)
|
||||
|
|
@ -757,13 +760,19 @@ class DefaultReplyer:
|
|||
# 处理结果
|
||||
timing_logs = []
|
||||
results_dict = {}
|
||||
|
||||
almost_zero_str = ""
|
||||
for name, result, duration in task_results:
|
||||
results_dict[name] = result
|
||||
chinese_name = task_name_mapping.get(name, name)
|
||||
if duration < 0.01:
|
||||
almost_zero_str += f"{chinese_name},"
|
||||
continue
|
||||
|
||||
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||
if duration > 8:
|
||||
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
||||
logger.info(f"在回复前的步骤耗时: {'; '.join(timing_logs)}")
|
||||
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.01s")
|
||||
|
||||
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
||||
relation_info = results_dict["relation_info"]
|
||||
|
|
|
|||
|
|
@ -642,6 +642,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||
person = Person(platform=platform, user_id=user_id)
|
||||
if not person.is_known:
|
||||
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
|
||||
# 如果用户尚未认识,则返回False和None
|
||||
return False, None
|
||||
person_id = person.person_id
|
||||
person_name = None
|
||||
|
|
|
|||
|
|
@ -159,7 +159,6 @@ class Messages(BaseModel):
|
|||
|
||||
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
|
||||
display_message = TextField(null=True) # 显示的消息
|
||||
memorized_times = IntegerField(default=0) # 被记忆的次数
|
||||
|
||||
priority_mode = TextField(null=True)
|
||||
priority_info = TextField(null=True)
|
||||
|
|
|
|||
|
|
@ -598,25 +598,10 @@ class MemoryConfig(ConfigBase):
|
|||
"""记忆配置类"""
|
||||
|
||||
enable_memory: bool = True
|
||||
|
||||
memory_build_interval: int = 600
|
||||
"""记忆构建间隔(秒)"""
|
||||
|
||||
memory_build_distribution: tuple[
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
|
||||
"""记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重"""
|
||||
|
||||
memory_build_sample_num: int = 8
|
||||
"""记忆构建采样数量"""
|
||||
|
||||
memory_build_sample_length: int = 40
|
||||
"""记忆构建采样长度"""
|
||||
"""是否启用记忆系统"""
|
||||
|
||||
memory_build_frequency: int = 1
|
||||
"""记忆构建频率(秒)"""
|
||||
|
||||
memory_compress_rate: float = 0.1
|
||||
"""记忆压缩率"""
|
||||
|
|
@ -630,15 +615,6 @@ class MemoryConfig(ConfigBase):
|
|||
memory_forget_percentage: float = 0.01
|
||||
"""记忆遗忘比例"""
|
||||
|
||||
consolidate_memory_interval: int = 1000
|
||||
"""记忆整合间隔(秒)"""
|
||||
|
||||
consolidation_similarity_threshold: float = 0.7
|
||||
"""整合相似度阈值"""
|
||||
|
||||
consolidate_memory_percentage: float = 0.01
|
||||
"""整合检查节点比例"""
|
||||
|
||||
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
|
||||
"""不允许记忆的词列表"""
|
||||
|
||||
|
|
|
|||
10
src/main.py
10
src/main.py
|
|
@ -141,20 +141,14 @@ class MainSystem:
|
|||
if global_config.memory.enable_memory and self.hippocampus_manager:
|
||||
tasks.extend(
|
||||
[
|
||||
self.build_memory_task(),
|
||||
# 移除记忆构建的定期调用,改为在heartFC_chat.py中调用
|
||||
# self.build_memory_task(),
|
||||
self.forget_memory_task(),
|
||||
]
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def build_memory_task(self):
|
||||
"""记忆构建任务"""
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.memory_build_interval)
|
||||
logger.info("正在进行记忆构建")
|
||||
await self.hippocampus_manager.build_memory() # type: ignore
|
||||
|
||||
async def forget_memory_task(self):
|
||||
"""记忆遗忘任务"""
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
|||
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
interested_rate,_ = await hippocampus_manager.get_activate_from_text(
|
||||
interested_rate,_ ,_= await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -158,6 +158,9 @@ class PromptBuilder:
|
|||
return relation_prompt
|
||||
|
||||
async def build_memory_block(self, text: str) -> str:
|
||||
# 待更新记忆系统
|
||||
return ""
|
||||
|
||||
related_memory = await hippocampus_manager.get_memory_from_text(
|
||||
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,557 +0,0 @@
|
|||
import copy
|
||||
import hashlib
|
||||
import datetime
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import Dict, Union, Optional, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import GroupInfo
|
||||
|
||||
|
||||
"""
|
||||
GroupInfoManager 类方法功能摘要:
|
||||
1. get_group_id - 根据平台和群号生成MD5哈希的唯一group_id
|
||||
2. create_group_info - 创建新群组信息文档(自动合并默认值)
|
||||
3. update_one_field - 更新单个字段值(若文档不存在则创建)
|
||||
4. del_one_document - 删除指定group_id的文档
|
||||
5. get_value - 获取单个字段值(返回实际值或默认值)
|
||||
6. get_values - 批量获取字段值(任一字段无效则返回空字典)
|
||||
7. add_member - 添加群成员
|
||||
8. remove_member - 移除群成员
|
||||
9. get_member_list - 获取群成员列表
|
||||
"""
|
||||
|
||||
|
||||
logger = get_logger("group_info")
|
||||
|
||||
JSON_SERIALIZED_FIELDS = ["member_list", "topic"]
|
||||
|
||||
group_info_default = {
|
||||
"group_id": None,
|
||||
"group_name": None,
|
||||
"platform": "unknown",
|
||||
"group_impression": None,
|
||||
"member_list": [],
|
||||
"topic":[],
|
||||
"create_time": None,
|
||||
"last_active": None,
|
||||
"member_count": 0,
|
||||
}
|
||||
|
||||
|
||||
class GroupInfoManager:
|
||||
def __init__(self):
|
||||
self.group_name_list = {}
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 设置连接池参数
|
||||
if hasattr(db, "execute_sql"):
|
||||
# 设置SQLite优化参数
|
||||
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
|
||||
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
|
||||
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
|
||||
db.create_tables([GroupInfo], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 GroupInfo 表创建失败: {e}")
|
||||
|
||||
# 初始化时读取所有group_name
|
||||
try:
|
||||
for record in GroupInfo.select(GroupInfo.group_id, GroupInfo.group_name).where(
|
||||
GroupInfo.group_name.is_null(False)
|
||||
):
|
||||
if record.group_name:
|
||||
self.group_name_list[record.group_id] = record.group_name
|
||||
logger.debug(f"已加载 {len(self.group_name_list)} 个群组名称 (Peewee)")
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 加载 group_name_list 失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def get_group_id(platform: str, group_number: Union[int, str]) -> str:
|
||||
"""获取群组唯一id"""
|
||||
# 添加空值检查,防止 platform 为 None 时出错
|
||||
if platform is None:
|
||||
platform = "unknown"
|
||||
elif "-" in platform:
|
||||
platform = platform.split("-")[1]
|
||||
|
||||
components = [platform, str(group_number)]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def is_group_known(self, platform: str, group_number: int):
|
||||
"""判断是否知道某个群组"""
|
||||
group_id = self.get_group_id(platform, group_number)
|
||||
|
||||
def _db_check_known_sync(g_id: str):
|
||||
return GroupInfo.get_or_none(GroupInfo.group_id == g_id) is not None
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_check_known_sync, group_id)
|
||||
except Exception as e:
|
||||
logger.error(f"检查群组 {group_id} 是否已知时出错 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def create_group_info(group_id: str, data: Optional[dict] = None):
|
||||
"""创建一个群组信息项"""
|
||||
if not group_id:
|
||||
logger.debug("创建失败,group_id不存在")
|
||||
return
|
||||
|
||||
_group_info_default = copy.deepcopy(group_info_default)
|
||||
model_fields = GroupInfo._meta.fields.keys() # type: ignore
|
||||
|
||||
final_data = {"group_id": group_id}
|
||||
|
||||
# Start with defaults for all model fields
|
||||
for key, default_value in _group_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
|
||||
# Ensure group_id is correctly set from the argument
|
||||
final_data["group_id"] = group_id
|
||||
|
||||
# Serialize JSON fields
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = json.dumps([], ensure_ascii=False)
|
||||
|
||||
def _db_create_sync(g_data: dict):
|
||||
try:
|
||||
GroupInfo.create(**g_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建 GroupInfo 记录 {g_data.get('group_id')} 失败 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_create_sync, final_data)
|
||||
|
||||
async def _safe_create_group_info(self, group_id: str, data: Optional[dict] = None):
|
||||
"""安全地创建群组信息,处理竞态条件"""
|
||||
if not group_id:
|
||||
logger.debug("创建失败,group_id不存在")
|
||||
return
|
||||
|
||||
_group_info_default = copy.deepcopy(group_info_default)
|
||||
model_fields = GroupInfo._meta.fields.keys() # type: ignore
|
||||
|
||||
final_data = {"group_id": group_id}
|
||||
|
||||
# Start with defaults for all model fields
|
||||
for key, default_value in _group_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
|
||||
# Ensure group_id is correctly set from the argument
|
||||
final_data["group_id"] = group_id
|
||||
|
||||
# Serialize JSON fields
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = json.dumps([], ensure_ascii=False)
|
||||
|
||||
def _db_safe_create_sync(g_data: dict):
|
||||
try:
|
||||
# 首先检查是否已存在
|
||||
existing = GroupInfo.get_or_none(GroupInfo.group_id == g_data["group_id"])
|
||||
if existing:
|
||||
logger.debug(f"群组 {g_data['group_id']} 已存在,跳过创建")
|
||||
return True
|
||||
|
||||
# 尝试创建
|
||||
GroupInfo.create(**g_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建群组 {g_data.get('group_id')},跳过错误")
|
||||
return True # 其他协程已创建,视为成功
|
||||
else:
|
||||
logger.error(f"创建 GroupInfo 记录 {g_data.get('group_id')} 失败 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_safe_create_sync, final_data)
|
||||
|
||||
async def update_one_field(self, group_id: str, field_name: str, value, data: Optional[Dict] = None):
|
||||
"""更新某一个字段,会补全"""
|
||||
if field_name not in GroupInfo._meta.fields: # type: ignore
|
||||
logger.debug(f"更新'{field_name}'失败,未在 GroupInfo Peewee 模型中定义的字段。")
|
||||
return
|
||||
|
||||
processed_value = value
|
||||
if field_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(value, (list, dict)):
|
||||
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
|
||||
elif value is None: # Store None as "[]" for JSON list fields
|
||||
processed_value = json.dumps([], ensure_ascii=False, indent=None)
|
||||
|
||||
def _db_update_sync(g_id: str, f_name: str, val_to_set):
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
|
||||
query_time = time.time()
|
||||
|
||||
if record:
|
||||
setattr(record, f_name, val_to_set)
|
||||
record.save()
|
||||
save_time = time.time()
|
||||
|
||||
total_time = save_time - start_time
|
||||
if total_time > 0.5: # 如果超过500ms就记录日志
|
||||
logger.warning(
|
||||
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) group_id={g_id}, field={f_name}"
|
||||
)
|
||||
|
||||
return True, False # Found and updated, no creation needed
|
||||
else:
|
||||
total_time = time.time() - start_time
|
||||
if total_time > 0.5:
|
||||
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 group_id={g_id}, field={f_name}")
|
||||
return False, True # Not found, needs creation
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
|
||||
raise
|
||||
|
||||
found, needs_creation = await asyncio.to_thread(_db_update_sync, group_id, field_name, processed_value)
|
||||
|
||||
if needs_creation:
|
||||
logger.info(f"{group_id} 不存在,将新建。")
|
||||
creation_data = data if data is not None else {}
|
||||
# Ensure platform and group_number are present for context if available from 'data'
|
||||
# but primarily, set the field that triggered the update.
|
||||
# The create_group_info will handle defaults and serialization.
|
||||
creation_data[field_name] = value # Pass original value to create_group_info
|
||||
|
||||
# Ensure platform and group_number are in creation_data if available,
|
||||
# otherwise create_group_info will use defaults.
|
||||
if data and "platform" in data:
|
||||
creation_data["platform"] = data["platform"]
|
||||
if data and "group_number" in data:
|
||||
creation_data["group_number"] = data["group_number"]
|
||||
|
||||
# 使用安全的创建方法,处理竞态条件
|
||||
await self._safe_create_group_info(group_id, creation_data)
|
||||
|
||||
@staticmethod
|
||||
async def del_one_document(group_id: str):
|
||||
"""删除指定 group_id 的文档"""
|
||||
if not group_id:
|
||||
logger.debug("删除失败:group_id 不能为空")
|
||||
return
|
||||
|
||||
def _db_delete_sync(g_id: str):
|
||||
try:
|
||||
query = GroupInfo.delete().where(GroupInfo.group_id == g_id)
|
||||
deleted_count = query.execute()
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(f"删除 GroupInfo {g_id} 失败 (Peewee): {e}")
|
||||
return 0
|
||||
|
||||
deleted_count = await asyncio.to_thread(_db_delete_sync, group_id)
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"删除成功:group_id={group_id} (Peewee)")
|
||||
else:
|
||||
logger.debug(f"删除失败:未找到 group_id={group_id} 或删除未影响行 (Peewee)")
|
||||
|
||||
@staticmethod
|
||||
async def get_value(group_id: str, field_name: str):
|
||||
"""获取指定群组指定字段的值"""
|
||||
default_value_for_field = group_info_default.get(field_name)
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
|
||||
|
||||
def _db_get_value_sync(g_id: str, f_name: str):
|
||||
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
|
||||
if record:
|
||||
val = getattr(record, f_name, None)
|
||||
if f_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"字段 {f_name} for {g_id} 包含无效JSON: {val}. 返回默认值.")
|
||||
return [] # Default for JSON fields on error
|
||||
elif val is None: # Field exists in DB but is None
|
||||
return [] # Default for JSON fields
|
||||
# If val is already a list/dict (e.g. if somehow set without serialization)
|
||||
return val # Should ideally not happen if update_one_field is always used
|
||||
return val
|
||||
return None # Record not found
|
||||
|
||||
try:
|
||||
value_from_db = await asyncio.to_thread(_db_get_value_sync, group_id, field_name)
|
||||
if value_from_db is not None:
|
||||
return value_from_db
|
||||
if field_name in group_info_default:
|
||||
return default_value_for_field
|
||||
logger.warning(f"字段 {field_name} 在 group_info_default 中未定义,且在数据库中未找到。")
|
||||
return None # Ultimate fallback
|
||||
except Exception as e:
|
||||
logger.error(f"获取字段 {field_name} for {group_id} 时出错 (Peewee): {e}")
|
||||
# Fallback to default in case of any error during DB access
|
||||
return default_value_for_field if field_name in group_info_default else None
|
||||
|
||||
@staticmethod
|
||||
async def get_values(group_id: str, field_names: list) -> dict:
|
||||
"""获取指定group_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||
if not group_id:
|
||||
logger.debug("get_values获取失败:group_id不能为空")
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
|
||||
def _db_get_record_sync(g_id: str):
|
||||
return GroupInfo.get_or_none(GroupInfo.group_id == g_id)
|
||||
|
||||
record = await asyncio.to_thread(_db_get_record_sync, group_id)
|
||||
|
||||
for field_name in field_names:
|
||||
if field_name not in GroupInfo._meta.fields: # type: ignore
|
||||
if field_name in group_info_default:
|
||||
result[field_name] = copy.deepcopy(group_info_default[field_name])
|
||||
logger.debug(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。")
|
||||
else:
|
||||
logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。")
|
||||
result[field_name] = None
|
||||
continue
|
||||
|
||||
if record:
|
||||
value = getattr(record, field_name)
|
||||
if value is not None:
|
||||
result[field_name] = value
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(group_info_default.get(field_name))
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(group_info_default.get(field_name))
|
||||
|
||||
return result
|
||||
|
||||
async def add_member(self, group_id: str, member_info: dict):
|
||||
"""添加群成员(使用 last_active_time,不使用 join_time)"""
|
||||
if not group_id or not member_info:
|
||||
logger.debug("添加成员失败:group_id或member_info不能为空")
|
||||
return
|
||||
|
||||
# 规范化成员字段
|
||||
normalized_member = dict(member_info)
|
||||
normalized_member.pop("join_time", None)
|
||||
if "last_active_time" not in normalized_member:
|
||||
normalized_member["last_active_time"] = datetime.datetime.now().timestamp()
|
||||
|
||||
member_id = normalized_member.get("user_id")
|
||||
if not member_id:
|
||||
logger.debug("添加成员失败:缺少 user_id")
|
||||
return
|
||||
|
||||
# 获取当前成员列表
|
||||
current_members = await self.get_value(group_id, "member_list")
|
||||
if not isinstance(current_members, list):
|
||||
current_members = []
|
||||
|
||||
# 移除已存在的同 user_id 成员
|
||||
current_members = [m for m in current_members if m.get("user_id") != member_id]
|
||||
|
||||
# 添加新成员
|
||||
current_members.append(normalized_member)
|
||||
|
||||
# 更新成员列表和成员数量
|
||||
await self.update_one_field(group_id, "member_list", current_members)
|
||||
await self.update_one_field(group_id, "member_count", len(current_members))
|
||||
await self.update_one_field(group_id, "last_active", datetime.datetime.now().timestamp())
|
||||
|
||||
logger.info(f"群组 {group_id} 添加/更新成员 {normalized_member.get('nickname', member_id)} 成功")
|
||||
|
||||
async def remove_member(self, group_id: str, user_id: str):
|
||||
"""移除群成员"""
|
||||
if not group_id or not user_id:
|
||||
logger.debug("移除成员失败:group_id或user_id不能为空")
|
||||
return
|
||||
|
||||
# 获取当前成员列表
|
||||
current_members = await self.get_value(group_id, "member_list")
|
||||
if not isinstance(current_members, list):
|
||||
logger.debug(f"群组 {group_id} 成员列表为空或格式错误")
|
||||
return
|
||||
|
||||
# 移除指定成员
|
||||
original_count = len(current_members)
|
||||
current_members = [m for m in current_members if m.get("user_id") != user_id]
|
||||
new_count = len(current_members)
|
||||
|
||||
if new_count < original_count:
|
||||
# 更新成员列表和成员数量
|
||||
await self.update_one_field(group_id, "member_list", current_members)
|
||||
await self.update_one_field(group_id, "member_count", new_count)
|
||||
await self.update_one_field(group_id, "last_active", datetime.datetime.now().timestamp())
|
||||
logger.info(f"群组 {group_id} 移除成员 {user_id} 成功")
|
||||
else:
|
||||
logger.debug(f"群组 {group_id} 中未找到成员 {user_id}")
|
||||
|
||||
async def get_member_list(self, group_id: str) -> List[dict]:
|
||||
"""获取群成员列表"""
|
||||
if not group_id:
|
||||
logger.debug("获取成员列表失败:group_id不能为空")
|
||||
return []
|
||||
|
||||
members = await self.get_value(group_id, "member_list")
|
||||
if isinstance(members, list):
|
||||
return members
|
||||
return []
|
||||
|
||||
async def get_or_create_group(
|
||||
self, platform: str, group_number: int, group_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
根据 platform 和 group_number 获取 group_id。
|
||||
如果对应的群组不存在,则使用提供的信息创建新群组。
|
||||
使用try-except处理竞态条件,避免重复创建错误。
|
||||
"""
|
||||
group_id = self.get_group_id(platform, group_number)
|
||||
|
||||
def _db_get_or_create_sync(g_id: str, init_data: dict):
|
||||
"""原子性的获取或创建操作"""
|
||||
# 首先尝试获取现有记录
|
||||
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
|
||||
if record:
|
||||
return record, False # 记录存在,未创建
|
||||
|
||||
# 记录不存在,尝试创建
|
||||
try:
|
||||
GroupInfo.create(**init_data)
|
||||
return GroupInfo.get(GroupInfo.group_id == g_id), True # 创建成功
|
||||
except Exception as e:
|
||||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建群组 {g_id},获取现有记录")
|
||||
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
|
||||
if record:
|
||||
return record, False # 其他协程已创建,返回现有记录
|
||||
# 如果仍然失败,重新抛出异常
|
||||
raise e
|
||||
|
||||
initial_data = {
|
||||
"group_id": group_id,
|
||||
"platform": platform,
|
||||
"group_number": str(group_number),
|
||||
"group_name": group_name,
|
||||
"create_time": datetime.datetime.now().timestamp(),
|
||||
"last_active": datetime.datetime.now().timestamp(),
|
||||
"member_count": 0,
|
||||
"member_list": [],
|
||||
"group_info": {},
|
||||
}
|
||||
|
||||
# 序列化JSON字段
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in initial_data:
|
||||
if isinstance(initial_data[key], (list, dict)):
|
||||
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
|
||||
elif initial_data[key] is None:
|
||||
initial_data[key] = json.dumps([], ensure_ascii=False)
|
||||
|
||||
model_fields = GroupInfo._meta.fields.keys() # type: ignore
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, group_id, filtered_initial_data)
|
||||
|
||||
if was_created:
|
||||
logger.info(f"群组 {platform}:{group_number} (group_id: {group_id}) 不存在,将创建新记录 (Peewee)。")
|
||||
logger.info(f"已为 {group_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
|
||||
else:
|
||||
logger.debug(f"群组 {platform}:{group_number} (group_id: {group_id}) 已存在,返回现有记录。")
|
||||
|
||||
return group_id
|
||||
|
||||
async def get_group_info_by_name(self, group_name: str) -> dict | None:
|
||||
"""根据 group_name 查找群组并返回基本信息 (如果找到)"""
|
||||
if not group_name:
|
||||
logger.debug("get_group_info_by_name 获取失败:group_name 不能为空")
|
||||
return None
|
||||
|
||||
found_group_id = None
|
||||
for gid, name_in_cache in self.group_name_list.items():
|
||||
if name_in_cache == group_name:
|
||||
found_group_id = gid
|
||||
break
|
||||
|
||||
if not found_group_id:
|
||||
|
||||
def _db_find_by_name_sync(g_name_to_find: str):
|
||||
return GroupInfo.get_or_none(GroupInfo.group_name == g_name_to_find)
|
||||
|
||||
record = await asyncio.to_thread(_db_find_by_name_sync, group_name)
|
||||
if record:
|
||||
found_group_id = record.group_id
|
||||
if (
|
||||
found_group_id not in self.group_name_list
|
||||
or self.group_name_list[found_group_id] != group_name
|
||||
):
|
||||
self.group_name_list[found_group_id] = group_name
|
||||
else:
|
||||
logger.debug(f"数据库中也未找到名为 '{group_name}' 的群组 (Peewee)")
|
||||
return None
|
||||
|
||||
if found_group_id:
|
||||
required_fields = [
|
||||
"group_id",
|
||||
"platform",
|
||||
"group_number",
|
||||
"group_name",
|
||||
"group_impression",
|
||||
"short_impression",
|
||||
"member_count",
|
||||
"create_time",
|
||||
"last_active",
|
||||
]
|
||||
valid_fields_to_get = [
|
||||
f
|
||||
for f in required_fields
|
||||
if f in GroupInfo._meta.fields or f in group_info_default # type: ignore
|
||||
]
|
||||
|
||||
group_data = await self.get_values(found_group_id, valid_fields_to_get)
|
||||
|
||||
if group_data:
|
||||
final_result = {key: group_data.get(key) for key in required_fields}
|
||||
return final_result
|
||||
else:
|
||||
logger.warning(f"找到了 group_id '{found_group_id}' 但 get_values 返回空 (Peewee)")
|
||||
return None
|
||||
|
||||
logger.error(f"逻辑错误:未能为 '{group_name}' 确定 group_id (Peewee)")
|
||||
return None
|
||||
|
||||
|
||||
group_info_manager = None
|
||||
|
||||
|
||||
def get_group_info_manager():
|
||||
global group_info_manager
|
||||
if group_info_manager is None:
|
||||
group_info_manager = GroupInfoManager()
|
||||
return group_info_manager
|
||||
|
|
@ -1,183 +0,0 @@
|
|||
import time
|
||||
import json
|
||||
import re
|
||||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.person_info.group_info import get_group_info_manager
|
||||
from src.plugin_system.apis import message_api
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
logger = get_logger("group_relationship_manager")
|
||||
|
||||
|
||||
class GroupRelationshipManager:
|
||||
def __init__(self):
|
||||
self.group_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="relationship.group"
|
||||
)
|
||||
self.last_group_impression_time = 0.0
|
||||
self.last_group_impression_message_count = 0
|
||||
|
||||
async def build_relation(self, chat_id: str, platform: str) -> None:
|
||||
"""构建群关系,类似 relationship_builder.build_relation() 的调用方式"""
|
||||
current_time = time.time()
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(chat_id)
|
||||
|
||||
# 计算间隔时间,基于活跃度动态调整:最小10分钟,最大30分钟
|
||||
interval_seconds = max(600, int(1800 / max(0.5, talk_frequency)))
|
||||
|
||||
# 统计新消息数量
|
||||
# 先获取所有新消息,然后过滤掉麦麦的消息和命令消息
|
||||
all_new_messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=chat_id,
|
||||
start_time=self.last_group_impression_time,
|
||||
end_time=current_time,
|
||||
filter_mai=True,
|
||||
filter_command=True,
|
||||
)
|
||||
new_messages_since_last_impression = len(all_new_messages)
|
||||
|
||||
# 触发条件:时间间隔 OR 消息数量阈值
|
||||
if (current_time - self.last_group_impression_time >= interval_seconds) or \
|
||||
(new_messages_since_last_impression >= 100):
|
||||
logger.info(f"[{chat_id}] 触发群印象构建 (时间间隔: {current_time - self.last_group_impression_time:.0f}s, 消息数: {new_messages_since_last_impression})")
|
||||
|
||||
# 异步执行群印象构建
|
||||
asyncio.create_task(
|
||||
self.build_group_impression(
|
||||
chat_id=chat_id,
|
||||
platform=platform,
|
||||
lookback_hours=12,
|
||||
max_messages=300
|
||||
)
|
||||
)
|
||||
|
||||
self.last_group_impression_time = current_time
|
||||
self.last_group_impression_message_count = 0
|
||||
else:
|
||||
# 更新消息计数
|
||||
self.last_group_impression_message_count = new_messages_since_last_impression
|
||||
logger.debug(f"[{chat_id}] 群印象构建等待中 (时间: {current_time - self.last_group_impression_time:.0f}s/{interval_seconds}s, 消息: {new_messages_since_last_impression}/100)")
|
||||
|
||||
async def build_group_impression(
|
||||
self,
|
||||
chat_id: str,
|
||||
platform: str,
|
||||
lookback_hours: int = 24,
|
||||
max_messages: int = 300,
|
||||
) -> Optional[str]:
|
||||
"""基于最近聊天记录构建群印象并存储
|
||||
返回生成的topic
|
||||
"""
|
||||
now = time.time()
|
||||
start_ts = now - lookback_hours * 3600
|
||||
|
||||
# 拉取最近消息(包含边界)
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_ts, now)
|
||||
if not messages:
|
||||
logger.info(f"[{chat_id}] 无近期消息,跳过群印象构建")
|
||||
return None
|
||||
|
||||
# 限制数量,优先最新
|
||||
messages = sorted(messages, key=lambda m: m.get("time", 0))[-max_messages:]
|
||||
|
||||
# 构建可读文本
|
||||
readable = build_readable_messages(
|
||||
messages=messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
|
||||
)
|
||||
if not readable:
|
||||
logger.info(f"[{chat_id}] 构建可读消息文本为空,跳过")
|
||||
return None
|
||||
|
||||
# 确保群存在
|
||||
group_info_manager = get_group_info_manager()
|
||||
group_id = await group_info_manager.get_or_create_group(platform, chat_id)
|
||||
|
||||
group_name = await group_info_manager.get_value(group_id, "group_name") or chat_id
|
||||
alias_str = ", ".join(global_config.bot.alias_names)
|
||||
|
||||
prompt = f"""
|
||||
你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。
|
||||
你现在在群「{group_name}」(平台:{platform})中。
|
||||
请你根据以下群内最近的聊天记录,总结这个群给你的印象。
|
||||
|
||||
要求:
|
||||
- 关注群的氛围(友好/活跃/娱乐/学习/严肃等)、常见话题、互动风格、活跃时段或频率、是否有显著文化/梗。
|
||||
- 用白话表达,避免夸张或浮夸的词汇;语气自然、接地气。
|
||||
- 不要暴露任何个人隐私信息。
|
||||
- 请严格按照json格式输出,不要有其他多余内容:
|
||||
{{
|
||||
"impression": "不超过200字的群印象长描述,白话、自然",
|
||||
"topic": "一句话概括群主要聊什么,白话"
|
||||
}}
|
||||
|
||||
群内聊天(节选):
|
||||
{readable}
|
||||
"""
|
||||
# 生成印象
|
||||
content, _ = await self.group_llm.generate_response_async(prompt=prompt)
|
||||
raw_text = (content or "").strip()
|
||||
|
||||
def _strip_code_fences(text: str) -> str:
|
||||
if text.startswith("```") and text.endswith("```"):
|
||||
# 去除首尾围栏
|
||||
return re.sub(r"^```[a-zA-Z0-9_\-]*\n|\n```$", "", text, flags=re.S)
|
||||
# 提取围栏中的主体
|
||||
match = re.search(r"```[a-zA-Z0-9_\-]*\n([\s\S]*?)\n```", text)
|
||||
return match.group(1) if match else text
|
||||
|
||||
parsed_text = _strip_code_fences(raw_text)
|
||||
|
||||
long_impression: str = ""
|
||||
topic_val: Any = ""
|
||||
|
||||
# 参考关系模块:先repair_json再loads,兼容返回列表/字典/字符串
|
||||
try:
|
||||
fixed = repair_json(parsed_text)
|
||||
data = json.loads(fixed) if isinstance(fixed, str) else fixed
|
||||
if isinstance(data, list) and data and isinstance(data[0], dict):
|
||||
data = data[0]
|
||||
if isinstance(data, dict):
|
||||
long_impression = str(data.get("impression") or "").strip()
|
||||
topic_val = data.get("topic", "")
|
||||
else:
|
||||
# 不是字典,直接作为文本
|
||||
text_fallback = str(data)
|
||||
long_impression = text_fallback[:400].strip()
|
||||
topic_val = ""
|
||||
except Exception:
|
||||
long_impression = parsed_text[:400].strip()
|
||||
topic_val = ""
|
||||
|
||||
# 兜底
|
||||
if not long_impression and not topic_val:
|
||||
logger.info(f"[{chat_id}] LLM未产生有效群印象,跳过")
|
||||
return None
|
||||
|
||||
# 写入数据库
|
||||
await group_info_manager.update_one_field(group_id, "group_impression", long_impression)
|
||||
if topic_val:
|
||||
await group_info_manager.update_one_field(group_id, "topic", topic_val)
|
||||
await group_info_manager.update_one_field(group_id, "last_active", now)
|
||||
|
||||
logger.info(f"[{chat_id}] 群印象更新完成: topic={topic_val}")
|
||||
return str(topic_val) if topic_val else ""
|
||||
|
||||
|
||||
group_relationship_manager: Optional[GroupRelationshipManager] = None
|
||||
|
||||
|
||||
def get_group_relationship_manager() -> GroupRelationshipManager:
|
||||
global group_relationship_manager
|
||||
if group_relationship_manager is None:
|
||||
group_relationship_manager = GroupRelationshipManager()
|
||||
return group_relationship_manager
|
||||
|
|
@ -71,7 +71,7 @@ class Person:
|
|||
person_id = get_person_id(platform, user_id)
|
||||
|
||||
if is_person_known(person_id=person_id):
|
||||
logger.info(f"用户 {nickname} 已存在")
|
||||
logger.debug(f"用户 {nickname} 已存在")
|
||||
return Person(person_id=person_id)
|
||||
|
||||
# 创建Person实例
|
||||
|
|
@ -148,9 +148,13 @@ class Person:
|
|||
|
||||
if not is_person_known(person_id=self.person_id):
|
||||
self.is_known = False
|
||||
logger.warning(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
|
||||
logger.debug(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
|
||||
self.person_name = f"未知用户{self.person_id[:4]}"
|
||||
return
|
||||
# raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
|
||||
|
||||
|
||||
|
||||
|
||||
self.is_known = False
|
||||
|
||||
|
|
|
|||
|
|
@ -300,15 +300,6 @@ class RelationshipBuilder:
|
|||
|
||||
return cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0
|
||||
|
||||
def force_cleanup_user_segments(self, person_id: str) -> bool:
|
||||
"""强制清理指定用户的所有消息段"""
|
||||
if person_id in self.person_engaged_cache:
|
||||
segments_count = len(self.person_engaged_cache[person_id])
|
||||
del self.person_engaged_cache[person_id]
|
||||
self._save_cache()
|
||||
logger.info(f"{self.log_prefix} 强制清理用户 {person_id} 的 {segments_count} 个消息段")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_cache_status(self) -> str:
|
||||
# sourcery skip: merge-list-append, merge-list-appends-into-extend
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from src.common.logger import get_logger
|
||||
from .person_info import Person,is_person_known
|
||||
from .person_info import Person
|
||||
import random
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
|
|
@ -272,7 +272,7 @@ class RelationshipManager:
|
|||
return ""
|
||||
|
||||
attitude_score = attitude_data["attitude"]
|
||||
confidence = attitude_data["confidence"]
|
||||
confidence = pow(attitude_data["confidence"],2)
|
||||
|
||||
new_confidence = total_confidence + confidence
|
||||
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence)/new_confidence
|
||||
|
|
@ -318,7 +318,7 @@ class RelationshipManager:
|
|||
return ""
|
||||
|
||||
neuroticism_score = neuroticism_data["neuroticism"]
|
||||
confidence = neuroticism_data["confidence"]
|
||||
confidence = pow(neuroticism_data["confidence"],2)
|
||||
|
||||
new_confidence = total_confidence + confidence
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[inner]
|
||||
version = "6.4.2"
|
||||
version = "6.4.5"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
|
|
@ -130,10 +130,7 @@ filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合
|
|||
|
||||
[memory]
|
||||
enable_memory = true # 是否启用记忆系统
|
||||
memory_build_interval = 600 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
|
||||
memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
||||
memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
|
||||
memory_build_sample_length = 30 # 采样长度,数值越高一段记忆内容越丰富
|
||||
memory_build_frequency = 1 # 记忆构建频率 越高,麦麦学习越多
|
||||
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
|
||||
|
||||
forget_memory_interval = 3000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习
|
||||
|
|
|
|||
Loading…
Reference in New Issue