MaiBot/src/plugins/memory_system/draw_memory.py

299 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import os
import sys
import time
import jieba
import matplotlib.pyplot as plt
import networkx as nx
from dotenv import load_dotenv
from loguru import logger
# from src.common.logger import get_module_logger
# logger = get_module_logger("draw_memory")
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
print(root_path)
from src.common.database import db # noqa: E402
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), ".env.dev")
load_dotenv(env_path)
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if "memory_items" in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]["memory_items"], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
self.G.nodes[concept]["memory_items"].append(memory)
else:
self.G.nodes[concept]["memory_items"] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
# print(f"第二层: {neighbor}")
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
def store_memory(self):
for node in self.G.nodes():
dot_data = {"concept": node}
db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ""
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)]) # 调试输出
logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}"
)
if closest_record:
closest_time = closest_record["time"]
group_id = closest_record["group_id"] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort("time", 1).limit(length)
)
for record in chat_record:
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(record["time"])))
try:
displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
except (KeyError, TypeError):
# 处理缺少键或类型错误的情况
displayname = record.get("user_nickname", "") or "用户" + str(record.get("user_id", "未知"))
chat_text += f"[{time_str}] {displayname}: {record['processed_plain_text']}\n" # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self):
# 清空现有的图数据
db.graph_data.delete_many({})
# 保存节点
for node in self.G.nodes(data=True):
node_data = {
"concept": node[0],
"memory_items": node[1].get("memory_items", []), # 默认为空列表
}
db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
edge_data = {"source": edge[0], "target": edge[1]}
db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node["concept"], memory_items=memory_items)
# 加载边
edges = db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge["source"], edge["target"])
def main():
memory_graph = Memory_graph()
memory_graph.load_graph_from_db()
# 只显示一次优化后的图形
visualize_graph_lite(memory_graph)
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == "退出":
break
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
logger.debug("第一层记忆:")
for item in first_layer_items:
logger.debug(item)
logger.debug("第二层记忆:")
for item in second_layer_items:
logger.debug(item)
else:
logger.debug("未找到相关记忆。")
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
def find_topic(text, topic_num):
prompt = (
f"这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。"
f"只需要列举{topic_num}个话题就好,不要告诉我其他内容。"
)
return prompt
def topic_what(text, topic):
prompt = (
f"这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。"
f"只输出这句话就好"
)
return prompt
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
logger.debug("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
# nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 计算节点大小和颜色
node_colors = []
node_sizes = []
nodes = list(H.nodes())
# 获取最大记忆数和最大度数用于归一化
max_memories = 1
max_degree = 1
for node in nodes:
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
max_memories = max(max_memories, memory_count)
max_degree = max(max_degree, degree)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 500 + 5000 * (ratio) # 使用1.5次方函数使差异不那么明显
node_sizes.append(size)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
# 红色分量随着度数增加而增加
r = (degree / max_degree) ** 0.3
red = min(1.0, r)
# 蓝色分量随着度数减少而增加
blue = max(0.0, 1 - red)
# blue = 1
color = (red, 0.1, blue)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
nx.draw(
H,
pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=10,
font_family="SimHei",
font_weight="bold",
edge_color="gray",
width=0.5,
alpha=0.9,
)
title = "记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数"
plt.title(title, fontsize=16, fontfamily="SimHei")
plt.show()
if __name__ == "__main__":
main()