mirror of https://github.com/Mai-with-u/MaiBot.git
remove:移除无用代码
parent
a3390f6cba
commit
d519406e4a
|
|
@ -248,8 +248,6 @@ class ExpressionLearner:
|
|||
style,
|
||||
_context,
|
||||
_context_words,
|
||||
_full_context,
|
||||
_full_context_embedding,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
|
||||
|
|
@ -263,8 +261,6 @@ class ExpressionLearner:
|
|||
style,
|
||||
context,
|
||||
context_words,
|
||||
full_context,
|
||||
full_context_embedding,
|
||||
) in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
|
|
@ -274,8 +270,6 @@ class ExpressionLearner:
|
|||
"style": style,
|
||||
"context": context,
|
||||
"context_words": context_words,
|
||||
"full_context": full_context,
|
||||
"full_context_embedding": full_context_embedding,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -299,8 +293,6 @@ class ExpressionLearner:
|
|||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.context = new_expr["context"]
|
||||
expr_obj.context_words = new_expr["context_words"]
|
||||
expr_obj.full_context = new_expr["full_context"]
|
||||
expr_obj.full_context_embedding = new_expr["full_context_embedding"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
|
|
@ -315,8 +307,6 @@ class ExpressionLearner:
|
|||
create_date=current_time, # 手动设置创建日期
|
||||
context=new_expr["context"],
|
||||
context_words=new_expr["context_words"],
|
||||
full_context=new_expr["full_context"],
|
||||
full_context_embedding=new_expr["full_context_embedding"],
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
|
|
@ -428,7 +418,7 @@ class ExpressionLearner:
|
|||
|
||||
async def learn_expression(
|
||||
self, num: int = 10
|
||||
) -> Optional[List[Tuple[str, str, str, List[str], str, List[float]]]]:
|
||||
) -> Optional[List[Tuple[str, str, str, List[str]]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
|
|
@ -482,19 +472,14 @@ class ExpressionLearner:
|
|||
)
|
||||
|
||||
split_matched_expressions_w_emb = []
|
||||
full_context_embedding: List[float] = await self.get_full_context_embedding(random_msg_match_str)
|
||||
|
||||
for situation, style, context, context_words in split_matched_expressions:
|
||||
split_matched_expressions_w_emb.append(
|
||||
(self.chat_id, situation, style, context, context_words, random_msg_match_str, full_context_embedding)
|
||||
(self.chat_id, situation, style, context, context_words)
|
||||
)
|
||||
|
||||
return split_matched_expressions_w_emb
|
||||
|
||||
async def get_full_context_embedding(self, context: str) -> List[float]:
|
||||
embedding, _ = await self.embedding_model.get_embedding(context)
|
||||
return embedding
|
||||
|
||||
def split_expression_context(
|
||||
self, matched_expressions: List[Tuple[str, str, str]]
|
||||
) -> List[Tuple[str, str, str, List[str]]]:
|
||||
|
|
|
|||
|
|
@ -7,21 +7,17 @@ import re
|
|||
import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Set, Coroutine, Any
|
||||
from typing import List, Tuple, Set
|
||||
from collections import Counter
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.config.config import model_config
|
||||
from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.utils import cut_key_words
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
) # 导入 build_readable_messages
|
||||
|
||||
|
||||
# 添加cosine_similarity函数
|
||||
|
|
@ -86,29 +82,10 @@ class MemoryGraph:
|
|||
if "memory_items" in self.G.nodes[concept]:
|
||||
# 获取现有的记忆项(已经是str格式)
|
||||
existing_memory = self.G.nodes[concept]["memory_items"]
|
||||
|
||||
# 如果现有记忆不为空,则使用LLM整合新旧记忆
|
||||
if existing_memory and hippocampus_instance and hippocampus_instance.model_small:
|
||||
try:
|
||||
integrated_memory = await self._integrate_memories_with_llm(
|
||||
existing_memory, str(memory), hippocampus_instance.model_small
|
||||
)
|
||||
self.G.nodes[concept]["memory_items"] = integrated_memory
|
||||
# 整合成功,增加权重
|
||||
current_weight = self.G.nodes[concept].get("weight", 0.0)
|
||||
self.G.nodes[concept]["weight"] = current_weight + 1.0
|
||||
logger.debug(f"节点 {concept} 记忆整合成功,权重增加到 {current_weight + 1.0}")
|
||||
logger.info(f"节点 {concept} 记忆内容已更新:{integrated_memory}")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM整合记忆失败: {e}")
|
||||
# 降级到简单连接
|
||||
new_memory_str = f"{existing_memory} | {memory}"
|
||||
self.G.nodes[concept]["memory_items"] = new_memory_str
|
||||
logger.info(f"节点 {concept} 记忆内容已简单拼接并更新:{new_memory_str}")
|
||||
else:
|
||||
new_memory_str = str(memory)
|
||||
self.G.nodes[concept]["memory_items"] = new_memory_str
|
||||
logger.info(f"节点 {concept} 记忆内容已直接更新:{new_memory_str}")
|
||||
# 简单连接新旧记忆
|
||||
new_memory_str = f"{existing_memory} | {memory}"
|
||||
self.G.nodes[concept]["memory_items"] = new_memory_str
|
||||
logger.info(f"节点 {concept} 记忆内容已简单拼接并更新:{new_memory_str}")
|
||||
else:
|
||||
self.G.nodes[concept]["memory_items"] = str(memory)
|
||||
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
|
||||
|
|
@ -164,53 +141,6 @@ class MemoryGraph:
|
|||
|
||||
return first_layer_items, second_layer_items
|
||||
|
||||
async def _integrate_memories_with_llm(self, existing_memory: str, new_memory: str, llm_model: LLMRequest) -> str:
|
||||
"""
|
||||
使用LLM整合新旧记忆内容
|
||||
|
||||
Args:
|
||||
existing_memory: 现有的记忆内容(字符串格式,可能包含多条记忆)
|
||||
new_memory: 新的记忆内容
|
||||
llm_model: LLM模型实例
|
||||
|
||||
Returns:
|
||||
str: 整合后的记忆内容
|
||||
"""
|
||||
try:
|
||||
# 构建整合提示
|
||||
integration_prompt = f"""你是一个记忆整合专家。请将以下的旧记忆和新记忆整合成一条更完整、更准确的记忆内容。
|
||||
|
||||
旧记忆内容:
|
||||
{existing_memory}
|
||||
|
||||
新记忆内容:
|
||||
{new_memory}
|
||||
|
||||
整合要求:
|
||||
1. 保留重要信息,去除重复内容
|
||||
2. 如果新旧记忆有冲突,合理整合矛盾的地方
|
||||
3. 将相关信息合并,形成更完整的描述
|
||||
4. 保持语言简洁、准确
|
||||
5. 只返回整合后的记忆内容,不要添加任何解释
|
||||
|
||||
整合后的记忆:"""
|
||||
|
||||
# 调用LLM进行整合
|
||||
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(
|
||||
integration_prompt
|
||||
)
|
||||
|
||||
if content and content.strip():
|
||||
integrated_content = content.strip()
|
||||
logger.debug(f"LLM记忆整合成功,模型: {model_name}")
|
||||
return integrated_content
|
||||
else:
|
||||
logger.warning("LLM返回的整合结果为空,使用默认连接方式")
|
||||
return f"{existing_memory} | {new_memory}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM记忆整合过程中出错: {e}")
|
||||
return f"{existing_memory} | {new_memory}"
|
||||
|
||||
@property
|
||||
def dots(self):
|
||||
|
|
@ -242,7 +172,6 @@ class MemoryGraph:
|
|||
class Hippocampus:
|
||||
def __init__(self):
|
||||
self.memory_graph = MemoryGraph()
|
||||
self.model_small: LLMRequest = None # type: ignore
|
||||
self.entorhinal_cortex: EntorhinalCortex = None # type: ignore
|
||||
self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore
|
||||
|
||||
|
|
@ -252,44 +181,11 @@ class Hippocampus:
|
|||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||
# 从数据库加载记忆图
|
||||
self.entorhinal_cortex.sync_memory_from_db()
|
||||
self.model_small = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="memory.modify"
|
||||
)
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取记忆图中所有节点的名字列表"""
|
||||
return list(self.memory_graph.G.nodes())
|
||||
|
||||
def calculate_weighted_activation(self, current_activation: float, edge_strength: int, target_node: str) -> float:
|
||||
"""
|
||||
计算考虑节点权重的激活值
|
||||
|
||||
Args:
|
||||
current_activation: 当前激活值
|
||||
edge_strength: 边的强度
|
||||
target_node: 目标节点名称
|
||||
|
||||
Returns:
|
||||
float: 计算后的激活值
|
||||
"""
|
||||
# 基础激活值计算
|
||||
base_activation = current_activation - (1 / edge_strength)
|
||||
|
||||
if base_activation <= 0:
|
||||
return 0.0
|
||||
|
||||
# 获取目标节点的权重
|
||||
if target_node in self.memory_graph.G:
|
||||
node_data = self.memory_graph.G.nodes[target_node]
|
||||
node_weight = node_data.get("weight", 1.0)
|
||||
|
||||
# 权重加成:每次整合增加10%激活值,最大加成200%
|
||||
weight_multiplier = 1.0 + min((node_weight - 1.0) * 0.1, 2.0)
|
||||
|
||||
return base_activation * weight_multiplier
|
||||
else:
|
||||
return base_activation
|
||||
|
||||
@staticmethod
|
||||
def calculate_node_hash(concept, memory_items) -> int:
|
||||
"""计算节点的特征值"""
|
||||
|
|
@ -309,45 +205,8 @@ class Hippocampus:
|
|||
# 直接使用元组,保证顺序一致性
|
||||
return hash((source, target))
|
||||
|
||||
@staticmethod
|
||||
def find_topic_llm(text: str, topic_num: int | list[int]):
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
topic_num_str = ""
|
||||
if isinstance(topic_num, list):
|
||||
topic_num_str = f"{topic_num[0]}-{topic_num[1]}"
|
||||
else:
|
||||
topic_num_str = topic_num
|
||||
|
||||
prompt = (
|
||||
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num_str}个关键的概念,必须是某种概念,比如人,事,物,概念,事件,地点 等等,帮我列出来,"
|
||||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||
f"如果确定找不出主题或者没有明显主题,返回<none>。"
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def topic_what(text, topic):
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
# 不再需要 time_info 参数
|
||||
prompt = (
|
||||
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成几句自然的话,'
|
||||
f"要求包含对这个概念的定义,内容,知识,时间和人物,这些信息必须来自这段文字,不能添加信息。\n只输出几句自然的话就好"
|
||||
)
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def calculate_topic_num(text, compress_rate):
|
||||
"""计算文本的话题数量"""
|
||||
information_content = calculate_information_content(text)
|
||||
topic_by_length = text.count("\n") * compress_rate
|
||||
topic_by_information_content = max(1, min(5, int((information_content - 3) * 2)))
|
||||
topic_num = int((topic_by_length + topic_by_information_content) / 2)
|
||||
logger.debug(
|
||||
f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, "
|
||||
f"topic_num: {topic_num}"
|
||||
)
|
||||
return topic_num
|
||||
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
"""从关键词获取相关记忆。
|
||||
|
|
@ -1039,395 +898,16 @@ class EntorhinalCortex:
|
|||
)
|
||||
|
||||
|
||||
# 负责整合,遗忘,合并记忆
|
||||
# 负责记忆管理
|
||||
class ParahippocampalGyrus:
|
||||
def __init__(self, hippocampus: Hippocampus):
|
||||
self.hippocampus = hippocampus
|
||||
self.memory_graph = hippocampus.memory_graph
|
||||
|
||||
self.memory_modify_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="memory.modify"
|
||||
)
|
||||
|
||||
async def memory_compress(self, messages: list[DatabaseMessages], compress_rate=0.1):
|
||||
"""压缩和总结消息内容,生成记忆主题和摘要。
|
||||
|
||||
Args:
|
||||
messages (list): 消息列表,每个消息是一个字典,包含数据库消息结构。
|
||||
compress_rate (float, optional): 压缩率,用于控制生成的主题数量。默认为0.1。
|
||||
|
||||
Returns:
|
||||
tuple: (compressed_memory, similar_topics_dict)
|
||||
- compressed_memory: set, 压缩后的记忆集合,每个元素是一个元组 (topic, summary)
|
||||
- similar_topics_dict: dict, 相似主题字典
|
||||
|
||||
Process:
|
||||
1. 使用 build_readable_messages 生成包含时间、人物信息的格式化文本。
|
||||
2. 使用LLM提取关键主题。
|
||||
3. 过滤掉包含禁用关键词的主题。
|
||||
4. 为每个主题生成摘要。
|
||||
5. 查找与现有记忆中的相似主题。
|
||||
"""
|
||||
if not messages:
|
||||
return set(), {}
|
||||
|
||||
# 1. 使用 build_readable_messages 生成格式化文本
|
||||
# build_readable_messages 只返回一个字符串,不需要解包
|
||||
input_text = build_readable_messages(
|
||||
messages,
|
||||
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
|
||||
replace_bot_name=False, # 保留原始用户名
|
||||
)
|
||||
|
||||
# 如果生成的可读文本为空(例如所有消息都无效),则直接返回
|
||||
if not input_text:
|
||||
logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。")
|
||||
return set(), {}
|
||||
|
||||
current_date = f"当前日期: {datetime.datetime.now().isoformat()}"
|
||||
input_text = f"{current_date}\n{input_text}"
|
||||
|
||||
logger.debug(f"记忆来源:\n{input_text}")
|
||||
|
||||
# 2. 使用LLM提取关键主题
|
||||
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
|
||||
topics_response, _ = await self.memory_modify_model.generate_response_async(
|
||||
self.hippocampus.find_topic_llm(input_text, topic_num)
|
||||
)
|
||||
|
||||
# 提取<>中的内容
|
||||
topics = re.findall(r"<([^>]+)>", topics_response)
|
||||
|
||||
if not topics:
|
||||
topics = ["none"]
|
||||
else:
|
||||
topics = [
|
||||
topic.strip()
|
||||
for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
if topic.strip()
|
||||
]
|
||||
|
||||
# 3. 过滤掉包含禁用关键词的topic
|
||||
filtered_topics = [
|
||||
topic for topic in topics if all(keyword not in topic for keyword in global_config.memory.memory_ban_words)
|
||||
]
|
||||
|
||||
logger.debug(f"过滤后话题: {filtered_topics}")
|
||||
|
||||
# 4. 创建所有话题的摘要生成任务
|
||||
tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List | None]]]]] = []
|
||||
topic_what_prompt: str = ""
|
||||
for topic in filtered_topics:
|
||||
# 调用修改后的 topic_what,不再需要 time_info
|
||||
topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
|
||||
try:
|
||||
task = self.memory_modify_model.generate_response_async(topic_what_prompt)
|
||||
tasks.append((topic.strip(), task))
|
||||
except Exception as e:
|
||||
logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}")
|
||||
continue
|
||||
|
||||
# 等待所有任务完成
|
||||
compressed_memory: Set[Tuple[str, str]] = set()
|
||||
similar_topics_dict = {}
|
||||
|
||||
for topic, task in tasks:
|
||||
response = await task
|
||||
if response:
|
||||
compressed_memory.add((topic, response[0]))
|
||||
|
||||
existing_topics = list(self.memory_graph.G.nodes())
|
||||
similar_topics = []
|
||||
|
||||
for existing_topic in existing_topics:
|
||||
topic_words = set(jieba.cut(topic))
|
||||
existing_words = set(jieba.cut(existing_topic))
|
||||
|
||||
all_words = topic_words | existing_words
|
||||
v1 = [1 if word in topic_words else 0 for word in all_words]
|
||||
v2 = [1 if word in existing_words else 0 for word in all_words]
|
||||
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
|
||||
if similarity >= 0.7:
|
||||
similar_topics.append((existing_topic, similarity))
|
||||
|
||||
similar_topics.sort(key=lambda x: x[1], reverse=True)
|
||||
similar_topics = similar_topics[:3]
|
||||
similar_topics_dict[topic] = similar_topics
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"prompt: {topic_what_prompt}")
|
||||
logger.info(f"压缩后的记忆: {compressed_memory}")
|
||||
logger.info(f"相似主题: {similar_topics_dict}")
|
||||
|
||||
return compressed_memory, similar_topics_dict
|
||||
|
||||
def get_similar_topics_from_keywords(
|
||||
self,
|
||||
keywords: list[str] | str,
|
||||
top_k: int = 3,
|
||||
threshold: float = 0.7,
|
||||
) -> dict[str, list[tuple[str, float]]]:
|
||||
"""基于输入的关键词,返回每个关键词对应的相似主题列表。
|
||||
|
||||
Args:
|
||||
keywords: 关键词列表或以逗号/空格/顿号分隔的字符串。
|
||||
top_k: 每个关键词返回的相似主题数量上限。
|
||||
threshold: 相似度阈值,低于该值的主题将被过滤。
|
||||
|
||||
Returns:
|
||||
dict[str, list[tuple[str, float]]]: {keyword: [(topic, similarity), ...]}
|
||||
"""
|
||||
# 规范化输入为列表[str]
|
||||
if isinstance(keywords, str):
|
||||
# 支持中英文逗号、顿号、空格分隔
|
||||
parts = keywords.replace(",", ",").replace("、", ",").replace(" ", ",").strip(", ")
|
||||
keyword_list = [p.strip() for p in parts.split(",") if p.strip()]
|
||||
else:
|
||||
keyword_list = [k.strip() for k in keywords if isinstance(k, str) and k.strip()]
|
||||
|
||||
if not keyword_list:
|
||||
return {}
|
||||
|
||||
existing_topics = list(self.memory_graph.G.nodes())
|
||||
result: dict[str, list[tuple[str, float]]] = {}
|
||||
|
||||
for kw in keyword_list:
|
||||
kw_words = set(jieba.cut(kw))
|
||||
similar_topics: list[tuple[str, float]] = []
|
||||
|
||||
for topic in existing_topics:
|
||||
topic_words = set(jieba.cut(topic))
|
||||
all_words = kw_words | topic_words
|
||||
if not all_words:
|
||||
continue
|
||||
v1 = [1 if w in kw_words else 0 for w in all_words]
|
||||
v2 = [1 if w in topic_words else 0 for w in all_words]
|
||||
sim = cosine_similarity(v1, v2)
|
||||
if sim >= threshold:
|
||||
similar_topics.append((topic, sim))
|
||||
|
||||
similar_topics.sort(key=lambda x: x[1], reverse=True)
|
||||
result[kw] = similar_topics[:top_k]
|
||||
|
||||
return result
|
||||
|
||||
async def add_memory_with_similar(
|
||||
self,
|
||||
memory_item: str,
|
||||
similar_topics_dict: dict[str, list[tuple[str, float]]],
|
||||
) -> bool:
|
||||
"""将单条记忆内容与相似主题写入记忆网络并同步数据库。
|
||||
|
||||
按 build_memory_for_chat 的方式:为 similar_topics_dict 的每个键作为主题添加节点内容,
|
||||
并与其相似主题建立连接,连接强度为 int(similarity * 10)。
|
||||
|
||||
Args:
|
||||
memory_item: 记忆内容字符串,将作为每个主题节点的 memory_items。
|
||||
similar_topics_dict: {topic: [(similar_topic, similarity), ...]}
|
||||
|
||||
Returns:
|
||||
bool: 是否成功执行添加与同步。
|
||||
"""
|
||||
try:
|
||||
if not memory_item or not isinstance(memory_item, str):
|
||||
return False
|
||||
|
||||
if not similar_topics_dict or not isinstance(similar_topics_dict, dict):
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 为每个主题写入节点
|
||||
for topic, similar_list in similar_topics_dict.items():
|
||||
if not topic or not isinstance(topic, str):
|
||||
continue
|
||||
|
||||
await self.hippocampus.memory_graph.add_dot(topic, memory_item, self.hippocampus)
|
||||
|
||||
# 连接相似主题
|
||||
if isinstance(similar_list, list):
|
||||
for item in similar_list:
|
||||
try:
|
||||
similar_topic, similarity = item
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(similar_topic, str):
|
||||
continue
|
||||
if topic == similar_topic:
|
||||
continue
|
||||
# 强度按 build_memory_for_chat 的规则
|
||||
strength = int(max(0.0, float(similarity)) * 10) if similarity is not None else 0
|
||||
if strength <= 0:
|
||||
continue
|
||||
# 确保相似主题节点存在(如果没有,也可以只建立边,networkx会创建节点,但需初始化属性)
|
||||
if similar_topic not in self.memory_graph.G:
|
||||
# 创建一个空的相似主题节点,避免悬空边,memory_items 为空字符串
|
||||
self.memory_graph.G.add_node(
|
||||
similar_topic,
|
||||
memory_items="",
|
||||
weight=1.0,
|
||||
created_time=current_time,
|
||||
last_modified=current_time,
|
||||
)
|
||||
self.memory_graph.G.add_edge(
|
||||
topic,
|
||||
similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time,
|
||||
)
|
||||
|
||||
# 同步数据库
|
||||
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加记忆节点失败: {e}")
|
||||
return False
|
||||
|
||||
async def operation_forget_topic(self, percentage=0.005):
|
||||
start_time = time.time()
|
||||
logger.info("[遗忘] 开始检查数据库...")
|
||||
|
||||
# 验证百分比参数
|
||||
if not 0 <= percentage <= 1:
|
||||
logger.warning(f"[遗忘] 无效的遗忘百分比: {percentage}, 使用默认值 0.005")
|
||||
percentage = 0.005
|
||||
|
||||
all_nodes = list(self.memory_graph.G.nodes())
|
||||
all_edges = list(self.memory_graph.G.edges())
|
||||
|
||||
if not all_nodes and not all_edges:
|
||||
logger.info("[遗忘] 记忆图为空,无需进行遗忘操作")
|
||||
return
|
||||
|
||||
# 确保至少检查1个节点和边,且不超过总数
|
||||
check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage)))
|
||||
check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage)))
|
||||
|
||||
# 只有在有足够的节点和边时才进行采样
|
||||
if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count:
|
||||
try:
|
||||
nodes_to_check = random.sample(all_nodes, check_nodes_count)
|
||||
edges_to_check = random.sample(all_edges, check_edges_count)
|
||||
except ValueError as e:
|
||||
logger.error(f"[遗忘] 采样错误: {str(e)}")
|
||||
return
|
||||
else:
|
||||
logger.info("[遗忘] 没有足够的节点或边进行遗忘操作")
|
||||
return
|
||||
|
||||
# 使用列表存储变化信息
|
||||
edge_changes = {
|
||||
"weakened": [], # 存储减弱的边
|
||||
"removed": [], # 存储移除的边
|
||||
}
|
||||
node_changes = {
|
||||
"reduced": [], # 存储减少记忆的节点
|
||||
"removed": [], # 存储移除的节点
|
||||
}
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
logger.info("[遗忘] 开始检查连接...")
|
||||
edge_check_start = time.time()
|
||||
for source, target in edges_to_check:
|
||||
edge_data = self.memory_graph.G[source][target]
|
||||
last_modified = edge_data.get("last_modified")
|
||||
|
||||
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
|
||||
current_strength = edge_data.get("strength", 1)
|
||||
new_strength = current_strength - 1
|
||||
|
||||
if new_strength <= 0:
|
||||
self.memory_graph.G.remove_edge(source, target)
|
||||
edge_changes["removed"].append(f"{source} -> {target}")
|
||||
else:
|
||||
edge_data["strength"] = new_strength
|
||||
edge_data["last_modified"] = current_time
|
||||
edge_changes["weakened"].append(f"{source}-{target} (强度: {current_strength} -> {new_strength})")
|
||||
edge_check_end = time.time()
|
||||
logger.info(f"[遗忘] 连接检查耗时: {edge_check_end - edge_check_start:.2f}秒")
|
||||
|
||||
logger.info("[遗忘] 开始检查节点...")
|
||||
node_check_start = time.time()
|
||||
for node in nodes_to_check:
|
||||
# 检查节点是否存在,以防在迭代中被移除(例如边移除导致)
|
||||
if node not in self.memory_graph.G:
|
||||
continue
|
||||
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
|
||||
# 首先获取记忆项
|
||||
memory_items = node_data.get("memory_items", "")
|
||||
# 直接检查记忆内容是否为空
|
||||
if not memory_items or memory_items.strip() == "":
|
||||
try:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes["removed"].append(f"{node}(空节点)") # 标记为空节点移除
|
||||
logger.debug(f"[遗忘] 移除了空的节点: {node}")
|
||||
except nx.NetworkXError as e:
|
||||
logger.warning(f"[遗忘] 移除空节点 {node} 时发生错误(可能已被移除): {e}")
|
||||
continue # 处理下一个节点
|
||||
|
||||
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
|
||||
last_modified = node_data.get("last_modified", current_time)
|
||||
node_weight = node_data.get("weight", 1.0)
|
||||
|
||||
# 条件1:检查是否长时间未修改 (使用配置的遗忘时间)
|
||||
time_threshold = 3600 * global_config.memory.memory_forget_time
|
||||
|
||||
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
|
||||
# 权重为1时使用默认阈值,权重越高阈值越大(越难遗忘)
|
||||
adjusted_threshold = time_threshold * node_weight
|
||||
|
||||
if current_time - last_modified > adjusted_threshold and memory_items:
|
||||
# 既然每个节点现在是完整记忆,直接删除整个节点
|
||||
try:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes["removed"].append(f"{node}(长时间未修改,权重{node_weight:.1f})")
|
||||
logger.debug(f"[遗忘] 移除了长时间未修改的节点: {node} (权重: {node_weight:.1f})")
|
||||
except nx.NetworkXError as e:
|
||||
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误(可能已被移除): {e}")
|
||||
continue
|
||||
node_check_end = time.time()
|
||||
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}秒")
|
||||
|
||||
if any(edge_changes.values()) or any(node_changes.values()):
|
||||
sync_start = time.time()
|
||||
|
||||
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
|
||||
|
||||
sync_end = time.time()
|
||||
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}秒")
|
||||
|
||||
# 汇总输出所有变化
|
||||
logger.info("[遗忘] 遗忘操作统计:")
|
||||
if edge_changes["weakened"]:
|
||||
logger.info(
|
||||
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}"
|
||||
)
|
||||
|
||||
if edge_changes["removed"]:
|
||||
logger.info(
|
||||
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}"
|
||||
)
|
||||
|
||||
if node_changes["reduced"]:
|
||||
logger.info(
|
||||
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}"
|
||||
)
|
||||
|
||||
if node_changes["removed"]:
|
||||
logger.info(
|
||||
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}"
|
||||
)
|
||||
else:
|
||||
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
|
||||
class HippocampusManager:
|
||||
|
|
@ -1462,11 +942,6 @@ class HippocampusManager:
|
|||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus
|
||||
|
||||
async def forget_memory(self, percentage: float = 0.005):
|
||||
"""遗忘记忆的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class MemoryChest:
|
|||
request_type="memory_chest_build",
|
||||
)
|
||||
|
||||
self.memory_build_threshold = 30
|
||||
self.memory_build_threshold = 20
|
||||
self.memory_size_limit = global_config.memory.max_memory_size
|
||||
|
||||
self.running_content_list = {} # {chat_id: {"content": running_content, "last_update_time": timestamp, "create_time": timestamp}}
|
||||
|
|
@ -154,10 +154,13 @@ class MemoryChest:
|
|||
|
||||
|
||||
|
||||
def get_all_titles(self) -> list[str]:
|
||||
def get_all_titles(self, exclude_locked: bool = False) -> list[str]:
|
||||
"""
|
||||
获取记忆仓库中的所有标题
|
||||
|
||||
Args:
|
||||
exclude_locked: 是否排除锁定的记忆,默认为 False
|
||||
|
||||
Returns:
|
||||
list: 包含所有标题的列表
|
||||
"""
|
||||
|
|
@ -166,6 +169,9 @@ class MemoryChest:
|
|||
titles = []
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title:
|
||||
# 如果 exclude_locked 为 True 且记忆已锁定,则跳过
|
||||
if exclude_locked and memory.locked:
|
||||
continue
|
||||
titles.append(memory.title)
|
||||
return titles
|
||||
except Exception as e:
|
||||
|
|
@ -261,8 +267,8 @@ class MemoryChest:
|
|||
Returns:
|
||||
str: 选择的标题
|
||||
"""
|
||||
# 获取所有标题并构建格式化字符串
|
||||
titles = self.get_all_titles()
|
||||
# 获取所有标题并构建格式化字符串(排除锁定的记忆)
|
||||
titles = self.get_all_titles(exclude_locked=True)
|
||||
formatted_titles = ""
|
||||
for title in titles:
|
||||
formatted_titles += f"{title}\n"
|
||||
|
|
@ -375,15 +381,15 @@ class MemoryChest:
|
|||
async def choose_merge_target(self, memory_title: str) -> list[str]:
|
||||
"""
|
||||
选择与给定记忆标题相关的记忆目标
|
||||
|
||||
|
||||
Args:
|
||||
memory_title: 要匹配的记忆标题
|
||||
|
||||
|
||||
Returns:
|
||||
list[str]: 选中的记忆内容列表
|
||||
"""
|
||||
try:
|
||||
all_titles = self.get_all_titles()
|
||||
all_titles = self.get_all_titles(exclude_locked=True)
|
||||
content = ""
|
||||
for title in all_titles:
|
||||
content += f"{title}\n"
|
||||
|
|
@ -430,10 +436,10 @@ class MemoryChest:
|
|||
def _get_memories_by_titles(self, titles: list[str]) -> list[str]:
|
||||
"""
|
||||
根据标题列表查找对应的记忆内容
|
||||
|
||||
|
||||
Args:
|
||||
titles: 记忆标题列表
|
||||
|
||||
|
||||
Returns:
|
||||
list[str]: 记忆内容列表
|
||||
"""
|
||||
|
|
@ -442,22 +448,32 @@ class MemoryChest:
|
|||
for title in titles:
|
||||
if not title or not title.strip():
|
||||
continue
|
||||
|
||||
|
||||
# 使用模糊查找匹配记忆
|
||||
try:
|
||||
best_match = find_best_matching_memory(title.strip(), similarity_threshold=0.8)
|
||||
if best_match:
|
||||
contents.append(best_match[1]) # best_match[1] 是 content
|
||||
logger.debug(f"找到记忆: {best_match[0]} (相似度: {best_match[2]:.3f})")
|
||||
# 检查记忆是否被锁定
|
||||
memory_title = best_match[0]
|
||||
memory_content = best_match[1]
|
||||
|
||||
# 查询数据库中的锁定状态
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title == memory_title and memory.locked:
|
||||
logger.warning(f"记忆 '{memory_title}' 已锁定,跳过合并")
|
||||
continue
|
||||
|
||||
contents.append(memory_content)
|
||||
logger.debug(f"找到记忆: {memory_title} (相似度: {best_match[2]:.3f})")
|
||||
else:
|
||||
logger.warning(f"未找到相似度 >= 0.8 的标题匹配: '{title}'")
|
||||
except Exception as e:
|
||||
logger.error(f"查找标题 '{title}' 的记忆时出错: {e}")
|
||||
continue
|
||||
|
||||
|
||||
logger.info(f"成功找到 {len(contents)} 条记忆内容")
|
||||
return contents
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"根据标题查找记忆时出错: {e}")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -306,8 +306,6 @@ class Expression(BaseModel):
|
|||
# new mode fields
|
||||
context = TextField(null=True)
|
||||
context_words = TextField(null=True)
|
||||
full_context = TextField(null=True)
|
||||
full_context_embedding = TextField(null=True)
|
||||
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
|
|
@ -324,6 +322,7 @@ class MemoryChest(BaseModel):
|
|||
|
||||
title = TextField() # 标题
|
||||
content = TextField() # 内容
|
||||
locked = BooleanField(default=False) # 是否锁定
|
||||
|
||||
class Meta:
|
||||
table_name = "memory_chest"
|
||||
|
|
|
|||
|
|
@ -6,211 +6,6 @@ from src.common.logger import get_logger
|
|||
|
||||
logger = get_logger("migrate")
|
||||
|
||||
|
||||
async def migrate_memory_items_to_string():
|
||||
"""
|
||||
将数据库中记忆节点的memory_items从list格式迁移到string格式
|
||||
并根据原始list的项目数量设置weight值
|
||||
"""
|
||||
logger.info("开始迁移记忆节点格式...")
|
||||
|
||||
migration_stats = {
|
||||
"total_nodes": 0,
|
||||
"converted_nodes": 0,
|
||||
"already_string_nodes": 0,
|
||||
"empty_nodes": 0,
|
||||
"error_nodes": 0,
|
||||
"weight_updated_nodes": 0,
|
||||
"truncated_nodes": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
# 获取所有图节点
|
||||
all_nodes = GraphNodes.select()
|
||||
migration_stats["total_nodes"] = all_nodes.count()
|
||||
|
||||
logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点")
|
||||
|
||||
for node in all_nodes:
|
||||
try:
|
||||
concept = node.concept
|
||||
memory_items_raw = node.memory_items.strip() if node.memory_items else ""
|
||||
original_weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
|
||||
|
||||
# 如果为空,跳过
|
||||
if not memory_items_raw:
|
||||
migration_stats["empty_nodes"] += 1
|
||||
logger.debug(f"跳过空节点: {concept}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 尝试解析JSON
|
||||
parsed_data = json.loads(memory_items_raw)
|
||||
|
||||
if isinstance(parsed_data, list):
|
||||
# 如果是list格式,需要转换
|
||||
if parsed_data:
|
||||
# 转换为字符串格式
|
||||
new_memory_items = " | ".join(str(item) for item in parsed_data)
|
||||
original_length = len(new_memory_items)
|
||||
|
||||
# 检查长度并截断
|
||||
if len(new_memory_items) > 100:
|
||||
new_memory_items = new_memory_items[:100]
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
new_weight = float(len(parsed_data)) # weight = list项目数量
|
||||
|
||||
# 更新数据库
|
||||
node.memory_items = new_memory_items
|
||||
node.weight = new_weight
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
migration_stats["weight_updated_nodes"] += 1
|
||||
|
||||
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
|
||||
logger.info(
|
||||
f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}"
|
||||
)
|
||||
else:
|
||||
# 空list,设置为空字符串
|
||||
node.memory_items = ""
|
||||
node.weight = 1.0
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
logger.debug(f"转换空list节点: {concept}")
|
||||
|
||||
elif isinstance(parsed_data, str):
|
||||
# 已经是字符串格式,检查长度和weight
|
||||
current_content = parsed_data
|
||||
original_length = len(current_content)
|
||||
content_truncated = False
|
||||
|
||||
# 检查长度并截断
|
||||
if len(current_content) > 100:
|
||||
current_content = current_content[:100]
|
||||
content_truncated = True
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
node.memory_items = current_content
|
||||
logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
# 检查weight是否需要更新
|
||||
update_needed = False
|
||||
if original_weight == 1.0:
|
||||
# 如果weight还是默认值,可以根据内容复杂度估算
|
||||
content_parts = (
|
||||
current_content.split(" | ") if " | " in current_content else [current_content]
|
||||
)
|
||||
estimated_weight = max(1.0, float(len(content_parts)))
|
||||
|
||||
if estimated_weight != original_weight:
|
||||
node.weight = estimated_weight
|
||||
update_needed = True
|
||||
logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}")
|
||||
|
||||
# 如果内容被截断或权重需要更新,保存到数据库
|
||||
if content_truncated or update_needed:
|
||||
node.save()
|
||||
if update_needed:
|
||||
migration_stats["weight_updated_nodes"] += 1
|
||||
if content_truncated:
|
||||
migration_stats["converted_nodes"] += 1 # 算作转换节点
|
||||
else:
|
||||
migration_stats["already_string_nodes"] += 1
|
||||
else:
|
||||
migration_stats["already_string_nodes"] += 1
|
||||
|
||||
else:
|
||||
# 其他JSON类型,转换为字符串
|
||||
new_memory_items = str(parsed_data) if parsed_data else ""
|
||||
original_length = len(new_memory_items)
|
||||
|
||||
# 检查长度并截断
|
||||
if len(new_memory_items) > 100:
|
||||
new_memory_items = new_memory_items[:100]
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
node.memory_items = new_memory_items
|
||||
node.weight = 1.0
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
|
||||
logger.debug(f"转换其他类型节点: {concept}{length_info}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 不是JSON格式,假设已经是纯字符串
|
||||
# 检查是否是带引号的字符串
|
||||
if memory_items_raw.startswith('"') and memory_items_raw.endswith('"'):
|
||||
# 去掉引号
|
||||
clean_content = memory_items_raw[1:-1]
|
||||
original_length = len(clean_content)
|
||||
|
||||
# 检查长度并截断
|
||||
if len(clean_content) > 100:
|
||||
clean_content = clean_content[:100]
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
node.memory_items = clean_content
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
|
||||
logger.debug(f"去除引号节点: {concept}{length_info}")
|
||||
else:
|
||||
# 已经是纯字符串格式,检查长度
|
||||
current_content = memory_items_raw
|
||||
original_length = len(current_content)
|
||||
|
||||
# 检查长度并截断
|
||||
if len(current_content) > 100:
|
||||
current_content = current_content[:100]
|
||||
node.memory_items = current_content
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1 # 算作转换节点
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
else:
|
||||
migration_stats["already_string_nodes"] += 1
|
||||
logger.debug(f"已是字符串格式节点: {concept}")
|
||||
|
||||
except Exception as e:
|
||||
migration_stats["error_nodes"] += 1
|
||||
logger.error(f"处理节点 {concept} 时发生错误: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生严重错误: {e}")
|
||||
raise
|
||||
|
||||
# 输出迁移统计
|
||||
logger.info("=== 记忆节点迁移完成 ===")
|
||||
logger.info(f"总节点数: {migration_stats['total_nodes']}")
|
||||
logger.info(f"已转换节点: {migration_stats['converted_nodes']}")
|
||||
logger.info(f"已是字符串格式: {migration_stats['already_string_nodes']}")
|
||||
logger.info(f"空节点: {migration_stats['empty_nodes']}")
|
||||
logger.info(f"错误节点: {migration_stats['error_nodes']}")
|
||||
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
|
||||
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
|
||||
|
||||
success_rate = (
|
||||
(migration_stats["converted_nodes"] + migration_stats["already_string_nodes"])
|
||||
/ migration_stats["total_nodes"]
|
||||
* 100
|
||||
if migration_stats["total_nodes"] > 0
|
||||
else 0
|
||||
)
|
||||
logger.info(f"迁移成功率: {success_rate:.1f}%")
|
||||
|
||||
return migration_stats
|
||||
|
||||
|
||||
async def set_all_person_known():
|
||||
"""
|
||||
将person_info库中所有记录的is_known字段设置为True
|
||||
|
|
@ -312,7 +107,6 @@ async def check_and_run_migrations():
|
|||
# 执行迁移函数
|
||||
# 依次执行两个异步函数
|
||||
await asyncio.sleep(3)
|
||||
await migrate_memory_items_to_string()
|
||||
await set_all_person_known()
|
||||
# 创建done.mem文件
|
||||
with open(done_file, "w", encoding="utf-8") as f:
|
||||
|
|
|
|||
Loading…
Reference in New Issue