remove:移除无用代码

pull/1273/head
SengokuCola 2025-09-29 21:24:23 +08:00
parent a3390f6cba
commit d519406e4a
5 changed files with 40 additions and 771 deletions

View File

@ -248,8 +248,6 @@ class ExpressionLearner:
style,
_context,
_context_words,
_full_context,
_full_context_embedding,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
@ -263,8 +261,6 @@ class ExpressionLearner:
style,
context,
context_words,
full_context,
full_context_embedding,
) in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
@ -274,8 +270,6 @@ class ExpressionLearner:
"style": style,
"context": context,
"context_words": context_words,
"full_context": full_context,
"full_context_embedding": full_context_embedding,
}
)
@ -299,8 +293,6 @@ class ExpressionLearner:
expr_obj.style = new_expr["style"]
expr_obj.context = new_expr["context"]
expr_obj.context_words = new_expr["context_words"]
expr_obj.full_context = new_expr["full_context"]
expr_obj.full_context_embedding = new_expr["full_context_embedding"]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
expr_obj.save()
@ -315,8 +307,6 @@ class ExpressionLearner:
create_date=current_time, # 手动设置创建日期
context=new_expr["context"],
context_words=new_expr["context_words"],
full_context=new_expr["full_context"],
full_context_embedding=new_expr["full_context_embedding"],
)
# 限制最大数量
exprs = list(
@ -428,7 +418,7 @@ class ExpressionLearner:
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, List[str], str, List[float]]]]:
) -> Optional[List[Tuple[str, str, str, List[str]]]]:
"""从指定聊天流学习表达方式
Args:
@ -482,19 +472,14 @@ class ExpressionLearner:
)
split_matched_expressions_w_emb = []
full_context_embedding: List[float] = await self.get_full_context_embedding(random_msg_match_str)
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append(
(self.chat_id, situation, style, context, context_words, random_msg_match_str, full_context_embedding)
(self.chat_id, situation, style, context, context_words)
)
return split_matched_expressions_w_emb
async def get_full_context_embedding(self, context: str) -> List[float]:
embedding, _ = await self.embedding_model.get_embedding(context)
return embedding
def split_expression_context(
self, matched_expressions: List[Tuple[str, str, str]]
) -> List[Tuple[str, str, str, List[str]]]:

View File

@ -7,21 +7,17 @@ import re
import jieba
import networkx as nx
import numpy as np
from typing import List, Tuple, Set, Coroutine, Any
from typing import List, Tuple, Set
from collections import Counter
import traceback
from rich.traceback import install
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import model_config
from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入
from src.common.logger import get_logger
from src.chat.utils.utils import cut_key_words
from src.chat.utils.chat_message_builder import (
build_readable_messages,
) # 导入 build_readable_messages
# 添加cosine_similarity函数
@ -86,29 +82,10 @@ class MemoryGraph:
if "memory_items" in self.G.nodes[concept]:
# 获取现有的记忆项已经是str格式
existing_memory = self.G.nodes[concept]["memory_items"]
# 如果现有记忆不为空则使用LLM整合新旧记忆
if existing_memory and hippocampus_instance and hippocampus_instance.model_small:
try:
integrated_memory = await self._integrate_memories_with_llm(
existing_memory, str(memory), hippocampus_instance.model_small
)
self.G.nodes[concept]["memory_items"] = integrated_memory
# 整合成功,增加权重
current_weight = self.G.nodes[concept].get("weight", 0.0)
self.G.nodes[concept]["weight"] = current_weight + 1.0
logger.debug(f"节点 {concept} 记忆整合成功,权重增加到 {current_weight + 1.0}")
logger.info(f"节点 {concept} 记忆内容已更新:{integrated_memory}")
except Exception as e:
logger.error(f"LLM整合记忆失败: {e}")
# 降级到简单连接
new_memory_str = f"{existing_memory} | {memory}"
self.G.nodes[concept]["memory_items"] = new_memory_str
logger.info(f"节点 {concept} 记忆内容已简单拼接并更新:{new_memory_str}")
else:
new_memory_str = str(memory)
self.G.nodes[concept]["memory_items"] = new_memory_str
logger.info(f"节点 {concept} 记忆内容已直接更新:{new_memory_str}")
# 简单连接新旧记忆
new_memory_str = f"{existing_memory} | {memory}"
self.G.nodes[concept]["memory_items"] = new_memory_str
logger.info(f"节点 {concept} 记忆内容已简单拼接并更新:{new_memory_str}")
else:
self.G.nodes[concept]["memory_items"] = str(memory)
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
@ -164,53 +141,6 @@ class MemoryGraph:
return first_layer_items, second_layer_items
async def _integrate_memories_with_llm(self, existing_memory: str, new_memory: str, llm_model: LLMRequest) -> str:
"""
使用LLM整合新旧记忆内容
Args:
existing_memory: 现有的记忆内容字符串格式可能包含多条记忆
new_memory: 新的记忆内容
llm_model: LLM模型实例
Returns:
str: 整合后的记忆内容
"""
try:
# 构建整合提示
integration_prompt = f"""你是一个记忆整合专家。请将以下的旧记忆和新记忆整合成一条更完整、更准确的记忆内容。
旧记忆内容
{existing_memory}
新记忆内容
{new_memory}
整合要求
1. 保留重要信息去除重复内容
2. 如果新旧记忆有冲突合理整合矛盾的地方
3. 将相关信息合并形成更完整的描述
4. 保持语言简洁准确
5. 只返回整合后的记忆内容不要添加任何解释
整合后的记忆"""
# 调用LLM进行整合
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(
integration_prompt
)
if content and content.strip():
integrated_content = content.strip()
logger.debug(f"LLM记忆整合成功模型: {model_name}")
return integrated_content
else:
logger.warning("LLM返回的整合结果为空使用默认连接方式")
return f"{existing_memory} | {new_memory}"
except Exception as e:
logger.error(f"LLM记忆整合过程中出错: {e}")
return f"{existing_memory} | {new_memory}"
@property
def dots(self):
@ -242,7 +172,6 @@ class MemoryGraph:
class Hippocampus:
def __init__(self):
self.memory_graph = MemoryGraph()
self.model_small: LLMRequest = None # type: ignore
self.entorhinal_cortex: EntorhinalCortex = None # type: ignore
self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore
@ -252,44 +181,11 @@ class Hippocampus:
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
self.model_small = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="memory.modify"
)
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
return list(self.memory_graph.G.nodes())
def calculate_weighted_activation(self, current_activation: float, edge_strength: int, target_node: str) -> float:
"""
计算考虑节点权重的激活值
Args:
current_activation: 当前激活值
edge_strength: 边的强度
target_node: 目标节点名称
Returns:
float: 计算后的激活值
"""
# 基础激活值计算
base_activation = current_activation - (1 / edge_strength)
if base_activation <= 0:
return 0.0
# 获取目标节点的权重
if target_node in self.memory_graph.G:
node_data = self.memory_graph.G.nodes[target_node]
node_weight = node_data.get("weight", 1.0)
# 权重加成每次整合增加10%激活值最大加成200%
weight_multiplier = 1.0 + min((node_weight - 1.0) * 0.1, 2.0)
return base_activation * weight_multiplier
else:
return base_activation
@staticmethod
def calculate_node_hash(concept, memory_items) -> int:
"""计算节点的特征值"""
@ -309,45 +205,8 @@ class Hippocampus:
# 直接使用元组,保证顺序一致性
return hash((source, target))
@staticmethod
def find_topic_llm(text: str, topic_num: int | list[int]):
# sourcery skip: inline-immediately-returned-variable
topic_num_str = ""
if isinstance(topic_num, list):
topic_num_str = f"{topic_num[0]}-{topic_num[1]}"
else:
topic_num_str = topic_num
prompt = (
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num_str}个关键的概念,必须是某种概念,比如人,事,物,概念,事件,地点 等等,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
f"如果确定找不出主题或者没有明显主题,返回<none>。"
)
return prompt
@staticmethod
def topic_what(text, topic):
# sourcery skip: inline-immediately-returned-variable
# 不再需要 time_info 参数
prompt = (
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成几句自然的话,'
f"要求包含对这个概念的定义,内容,知识,时间和人物,这些信息必须来自这段文字,不能添加信息。\n只输出几句自然的话就好"
)
return prompt
@staticmethod
def calculate_topic_num(text, compress_rate):
"""计算文本的话题数量"""
information_content = calculate_information_content(text)
topic_by_length = text.count("\n") * compress_rate
topic_by_information_content = max(1, min(5, int((information_content - 3) * 2)))
topic_num = int((topic_by_length + topic_by_information_content) / 2)
logger.debug(
f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, "
f"topic_num: {topic_num}"
)
return topic_num
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
"""从关键词获取相关记忆。
@ -1039,395 +898,16 @@ class EntorhinalCortex:
)
# 负责整合,遗忘,合并记忆
# 负责记忆管理
class ParahippocampalGyrus:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.memory_modify_model = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="memory.modify"
)
async def memory_compress(self, messages: list[DatabaseMessages], compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。
Args:
messages (list): 消息列表每个消息是一个字典包含数据库消息结构
compress_rate (float, optional): 压缩率用于控制生成的主题数量默认为0.1
Returns:
tuple: (compressed_memory, similar_topics_dict)
- compressed_memory: set, 压缩后的记忆集合每个元素是一个元组 (topic, summary)
- similar_topics_dict: dict, 相似主题字典
Process:
1. 使用 build_readable_messages 生成包含时间人物信息的格式化文本
2. 使用LLM提取关键主题
3. 过滤掉包含禁用关键词的主题
4. 为每个主题生成摘要
5. 查找与现有记忆中的相似主题
"""
if not messages:
return set(), {}
# 1. 使用 build_readable_messages 生成格式化文本
# build_readable_messages 只返回一个字符串,不需要解包
input_text = build_readable_messages(
messages,
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
replace_bot_name=False, # 保留原始用户名
)
# 如果生成的可读文本为空(例如所有消息都无效),则直接返回
if not input_text:
logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。")
return set(), {}
current_date = f"当前日期: {datetime.datetime.now().isoformat()}"
input_text = f"{current_date}\n{input_text}"
logger.debug(f"记忆来源:\n{input_text}")
# 2. 使用LLM提取关键主题
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
topics_response, _ = await self.memory_modify_model.generate_response_async(
self.hippocampus.find_topic_llm(input_text, topic_num)
)
# 提取<>中的内容
topics = re.findall(r"<([^>]+)>", topics_response)
if not topics:
topics = ["none"]
else:
topics = [
topic.strip()
for topic in ",".join(topics).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if topic.strip()
]
# 3. 过滤掉包含禁用关键词的topic
filtered_topics = [
topic for topic in topics if all(keyword not in topic for keyword in global_config.memory.memory_ban_words)
]
logger.debug(f"过滤后话题: {filtered_topics}")
# 4. 创建所有话题的摘要生成任务
tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List | None]]]]] = []
topic_what_prompt: str = ""
for topic in filtered_topics:
# 调用修改后的 topic_what不再需要 time_info
topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
try:
task = self.memory_modify_model.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
except Exception as e:
logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}")
continue
# 等待所有任务完成
compressed_memory: Set[Tuple[str, str]] = set()
similar_topics_dict = {}
for topic, task in tasks:
response = await task
if response:
compressed_memory.add((topic, response[0]))
existing_topics = list(self.memory_graph.G.nodes())
similar_topics = []
for existing_topic in existing_topics:
topic_words = set(jieba.cut(topic))
existing_words = set(jieba.cut(existing_topic))
all_words = topic_words | existing_words
v1 = [1 if word in topic_words else 0 for word in all_words]
v2 = [1 if word in existing_words else 0 for word in all_words]
similarity = cosine_similarity(v1, v2)
if similarity >= 0.7:
similar_topics.append((existing_topic, similarity))
similar_topics.sort(key=lambda x: x[1], reverse=True)
similar_topics = similar_topics[:3]
similar_topics_dict[topic] = similar_topics
if global_config.debug.show_prompt:
logger.info(f"prompt: {topic_what_prompt}")
logger.info(f"压缩后的记忆: {compressed_memory}")
logger.info(f"相似主题: {similar_topics_dict}")
return compressed_memory, similar_topics_dict
def get_similar_topics_from_keywords(
self,
keywords: list[str] | str,
top_k: int = 3,
threshold: float = 0.7,
) -> dict[str, list[tuple[str, float]]]:
"""基于输入的关键词,返回每个关键词对应的相似主题列表。
Args:
keywords: 关键词列表或以逗号/空格/顿号分隔的字符串
top_k: 每个关键词返回的相似主题数量上限
threshold: 相似度阈值低于该值的主题将被过滤
Returns:
dict[str, list[tuple[str, float]]]: {keyword: [(topic, similarity), ...]}
"""
# 规范化输入为列表[str]
if isinstance(keywords, str):
# 支持中英文逗号、顿号、空格分隔
parts = keywords.replace("", ",").replace("", ",").replace(" ", ",").strip(", ")
keyword_list = [p.strip() for p in parts.split(",") if p.strip()]
else:
keyword_list = [k.strip() for k in keywords if isinstance(k, str) and k.strip()]
if not keyword_list:
return {}
existing_topics = list(self.memory_graph.G.nodes())
result: dict[str, list[tuple[str, float]]] = {}
for kw in keyword_list:
kw_words = set(jieba.cut(kw))
similar_topics: list[tuple[str, float]] = []
for topic in existing_topics:
topic_words = set(jieba.cut(topic))
all_words = kw_words | topic_words
if not all_words:
continue
v1 = [1 if w in kw_words else 0 for w in all_words]
v2 = [1 if w in topic_words else 0 for w in all_words]
sim = cosine_similarity(v1, v2)
if sim >= threshold:
similar_topics.append((topic, sim))
similar_topics.sort(key=lambda x: x[1], reverse=True)
result[kw] = similar_topics[:top_k]
return result
async def add_memory_with_similar(
self,
memory_item: str,
similar_topics_dict: dict[str, list[tuple[str, float]]],
) -> bool:
"""将单条记忆内容与相似主题写入记忆网络并同步数据库。
build_memory_for_chat 的方式 similar_topics_dict 的每个键作为主题添加节点内容
并与其相似主题建立连接连接强度为 int(similarity * 10)
Args:
memory_item: 记忆内容字符串将作为每个主题节点的 memory_items
similar_topics_dict: {topic: [(similar_topic, similarity), ...]}
Returns:
bool: 是否成功执行添加与同步
"""
try:
if not memory_item or not isinstance(memory_item, str):
return False
if not similar_topics_dict or not isinstance(similar_topics_dict, dict):
return False
current_time = time.time()
# 为每个主题写入节点
for topic, similar_list in similar_topics_dict.items():
if not topic or not isinstance(topic, str):
continue
await self.hippocampus.memory_graph.add_dot(topic, memory_item, self.hippocampus)
# 连接相似主题
if isinstance(similar_list, list):
for item in similar_list:
try:
similar_topic, similarity = item
except Exception:
continue
if not isinstance(similar_topic, str):
continue
if topic == similar_topic:
continue
# 强度按 build_memory_for_chat 的规则
strength = int(max(0.0, float(similarity)) * 10) if similarity is not None else 0
if strength <= 0:
continue
# 确保相似主题节点存在如果没有也可以只建立边networkx会创建节点但需初始化属性
if similar_topic not in self.memory_graph.G:
# 创建一个空的相似主题节点避免悬空边memory_items 为空字符串
self.memory_graph.G.add_node(
similar_topic,
memory_items="",
weight=1.0,
created_time=current_time,
last_modified=current_time,
)
self.memory_graph.G.add_edge(
topic,
similar_topic,
strength=strength,
created_time=current_time,
last_modified=current_time,
)
# 同步数据库
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
return True
except Exception as e:
logger.error(f"添加记忆节点失败: {e}")
return False
async def operation_forget_topic(self, percentage=0.005):
start_time = time.time()
logger.info("[遗忘] 开始检查数据库...")
# 验证百分比参数
if not 0 <= percentage <= 1:
logger.warning(f"[遗忘] 无效的遗忘百分比: {percentage}, 使用默认值 0.005")
percentage = 0.005
all_nodes = list(self.memory_graph.G.nodes())
all_edges = list(self.memory_graph.G.edges())
if not all_nodes and not all_edges:
logger.info("[遗忘] 记忆图为空,无需进行遗忘操作")
return
# 确保至少检查1个节点和边且不超过总数
check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage)))
check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage)))
# 只有在有足够的节点和边时才进行采样
if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count:
try:
nodes_to_check = random.sample(all_nodes, check_nodes_count)
edges_to_check = random.sample(all_edges, check_edges_count)
except ValueError as e:
logger.error(f"[遗忘] 采样错误: {str(e)}")
return
else:
logger.info("[遗忘] 没有足够的节点或边进行遗忘操作")
return
# 使用列表存储变化信息
edge_changes = {
"weakened": [], # 存储减弱的边
"removed": [], # 存储移除的边
}
node_changes = {
"reduced": [], # 存储减少记忆的节点
"removed": [], # 存储移除的节点
}
current_time = datetime.datetime.now().timestamp()
logger.info("[遗忘] 开始检查连接...")
edge_check_start = time.time()
for source, target in edges_to_check:
edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get("last_modified")
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
current_strength = edge_data.get("strength", 1)
new_strength = current_strength - 1
if new_strength <= 0:
self.memory_graph.G.remove_edge(source, target)
edge_changes["removed"].append(f"{source} -> {target}")
else:
edge_data["strength"] = new_strength
edge_data["last_modified"] = current_time
edge_changes["weakened"].append(f"{source}-{target} (强度: {current_strength} -> {new_strength})")
edge_check_end = time.time()
logger.info(f"[遗忘] 连接检查耗时: {edge_check_end - edge_check_start:.2f}")
logger.info("[遗忘] 开始检查节点...")
node_check_start = time.time()
for node in nodes_to_check:
# 检查节点是否存在,以防在迭代中被移除(例如边移除导致)
if node not in self.memory_graph.G:
continue
node_data = self.memory_graph.G.nodes[node]
# 首先获取记忆项
memory_items = node_data.get("memory_items", "")
# 直接检查记忆内容是否为空
if not memory_items or memory_items.strip() == "":
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node}(空节点)") # 标记为空节点移除
logger.debug(f"[遗忘] 移除了空的节点: {node}")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 移除空节点 {node} 时发生错误(可能已被移除): {e}")
continue # 处理下一个节点
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
last_modified = node_data.get("last_modified", current_time)
node_weight = node_data.get("weight", 1.0)
# 条件1检查是否长时间未修改 (使用配置的遗忘时间)
time_threshold = 3600 * global_config.memory.memory_forget_time
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
# 权重为1时使用默认阈值权重越高阈值越大越难遗忘
adjusted_threshold = time_threshold * node_weight
if current_time - last_modified > adjusted_threshold and memory_items:
# 既然每个节点现在是完整记忆,直接删除整个节点
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node}(长时间未修改,权重{node_weight:.1f})")
logger.debug(f"[遗忘] 移除了长时间未修改的节点: {node} (权重: {node_weight:.1f})")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误(可能已被移除): {e}")
continue
node_check_end = time.time()
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}")
if any(edge_changes.values()) or any(node_changes.values()):
sync_start = time.time()
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
sync_end = time.time()
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}")
# 汇总输出所有变化
logger.info("[遗忘] 遗忘操作统计:")
if edge_changes["weakened"]:
logger.info(
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}"
)
if edge_changes["removed"]:
logger.info(
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}"
)
if node_changes["reduced"]:
logger.info(
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}"
)
if node_changes["removed"]:
logger.info(
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}"
)
else:
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
end_time = time.time()
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}")
class HippocampusManager:
@ -1462,11 +942,6 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return self._hippocampus
async def forget_memory(self, percentage: float = 0.005):
"""遗忘记忆的公共接口"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
async def get_memory_from_topic(
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3

View File

@ -32,7 +32,7 @@ class MemoryChest:
request_type="memory_chest_build",
)
self.memory_build_threshold = 30
self.memory_build_threshold = 20
self.memory_size_limit = global_config.memory.max_memory_size
self.running_content_list = {} # {chat_id: {"content": running_content, "last_update_time": timestamp, "create_time": timestamp}}
@ -154,10 +154,13 @@ class MemoryChest:
def get_all_titles(self) -> list[str]:
def get_all_titles(self, exclude_locked: bool = False) -> list[str]:
"""
获取记忆仓库中的所有标题
Args:
exclude_locked: 是否排除锁定的记忆默认为 False
Returns:
list: 包含所有标题的列表
"""
@ -166,6 +169,9 @@ class MemoryChest:
titles = []
for memory in MemoryChestModel.select():
if memory.title:
# 如果 exclude_locked 为 True 且记忆已锁定,则跳过
if exclude_locked and memory.locked:
continue
titles.append(memory.title)
return titles
except Exception as e:
@ -261,8 +267,8 @@ class MemoryChest:
Returns:
str: 选择的标题
"""
# 获取所有标题并构建格式化字符串
titles = self.get_all_titles()
# 获取所有标题并构建格式化字符串(排除锁定的记忆)
titles = self.get_all_titles(exclude_locked=True)
formatted_titles = ""
for title in titles:
formatted_titles += f"{title}\n"
@ -375,15 +381,15 @@ class MemoryChest:
async def choose_merge_target(self, memory_title: str) -> list[str]:
"""
选择与给定记忆标题相关的记忆目标
Args:
memory_title: 要匹配的记忆标题
Returns:
list[str]: 选中的记忆内容列表
"""
try:
all_titles = self.get_all_titles()
all_titles = self.get_all_titles(exclude_locked=True)
content = ""
for title in all_titles:
content += f"{title}\n"
@ -430,10 +436,10 @@ class MemoryChest:
def _get_memories_by_titles(self, titles: list[str]) -> list[str]:
"""
根据标题列表查找对应的记忆内容
Args:
titles: 记忆标题列表
Returns:
list[str]: 记忆内容列表
"""
@ -442,22 +448,32 @@ class MemoryChest:
for title in titles:
if not title or not title.strip():
continue
# 使用模糊查找匹配记忆
try:
best_match = find_best_matching_memory(title.strip(), similarity_threshold=0.8)
if best_match:
contents.append(best_match[1]) # best_match[1] 是 content
logger.debug(f"找到记忆: {best_match[0]} (相似度: {best_match[2]:.3f})")
# 检查记忆是否被锁定
memory_title = best_match[0]
memory_content = best_match[1]
# 查询数据库中的锁定状态
for memory in MemoryChestModel.select():
if memory.title == memory_title and memory.locked:
logger.warning(f"记忆 '{memory_title}' 已锁定,跳过合并")
continue
contents.append(memory_content)
logger.debug(f"找到记忆: {memory_title} (相似度: {best_match[2]:.3f})")
else:
logger.warning(f"未找到相似度 >= 0.8 的标题匹配: '{title}'")
except Exception as e:
logger.error(f"查找标题 '{title}' 的记忆时出错: {e}")
continue
logger.info(f"成功找到 {len(contents)} 条记忆内容")
return contents
except Exception as e:
logger.error(f"根据标题查找记忆时出错: {e}")
return []

View File

@ -306,8 +306,6 @@ class Expression(BaseModel):
# new mode fields
context = TextField(null=True)
context_words = TextField(null=True)
full_context = TextField(null=True)
full_context_embedding = TextField(null=True)
last_active_time = FloatField()
chat_id = TextField(index=True)
@ -324,6 +322,7 @@ class MemoryChest(BaseModel):
title = TextField() # 标题
content = TextField() # 内容
locked = BooleanField(default=False) # 是否锁定
class Meta:
table_name = "memory_chest"

View File

@ -6,211 +6,6 @@ from src.common.logger import get_logger
logger = get_logger("migrate")
async def migrate_memory_items_to_string():
"""
将数据库中记忆节点的memory_items从list格式迁移到string格式
并根据原始list的项目数量设置weight值
"""
logger.info("开始迁移记忆节点格式...")
migration_stats = {
"total_nodes": 0,
"converted_nodes": 0,
"already_string_nodes": 0,
"empty_nodes": 0,
"error_nodes": 0,
"weight_updated_nodes": 0,
"truncated_nodes": 0,
}
try:
# 获取所有图节点
all_nodes = GraphNodes.select()
migration_stats["total_nodes"] = all_nodes.count()
logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点")
for node in all_nodes:
try:
concept = node.concept
memory_items_raw = node.memory_items.strip() if node.memory_items else ""
original_weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
# 如果为空,跳过
if not memory_items_raw:
migration_stats["empty_nodes"] += 1
logger.debug(f"跳过空节点: {concept}")
continue
try:
# 尝试解析JSON
parsed_data = json.loads(memory_items_raw)
if isinstance(parsed_data, list):
# 如果是list格式需要转换
if parsed_data:
# 转换为字符串格式
new_memory_items = " | ".join(str(item) for item in parsed_data)
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符")
new_weight = float(len(parsed_data)) # weight = list项目数量
# 更新数据库
node.memory_items = new_memory_items
node.weight = new_weight
node.save()
migration_stats["converted_nodes"] += 1
migration_stats["weight_updated_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.info(
f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}"
)
else:
# 空list设置为空字符串
node.memory_items = ""
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
logger.debug(f"转换空list节点: {concept}")
elif isinstance(parsed_data, str):
# 已经是字符串格式检查长度和weight
current_content = parsed_data
original_length = len(current_content)
content_truncated = False
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
content_truncated = True
migration_stats["truncated_nodes"] += 1
node.memory_items = current_content
logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符")
# 检查weight是否需要更新
update_needed = False
if original_weight == 1.0:
# 如果weight还是默认值可以根据内容复杂度估算
content_parts = (
current_content.split(" | ") if " | " in current_content else [current_content]
)
estimated_weight = max(1.0, float(len(content_parts)))
if estimated_weight != original_weight:
node.weight = estimated_weight
update_needed = True
logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}")
# 如果内容被截断或权重需要更新,保存到数据库
if content_truncated or update_needed:
node.save()
if update_needed:
migration_stats["weight_updated_nodes"] += 1
if content_truncated:
migration_stats["converted_nodes"] += 1 # 算作转换节点
else:
migration_stats["already_string_nodes"] += 1
else:
migration_stats["already_string_nodes"] += 1
else:
# 其他JSON类型转换为字符串
new_memory_items = str(parsed_data) if parsed_data else ""
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = new_memory_items
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"转换其他类型节点: {concept}{length_info}")
except json.JSONDecodeError:
# 不是JSON格式假设已经是纯字符串
# 检查是否是带引号的字符串
if memory_items_raw.startswith('"') and memory_items_raw.endswith('"'):
# 去掉引号
clean_content = memory_items_raw[1:-1]
original_length = len(clean_content)
# 检查长度并截断
if len(clean_content) > 100:
clean_content = clean_content[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = clean_content
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"去除引号节点: {concept}{length_info}")
else:
# 已经是纯字符串格式,检查长度
current_content = memory_items_raw
original_length = len(current_content)
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
node.memory_items = current_content
node.save()
migration_stats["converted_nodes"] += 1 # 算作转换节点
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符")
else:
migration_stats["already_string_nodes"] += 1
logger.debug(f"已是字符串格式节点: {concept}")
except Exception as e:
migration_stats["error_nodes"] += 1
logger.error(f"处理节点 {concept} 时发生错误: {e}")
continue
except Exception as e:
logger.error(f"迁移过程中发生严重错误: {e}")
raise
# 输出迁移统计
logger.info("=== 记忆节点迁移完成 ===")
logger.info(f"总节点数: {migration_stats['total_nodes']}")
logger.info(f"已转换节点: {migration_stats['converted_nodes']}")
logger.info(f"已是字符串格式: {migration_stats['already_string_nodes']}")
logger.info(f"空节点: {migration_stats['empty_nodes']}")
logger.info(f"错误节点: {migration_stats['error_nodes']}")
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
success_rate = (
(migration_stats["converted_nodes"] + migration_stats["already_string_nodes"])
/ migration_stats["total_nodes"]
* 100
if migration_stats["total_nodes"] > 0
else 0
)
logger.info(f"迁移成功率: {success_rate:.1f}%")
return migration_stats
async def set_all_person_known():
"""
将person_info库中所有记录的is_known字段设置为True
@ -312,7 +107,6 @@ async def check_and_run_migrations():
# 执行迁移函数
# 依次执行两个异步函数
await asyncio.sleep(3)
await migrate_memory_items_to_string()
await set_all_person_known()
# 创建done.mem文件
with open(done_file, "w", encoding="utf-8") as f: