diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml
index 669fb8a1..2a5f497f 100644
--- a/.github/workflows/docker-image.yml
+++ b/.github/workflows/docker-image.yml
@@ -3,10 +3,11 @@ name: Docker Build and Push
on:
push:
branches:
- - main # 推送到main分支时触发
+ - main
+ - debug # 新增 debug 分支触发
tags:
- - 'v*' # 推送v开头的tag时触发(例如v1.0.0)
- workflow_dispatch: # 允许手动触发
+ - 'v*'
+ workflow_dispatch:
jobs:
build-and-push:
@@ -24,15 +25,24 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
+ - name: Determine Image Tags
+ id: tags
+ run: |
+ if [[ "${{ github.ref }}" == refs/tags/* ]]; then
+ echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }},${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
+ elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
+ echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:main,${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
+ elif [ "${{ github.ref }}" == "refs/heads/debug" ]; then
+ echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:debug" >> $GITHUB_OUTPUT
+ fi
+
- name: Build and Push Docker Image
uses: docker/build-push-action@v5
with:
- context: . # Docker构建上下文路径
- file: ./Dockerfile # Dockerfile路径
- platforms: linux/amd64,linux/arm64 # 支持arm架构
- tags: |
- ${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }}
- ${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest
- push: true
- cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest
- cache-to: type=inline
+ context: .
+ file: ./Dockerfile
+ platforms: linux/amd64,linux/arm64
+ tags: ${{ steps.tags.outputs.tags }}
+ push: true
+ cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache
+ cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache,mode=max
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index f704a19b..51a11d8c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -185,3 +185,6 @@ cython_debug/
# PyPI configuration file
.pypirc
.env
+
+# jieba
+jieba.cache
diff --git a/README.md b/README.md
index c09b33c4..7bfa465a 100644
--- a/README.md
+++ b/README.md
@@ -42,22 +42,22 @@
## 🎯 功能介绍
### 💬 聊天功能
-- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言,目前有bug,所以现在只会检测主题,不会进行存储
+- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言
- 支持bot名字呼唤发言:检测到"麦麦"会主动发言,可配置
-- 使用硅基流动的api进行回复生成,可随机使用R1,V3,R1-distill等模型,未来将加入官网api支持
+- 支持多模型,多厂商自定义配置
- 动态的prompt构建器,更拟人
- 支持图片,转发消息,回复消息的识别
- 错别字和多条回复功能:麦麦可以随机生成错别字,会多条发送回复以及对消息进行reply
### 😊 表情包功能
-- 支持根据发言内容发送对应情绪的表情包:未完善,可以用
-- 会自动偷群友的表情包(未完善,暂时禁用)目前有bug
+- 支持根据发言内容发送对应情绪的表情包
+- 会自动偷群友的表情包
### 📅 日程功能
- 麦麦会自动生成一天的日程,实现更拟人的回复
### 🧠 记忆功能
-- 对聊天记录进行概括存储,在需要时调用,没写完
+- 对聊天记录进行概括存储,在需要时调用,待完善
### 📚 知识库功能
- 基于embedding模型的知识库,手动放入txt会自动识别,写完了,暂时禁用
diff --git a/config/bot_config_template.toml b/config/bot_config_template.toml
index 5ad837f6..28ffb0ce 100644
--- a/config/bot_config_template.toml
+++ b/config/bot_config_template.toml
@@ -11,7 +11,7 @@ prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的
[message]
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
-max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃
+max_context_size = 15 # 麦麦获得的上文数量
emoji_chance = 0.2 # 麦麦使用表情包的概率
ban_words = [
# "403","张三"
@@ -31,6 +31,7 @@ model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概
[memory]
build_memory_interval = 300 # 记忆构建间隔 单位秒
+forget_memory_interval = 300 # 记忆遗忘间隔 单位秒
[others]
enable_advance_output = true # 是否启用高级输出
diff --git a/src/plugins/chat/Segment_builder.py b/src/plugins/chat/Segment_builder.py
new file mode 100644
index 00000000..09673a04
--- /dev/null
+++ b/src/plugins/chat/Segment_builder.py
@@ -0,0 +1,165 @@
+from typing import Dict, List, Union, Optional, Any
+import base64
+import os
+
+"""
+OneBot v11 Message Segment Builder
+
+This module provides classes for building message segments that conform to the
+OneBot v11 standard. These segments can be used to construct complex messages
+for sending through bots that implement the OneBot interface.
+"""
+
+
+
+class Segment:
+ """Base class for all message segments."""
+
+ def __init__(self, type_: str, data: Dict[str, Any]):
+ self.type = type_
+ self.data = data
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert the segment to a dictionary format."""
+ return {
+ "type": self.type,
+ "data": self.data
+ }
+
+
+class Text(Segment):
+ """Text message segment."""
+
+ def __init__(self, text: str):
+ super().__init__("text", {"text": text})
+
+
+class Face(Segment):
+ """Face/emoji message segment."""
+
+ def __init__(self, face_id: int):
+ super().__init__("face", {"id": str(face_id)})
+
+
+class Image(Segment):
+ """Image message segment."""
+
+ @classmethod
+ def from_url(cls, url: str) -> 'Image':
+ """Create an Image segment from a URL."""
+ return cls(url=url)
+
+ @classmethod
+ def from_path(cls, path: str) -> 'Image':
+ """Create an Image segment from a file path."""
+ with open(path, 'rb') as f:
+ file_b64 = base64.b64encode(f.read()).decode('utf-8')
+ return cls(file=f"base64://{file_b64}")
+
+ def __init__(self, file: str = None, url: str = None, cache: bool = True):
+ data = {}
+ if file:
+ data["file"] = file
+ if url:
+ data["url"] = url
+ if not cache:
+ data["cache"] = "0"
+ super().__init__("image", data)
+
+
+class At(Segment):
+ """@Someone message segment."""
+
+ def __init__(self, user_id: Union[int, str]):
+ data = {"qq": str(user_id)}
+ super().__init__("at", data)
+
+
+class Record(Segment):
+ """Voice message segment."""
+
+ def __init__(self, file: str, magic: bool = False, cache: bool = True):
+ data = {"file": file}
+ if magic:
+ data["magic"] = "1"
+ if not cache:
+ data["cache"] = "0"
+ super().__init__("record", data)
+
+
+class Video(Segment):
+ """Video message segment."""
+
+ def __init__(self, file: str):
+ super().__init__("video", {"file": file})
+
+
+class Reply(Segment):
+ """Reply message segment."""
+
+ def __init__(self, message_id: int):
+ super().__init__("reply", {"id": str(message_id)})
+
+
+class MessageBuilder:
+ """Helper class for building complex messages."""
+
+ def __init__(self):
+ self.segments: List[Segment] = []
+
+ def text(self, text: str) -> 'MessageBuilder':
+ """Add a text segment."""
+ self.segments.append(Text(text))
+ return self
+
+ def face(self, face_id: int) -> 'MessageBuilder':
+ """Add a face/emoji segment."""
+ self.segments.append(Face(face_id))
+ return self
+
+ def image(self, file: str = None) -> 'MessageBuilder':
+ """Add an image segment."""
+ self.segments.append(Image(file=file))
+ return self
+
+ def at(self, user_id: Union[int, str]) -> 'MessageBuilder':
+ """Add an @someone segment."""
+ self.segments.append(At(user_id))
+ return self
+
+ def record(self, file: str, magic: bool = False) -> 'MessageBuilder':
+ """Add a voice record segment."""
+ self.segments.append(Record(file, magic))
+ return self
+
+ def video(self, file: str) -> 'MessageBuilder':
+ """Add a video segment."""
+ self.segments.append(Video(file))
+ return self
+
+ def reply(self, message_id: int) -> 'MessageBuilder':
+ """Add a reply segment."""
+ self.segments.append(Reply(message_id))
+ return self
+
+ def build(self) -> List[Dict[str, Any]]:
+ """Build the message into a list of segment dictionaries."""
+ return [segment.to_dict() for segment in self.segments]
+
+
+'''Convenience functions
+def text(content: str) -> Dict[str, Any]:
+ """Create a text message segment."""
+ return Text(content).to_dict()
+
+def image_url(url: str) -> Dict[str, Any]:
+ """Create an image message segment from URL."""
+ return Image.from_url(url).to_dict()
+
+def image_path(path: str) -> Dict[str, Any]:
+ """Create an image message segment from file path."""
+ return Image.from_path(path).to_dict()
+
+def at(user_id: Union[int, str]) -> Dict[str, Any]:
+ """Create an @someone message segment."""
+ return At(user_id).to_dict()'''
\ No newline at end of file
diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py
index ac04866a..ab99f647 100644
--- a/src/plugins/chat/__init__.py
+++ b/src/plugins/chat/__init__.py
@@ -13,6 +13,7 @@ from .willing_manager import willing_manager
from nonebot.rule import to_me
from .bot import chat_bot
from .emoji_manager import emoji_manager
+import time
# 获取驱动器
@@ -86,19 +87,27 @@ async def _(bot: Bot):
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await chat_bot.handle_message(event, bot)
-'''
-@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships")
-async def monitor_relationships():
- """每15秒打印一次关系数据"""
- relationship_manager.print_all_relationships()
-'''
-
# 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task():
"""每30秒执行一次记忆构建"""
- print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
- await hippocampus.build_memory(chat_size=30)
- print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
+ print("\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------")
+ start_time = time.time()
+ await hippocampus.operation_build_memory(chat_size=20)
+ end_time = time.time()
+ print(f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------")
+
+@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
+async def forget_memory_task():
+ """每30秒执行一次记忆构建"""
+ # print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
+ # await hippocampus.operation_forget_topic(percentage=0.1)
+ # print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
+@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")
+async def merge_memory_task():
+ """每30秒执行一次记忆构建"""
+ # print("\033[1;32m[记忆整合]\033[0m 开始整合")
+ # await hippocampus.operation_merge_memory(percentage=0.1)
+ # print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py
index f9488b96..c2651fa8 100644
--- a/src/plugins/chat/bot.py
+++ b/src/plugins/chat/bot.py
@@ -69,11 +69,9 @@ class ChatBot:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
- identifier=topic_identifier.identify_topic()
- if global_config.topic_extract=='llm':
- topic=await identifier(message.processed_plain_text)
- else:
- topic=identifier(message.detailed_plain_text)
+
+ topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
+
# topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
# topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py
index 96c83dfe..be599f48 100644
--- a/src/plugins/chat/config.py
+++ b/src/plugins/chat/config.py
@@ -26,7 +26,8 @@ class BotConfig:
talk_frequency_down_groups = set()
ban_user_id = set()
- build_memory_interval: int = 60 # 记忆构建间隔(秒)
+ build_memory_interval: int = 30 # 记忆构建间隔(秒)
+ forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
@@ -155,6 +156,7 @@ class BotConfig:
if "memory" in toml_dict:
memory_config = toml_dict["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
+ config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
# 群组配置
if "groups" in toml_dict:
@@ -188,6 +190,6 @@ global_config = BotConfig.load_config(config_path=bot_config_path)
if not global_config.enable_advance_output:
- # logger.remove()
+ logger.remove()
pass
diff --git a/src/plugins/chat/del.message_send_control.py b/src/plugins/chat/del.message_send_control.py
deleted file mode 100644
index 30ade9cd..00000000
--- a/src/plugins/chat/del.message_send_control.py
+++ /dev/null
@@ -1,251 +0,0 @@
-from typing import Union, List, Optional, Deque, Dict
-from nonebot.adapters.onebot.v11 import Bot, MessageSegment
-import asyncio
-import random
-import os
-from .message import Message, Message_Thinking, MessageSet
-from .cq_code import CQCode
-from collections import deque
-import time
-from .storage import MessageStorage
-from .config import global_config
-from .cq_code import cq_code_tool
-
-if os.name == "nt":
- from .message_visualizer import message_visualizer
-
-
-
-class SendTemp:
- """单个群组的临时消息队列管理器"""
- def __init__(self, group_id: int, max_size: int = 100):
- self.group_id = group_id
- self.max_size = max_size
- self.messages: Deque[Union[Message, Message_Thinking]] = deque(maxlen=max_size)
- self.last_send_time = 0
-
- def add(self, message: Message) -> None:
- """按时间顺序添加消息到队列"""
- if not self.messages:
- self.messages.append(message)
- return
-
- # 按时间顺序插入
- if message.time >= self.messages[-1].time:
- self.messages.append(message)
- return
-
- # 使用二分查找找到合适的插入位置
- messages_list = list(self.messages)
- left, right = 0, len(messages_list)
-
- while left < right:
- mid = (left + right) // 2
- if messages_list[mid].time < message.time:
- left = mid + 1
- else:
- right = mid
-
- # 重建消息队列,保持时间顺序
- new_messages = deque(maxlen=self.max_size)
- new_messages.extend(messages_list[:left])
- new_messages.append(message)
- new_messages.extend(messages_list[left:])
- self.messages = new_messages
- def get_earliest_message(self) -> Optional[Message]:
- """获取时间最早的消息"""
- message = self.messages.popleft() if self.messages else None
- return message
-
- def clear(self) -> None:
- """清空队列"""
- self.messages.clear()
-
- def get_all(self, group_id: Optional[int] = None) -> List[Union[Message, Message_Thinking]]:
- """获取所有待发送的消息"""
- if group_id is None:
- return list(self.messages)
- return [msg for msg in self.messages if msg.group_id == group_id]
-
- def peek_next(self) -> Optional[Union[Message, Message_Thinking]]:
- """查看下一条要发送的消息(不移除)"""
- return self.messages[0] if self.messages else None
-
- def has_messages(self) -> bool:
- """检查是否有待发送的消息"""
- return bool(self.messages)
-
- def count(self, group_id: Optional[int] = None) -> int:
- """获取待发送消息数量"""
- if group_id is None:
- return len(self.messages)
- return len([msg for msg in self.messages if msg.group_id == group_id])
-
- def get_last_send_time(self) -> float:
- """获取最后一次发送时间"""
- return self.last_send_time
-
- def update_send_time(self):
- """更新最后发送时间"""
- self.last_send_time = time.time()
-
-class SendTempContainer:
- """管理所有群组的消息缓存容器"""
- def __init__(self):
- self.temp_queues: Dict[int, SendTemp] = {}
-
- def get_queue(self, group_id: int) -> SendTemp:
- """获取或创建群组的消息队列"""
- if group_id not in self.temp_queues:
- self.temp_queues[group_id] = SendTemp(group_id)
- return self.temp_queues[group_id]
-
- def add_message(self, message: Message) -> None:
- """添加消息到对应群组的队列"""
- queue = self.get_queue(message.group_id)
- queue.add(message)
-
- def get_group_messages(self, group_id: int) -> List[Union[Message, Message_Thinking]]:
- """获取指定群组的所有待发送消息"""
- queue = self.get_queue(group_id)
- return queue.get_all()
-
- def has_messages(self, group_id: int) -> bool:
- """检查指定群组是否有待发送消息"""
- queue = self.get_queue(group_id)
- return queue.has_messages()
-
- def get_all_groups(self) -> List[int]:
- """获取所有有待发送消息的群组ID"""
- return list(self.temp_queues.keys())
-
- def update_thinking_message(self, message_obj: Union[Message, MessageSet]) -> bool:
- queue = self.get_queue(message_obj.group_id)
- # 使用列表解析找到匹配的消息索引
- matching_indices = [
- i for i, msg in enumerate(queue.messages)
- if msg.message_id == message_obj.message_id
- ]
-
- if not matching_indices:
- return False
-
- index = matching_indices[0] # 获取第一个匹配的索引
-
- # 将消息转换为列表以便修改
- messages = list(queue.messages)
-
- # 根据消息类型处理
- if isinstance(message_obj, MessageSet):
- messages.pop(index)
- # 在原位置插入新消息组
- for i, single_message in enumerate(message_obj.messages):
- messages.insert(index + i, single_message)
- # print(f"\033[1;34m[调试]\033[0m 添加消息组中的第{i+1}条消息: {single_message}")
- else:
- # 直接替换原消息
- messages[index] = message_obj
- # print(f"\033[1;34m[调试]\033[0m 已更新消息: {message_obj}")
-
- # 重建队列
- queue.messages.clear()
- for msg in messages:
- queue.messages.append(msg)
-
- return True
-
-
-class MessageSendControl:
- """消息发送控制器"""
- def __init__(self):
- self.typing_speed = (0.1, 0.3) # 每个字符的打字时间范围(秒)
- self.message_interval = (0.5, 1) # 多条消息间的间隔时间范围(秒)
- self.max_retry = 3 # 最大重试次数
- self.send_temp_container = SendTempContainer()
- self._running = True
- self._paused = False
- self._current_bot = None
- self.storage = MessageStorage() # 添加存储实例
- try:
- message_visualizer.start()
- except(NameError):
- pass
-
- async def process_group_messages(self, group_id: int):
- queue = self.send_temp_container.get_queue(group_id)
- if queue.has_messages():
- message = queue.peek_next()
- # 处理消息的逻辑
- if isinstance(message, Message_Thinking):
- message.update_thinking_time()
- thinking_time = message.thinking_time
- if message.interupt:
- print(f"\033[1;34m[调试]\033[0m 思考不打算回复,移除")
- queue.get_earliest_message()
- return
- elif thinking_time < 90: # 最少思考2秒
- if int(thinking_time) % 15 == 0:
- print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{thinking_time:.1f}秒")
- return
- else:
- print(f"\033[1;34m[调试]\033[0m 思考消息超时,移除")
- queue.get_earliest_message() # 移除超时的思考消息
- return
- elif isinstance(message, Message):
- message = queue.get_earliest_message()
- if message and message.processed_plain_text:
- print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
- cost_time = round(time.time(), 2) - message.time
- if cost_time > 40:
- message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_id) + message.processed_plain_text
- cur_time = time.time()
- await self._current_bot.send_group_msg(
- group_id=group_id,
- message=str(message.processed_plain_text),
- auto_escape=False
- )
- cost_time = round(time.time(), 2) - cur_time
- print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}秒")
- current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
- print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
-
- if message.is_emoji:
- message.processed_plain_text = "[表情包]"
- await self.storage.store_message(message, None)
- else:
- await self.storage.store_message(message, None)
-
-
-
- queue.update_send_time()
- if queue.has_messages():
- await asyncio.sleep(
- random.uniform(
- self.message_interval[0],
- self.message_interval[1]
- )
- )
-
- async def start_processor(self, bot: Bot):
- """启动消息处理器"""
- self._current_bot = bot
-
- while self._running:
- await asyncio.sleep(1.5)
- tasks = []
- for group_id in self.send_temp_container.get_all_groups():
- tasks.append(self.process_group_messages(group_id))
-
- # 并行处理所有群组的消息
- await asyncio.gather(*tasks)
- try:
- message_visualizer.update_content(self.send_temp_container)
- except(NameError):
- pass
-
- def set_typing_speed(self, min_speed: float, max_speed: float):
- """设置打字速度范围"""
- self.typing_speed = (min_speed, max_speed)
-
-# 创建全局实例
-message_sender_control = MessageSendControl()
diff --git a/src/plugins/chat/del.message_stream.py b/src/plugins/chat/del.message_stream.py
deleted file mode 100644
index 07809caa..00000000
--- a/src/plugins/chat/del.message_stream.py
+++ /dev/null
@@ -1,271 +0,0 @@
-from typing import List, Optional, Dict
-from .message import Message
-import time
-from collections import deque
-from datetime import datetime, timedelta
-import os
-import json
-import asyncio
-
-class MessageStream:
- """单个群组的消息流容器"""
- def __init__(self, group_id: int, max_size: int = 1000):
- self.group_id = group_id
- self.messages = deque(maxlen=max_size)
- self.max_size = max_size
- self.last_save_time = time.time()
-
- # 确保日志目录存在
- self.log_dir = os.path.join("log", str(self.group_id))
- os.makedirs(self.log_dir, exist_ok=True)
-
- # 启动自动保存任务
- asyncio.create_task(self._auto_save())
-
- async def _auto_save(self):
- """每30秒自动保存一次消息记录"""
- while True:
- await asyncio.sleep(30) # 等待30秒
- await self.save_to_log()
-
- async def save_to_log(self):
- """将消息保存到日志文件"""
- try:
- current_time = time.time()
- # 只有有新消息时才保存
- if not self.messages or self.last_save_time == current_time:
- return
-
- # 生成日志文件名 (使用当前日期)
- date_str = time.strftime("%Y-%m-%d", time.localtime(current_time))
- log_file = os.path.join(self.log_dir, f"chat_{date_str}.log")
-
- # 获取需要保存的新消息
- new_messages = [
- msg for msg in self.messages
- if msg.time > self.last_save_time
- ]
-
- if not new_messages:
- return
-
- # 将消息转换为可序列化的格式
- message_logs = []
- for msg in new_messages:
- message_logs.append({
- "time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(msg.time)),
- "user_id": msg.user_id,
- "user_nickname": msg.user_nickname,
- "user_cardname": msg.user_cardname,
- "message_id": msg.message_id,
- "raw_message": msg.raw_message,
- "processed_text": msg.processed_plain_text
- })
-
- # 追加写入日志文件
- with open(log_file, "a", encoding="utf-8") as f:
- for log in message_logs:
- f.write(json.dumps(log, ensure_ascii=False) + "\n")
-
- self.last_save_time = current_time
-
- except Exception as e:
- print(f"\033[1;31m[错误]\033[0m 保存群 {self.group_id} 的消息日志失败: {str(e)}")
-
- def add_message(self, message: Message) -> None:
- """按时间顺序添加新消息到队列
-
- 使用改进的二分查找算法来保持消息的时间顺序,同时优化内存使用。
-
- Args:
- message: Message对象,要添加的新消息
- """
-
- # 空队列或消息应该添加到末尾的情况
- if (not self.messages or
- message.time >= self.messages[-1].time):
- self.messages.append(message)
- return
-
- # 消息应该添加到开头的情况
- if message.time <= self.messages[0].time:
- self.messages.appendleft(message)
- return
-
- # 使用二分查找在现有队列中找到合适的插入位置
- left, right = 0, len(self.messages) - 1
- while left <= right:
- mid = (left + right) // 2
- if self.messages[mid].time < message.time:
- left = mid + 1
- else:
- right = mid - 1
-
- temp = list(self.messages)
- temp.insert(left, message)
-
- # 如果超出最大长度,移除多余的消息
- if len(temp) > self.max_size:
- temp = temp[-self.max_size:]
-
- # 重建队列
- self.messages = deque(temp, maxlen=self.max_size)
-
- async def get_recent_messages_from_db(self, count: int = 10) -> List[Message]:
- """从数据库中获取最近的消息记录
-
- Args:
- count: 需要获取的消息数量
-
- Returns:
- List[Message]: 最近的消息列表
- """
- try:
- from ...common.database import Database
- db = Database.get_instance()
-
- # 从数据库中查询最近的消息
- recent_messages = list(db.db.messages.find(
- {"group_id": self.group_id},
- # {
- # "time": 1,
- # "user_id": 1,
- # "user_nickname": 1,
- # # "user_cardname": 1,
- # "message_id": 1,
- # "raw_message": 1,
- # "processed_text": 1
- # }
- ).sort("time", -1).limit(count))
-
- if not recent_messages:
- return []
-
- # 转换为 Message 对象
- from .message import Message
- messages = []
- for msg_data in recent_messages:
- try:
- msg = Message(
- time=msg_data["time"],
- user_id=msg_data["user_id"],
- user_nickname=msg_data.get("user_nickname", ""),
- user_cardname=msg_data.get("user_cardname", ""),
- message_id=msg_data["message_id"],
- raw_message=msg_data["raw_message"],
- processed_plain_text=msg_data.get("processed_text", ""),
- group_id=self.group_id
- )
- messages.append(msg)
- except KeyError:
- print("[WARNING] 数据库中存在无效的消息")
- continue
-
- return list(reversed(messages)) # 返回按时间正序的消息
-
- except Exception as e:
- print(f"\033[1;31m[错误]\033[0m 从数据库获取群 {self.group_id} 的最近消息记录失败: {str(e)}")
- return []
-
- def get_recent_messages(self, count: int = 10) -> List[Message]:
- """获取最近的n条消息(从内存队列)"""
- print(f"\033[1;34m[调试]\033[0m 从内存获取群 {self.group_id} 的最近{count}条消息记录")
- return list(self.messages)[-count:]
-
- def get_messages_in_timerange(self,
- start_time: Optional[float] = None,
- end_time: Optional[float] = None) -> List[Message]:
- """获取时间范围内的消息"""
- if start_time is None:
- start_time = time.time() - 3600
- if end_time is None:
- end_time = time.time()
-
- return [
- msg for msg in self.messages
- if start_time <= msg.time <= end_time
- ]
-
- def get_user_messages(self, user_id: int, count: int = 10) -> List[Message]:
- """获取特定用户的最近消息"""
- user_messages = [msg for msg in self.messages if msg.user_id == user_id]
- return user_messages[-count:]
-
- def clear_old_messages(self, hours: int = 24) -> None:
- """清理旧消息"""
- cutoff_time = time.time() - (hours * 3600)
- self.messages = deque(
- [msg for msg in self.messages if msg.time > cutoff_time],
- maxlen=self.max_size
- )
-
-class MessageStreamContainer:
- """管理所有群组的消息流容器"""
- def __init__(self, max_size: int = 1000):
- self.streams: Dict[int, MessageStream] = {}
- self.max_size = max_size
-
- async def save_all_logs(self):
- """保存所有群组的消息日志"""
- for stream in self.streams.values():
- await stream.save_to_log()
-
- def add_message(self, message: Message) -> None:
- """添加消息到对应群组的消息流"""
- if not message.group_id:
- return
-
- if message.group_id not in self.streams:
- self.streams[message.group_id] = MessageStream(message.group_id, self.max_size)
-
- self.streams[message.group_id].add_message(message)
-
- def get_stream(self, group_id: int) -> Optional[MessageStream]:
- """获取特定群组的消息流"""
- return self.streams.get(group_id)
-
- def get_all_streams(self) -> Dict[int, MessageStream]:
- """获取所有群组的消息流"""
- return self.streams
-
- def clear_old_messages(self, hours: int = 24) -> None:
- """清理所有群组的旧消息"""
- for stream in self.streams.values():
- stream.clear_old_messages(hours)
-
- def get_group_stats(self, group_id: int) -> Dict:
- """获取群组的消息统计信息"""
- stream = self.streams.get(group_id)
- if not stream:
- return {
- "total_messages": 0,
- "unique_users": 0,
- "active_hours": [],
- "most_active_user": None
- }
-
- messages = stream.messages
- user_counts = {}
- hour_counts = {}
-
- for msg in messages:
- user_counts[msg.user_id] = user_counts.get(msg.user_id, 0) + 1
- hour = datetime.fromtimestamp(msg.time).hour
- hour_counts[hour] = hour_counts.get(hour, 0) + 1
-
- most_active_user = max(user_counts.items(), key=lambda x: x[1])[0] if user_counts else None
- active_hours = sorted(
- hour_counts.items(),
- key=lambda x: x[1],
- reverse=True
- )[:5]
-
- return {
- "total_messages": len(messages),
- "unique_users": len(user_counts),
- "active_hours": active_hours,
- "most_active_user": most_active_user
- }
-
-# 创建全局实例
-message_stream_container = MessageStreamContainer()
diff --git a/src/plugins/chat/del.message_visualizer.py b/src/plugins/chat/del.message_visualizer.py
deleted file mode 100644
index 0469af8f..00000000
--- a/src/plugins/chat/del.message_visualizer.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import subprocess
-import threading
-import queue
-import os
-import time
-from typing import Dict
-from .message import Message_Thinking
-
-class MessageVisualizer:
- def __init__(self):
- self.process = None
- self.message_queue = queue.Queue()
- self.is_running = False
- self.content_file = "message_queue_content.txt"
-
- def start(self):
- if self.process is None:
- # 创建用于显示的批处理文件
- with open("message_queue_window.bat", "w", encoding="utf-8") as f:
- f.write('@echo off\n')
- f.write('chcp 65001\n') # 设置UTF-8编码
- f.write('title Message Queue Visualizer\n')
- f.write('echo Waiting for message queue updates...\n')
- f.write(':loop\n')
- f.write('if exist "queue_update.txt" (\n')
- f.write(' type "queue_update.txt" > "message_queue_content.txt"\n')
- f.write(' del "queue_update.txt"\n')
- f.write(' cls\n')
- f.write(' type "message_queue_content.txt"\n')
- f.write(')\n')
- f.write('timeout /t 1 /nobreak >nul\n')
- f.write('goto loop\n')
-
- # 清空内容文件
- with open(self.content_file, "w", encoding="utf-8") as f:
- f.write("")
-
- # 启动新窗口
- startupinfo = subprocess.STARTUPINFO()
- startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
- self.process = subprocess.Popen(
- ['cmd', '/c', 'start', 'message_queue_window.bat'],
- shell=True,
- startupinfo=startupinfo
- )
- self.is_running = True
-
- # 启动处理线程
- threading.Thread(target=self._process_messages, daemon=True).start()
-
- def _process_messages(self):
- while self.is_running:
- try:
- # 获取新消息
- text = self.message_queue.get(timeout=1)
- # 写入更新文件
- with open("queue_update.txt", "w", encoding="utf-8") as f:
- f.write(text)
- except queue.Empty:
- continue
- except Exception as e:
- print(f"处理队列可视化内容时出错: {e}")
-
- def update_content(self, send_temp_container):
- """更新显示内容"""
- if not self.is_running:
- return
-
- current_time = time.strftime("%Y-%m-%d %H:%M:%S")
- display_text = f"Message Queue Status - {current_time}\n"
- display_text += "=" * 50 + "\n\n"
-
- # 遍历所有群组的队列
- for group_id, queue in send_temp_container.temp_queues.items():
- display_text += f"\n{'='*20} 群组: {queue.group_id} {'='*20}\n"
- display_text += f"消息队列长度: {len(queue.messages)}\n"
- display_text += f"最后发送时间: {time.strftime('%H:%M:%S', time.localtime(queue.last_send_time))}\n"
- display_text += "\n消息队列内容:\n"
-
- # 显示队列中的消息
- if not queue.messages:
- display_text += " [空队列]\n"
- else:
- for i, msg in enumerate(queue.messages):
- msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
- display_text += f"\n--- 消息 {i+1} ---\n"
-
- if isinstance(msg, Message_Thinking):
- display_text += f"类型: \033[1;33m思考中消息\033[0m\n"
- display_text += f"时间: {msg_time}\n"
- display_text += f"消息ID: {msg.message_id}\n"
- display_text += f"群组: {msg.group_id}\n"
- display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
- display_text += f"内容: {msg.thinking_text}\n"
- display_text += f"思考时间: {int(msg.thinking_time)}秒\n"
- else:
- display_text += f"类型: 普通消息\n"
- display_text += f"时间: {msg_time}\n"
- display_text += f"消息ID: {msg.message_id}\n"
- display_text += f"群组: {msg.group_id}\n"
- display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
- if hasattr(msg, 'is_emoji') and msg.is_emoji:
- display_text += f"内容: [表情包消息]\n"
- else:
- # 显示原始消息和处理后的消息
- display_text += f"原始内容: {msg.raw_message[:50]}...\n"
- display_text += f"处理后内容: {msg.processed_plain_text[:50]}...\n"
-
- if msg.reply_message:
- display_text += f"回复消息: {str(msg.reply_message)[:50]}...\n"
-
- display_text += f"\n{'-' * 50}\n"
-
- # 添加统计信息
- display_text += "\n总体统计:\n"
- display_text += f"活跃群组数: {len(send_temp_container.temp_queues)}\n"
- total_messages = sum(len(q.messages) for q in send_temp_container.temp_queues.values())
- display_text += f"总消息数: {total_messages}\n"
- thinking_messages = sum(
- sum(1 for msg in q.messages if isinstance(msg, Message_Thinking))
- for q in send_temp_container.temp_queues.values()
- )
- display_text += f"思考中消息数: {thinking_messages}\n"
-
- self.message_queue.put(display_text)
-
- def stop(self):
- self.is_running = False
- if self.process:
- self.process.terminate()
- self.process = None
- # 清理文件
- for file in ["message_queue_window.bat", "message_queue_content.txt", "queue_update.txt"]:
- if os.path.exists(file):
- os.remove(file)
-
-# 创建全局单例
-message_visualizer = MessageVisualizer()
diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py
index 034ff734..04f2e73a 100644
--- a/src/plugins/chat/llm_generator.py
+++ b/src/plugins/chat/llm_generator.py
@@ -95,7 +95,11 @@ class ResponseGenerator:
# return None
# 生成回复
- content, reasoning_content = await model.generate_response(prompt)
+ try:
+ content, reasoning_content = await model.generate_response(prompt)
+ except Exception as e:
+ print(f"生成回复时出错: {e}")
+ return None
# 保存到数据库
self._save_to_db(
@@ -138,9 +142,12 @@ class ResponseGenerator:
内容:{content}
输出:
'''
-
content, _ = await self.model_v3.generate_response(prompt)
- return [content.strip()] if content else ["neutral"]
+ content=content.strip()
+ if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
+ return [content]
+ else:
+ return ["neutral"]
except Exception as e:
print(f"获取情感标签时出错: {e}")
diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py
index 99858694..970fd368 100644
--- a/src/plugins/chat/message_sender.py
+++ b/src/plugins/chat/message_sender.py
@@ -52,12 +52,16 @@ class Message_Sender:
await asyncio.sleep(typing_time)
# 发送消息
- await self._current_bot.send_group_msg(
- group_id=group_id,
- message=message,
- auto_escape=auto_escape
- )
- print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
+ try:
+ await self._current_bot.send_group_msg(
+ group_id=group_id,
+ message=message,
+ auto_escape=auto_escape
+ )
+ print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
+ except Exception as e:
+ print(f"发生错误 {e}")
+ print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
class MessageContainer:
diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py
index c354631c..ba22a403 100644
--- a/src/plugins/chat/prompt_builder.py
+++ b/src/plugins/chat/prompt_builder.py
@@ -36,7 +36,9 @@ class PromptBuilder:
memory_prompt = ''
start_time = time.time() # 记录开始时间
- topic = topic_identifier.identify_topic_jieba(message_txt)
+ # topic = await topic_identifier.identify_topic_llm(message_txt)
+ topic = topic_identifier.identify_topic_snownlp(message_txt)
+
# print(f"\033[1;32m[pb主题识别]\033[0m 主题: {topic}")
all_first_layer_items = [] # 存储所有第一层记忆
@@ -64,15 +66,7 @@ class PromptBuilder:
if overlap:
# print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
overlapping_second_layer.update(overlap)
-
- # 合并所有需要的记忆
- # if all_first_layer_items:
- # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
- # if overlapping_second_layer:
- # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
-
- # 使用集合去重
- # 从每个来源随机选择2条记忆(如果有的话)
+
selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else []
selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else []
diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py
index 07749e83..812d4e32 100644
--- a/src/plugins/chat/topic_identifier.py
+++ b/src/plugins/chat/topic_identifier.py
@@ -15,16 +15,6 @@ class TopicIdentifier:
self.llm_client = LLM_request(model=global_config.llm_topic_extract)
self.select=global_config.topic_extract
- def identify_topic(self):
- if self.select=='jieba':
- return self.identify_topic_jieba
- elif self.select=='snownlp':
- return self.identify_topic_snownlp
- elif self.select=='llm':
- return self.identify_topic_llm
- else:
- return self.identify_topic_snownlp
-
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
"""识别消息主题,返回主题列表"""
@@ -48,56 +38,10 @@ class TopicIdentifier:
# 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
+
+ print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}")
return topic_list if topic_list else None
- def identify_topic_jieba(self, text: str) -> Optional[str]:
- """使用jieba识别主题"""
- words = jieba.lcut(text)
- # 去除停用词和标点符号
- stop_words = {
- '的', '了', '和', '是', '就', '都', '而', '及', '与', '这', '那', '但', '然', '却',
- '因为', '所以', '如果', '虽然', '一个', '我', '你', '他', '她', '它', '我们', '你们',
- '他们', '在', '有', '个', '把', '被', '让', '给', '从', '向', '到', '又', '也', '很',
- '啊', '吧', '呢', '吗', '呀', '哦', '哈', '么', '嘛', '啦', '哎', '唉', '哇', '嗯',
- '哼', '哪', '什么', '怎么', '为什么', '怎样', '如何', '什么样', '这样', '那样', '这么',
- '那么', '多少', '几', '谁', '哪里', '哪儿', '什么时候', '何时', '为何', '怎么办',
- '怎么样', '这些', '那些', '一些', '一点', '一下', '一直', '一定', '一般', '一样',
- '一会儿', '一边', '一起',
- # 添加更多量词
- '个', '只', '条', '张', '片', '块', '本', '册', '页', '幅', '面', '篇', '份',
- '朵', '颗', '粒', '座', '幢', '栋', '间', '层', '家', '户', '位', '名', '群',
- '双', '对', '打', '副', '套', '批', '组', '串', '包', '箱', '袋', '瓶', '罐',
- # 添加更多介词
- '按', '按照', '把', '被', '比', '比如', '除', '除了', '当', '对', '对于',
- '根据', '关于', '跟', '和', '将', '经', '经过', '靠', '连', '论', '通过',
- '同', '往', '为', '为了', '围绕', '于', '由', '由于', '与', '在', '沿', '沿着',
- '依', '依照', '以', '因', '因为', '用', '由', '与', '自', '自从'
- }
-
- # 过滤掉停用词和标点符号,只保留名词和动词
- filtered_words = []
- for word in words:
- if word not in stop_words and not word.strip() in {
- '。', ',', '、', ':', ';', '!', '?', '"', '"', ''', ''',
- '(', ')', '【', '】', '《', '》', '…', '—', '·', '、', '~',
- '~', '+', '=', '-', '/', '\\', '|', '*', '#', '@', '$', '%',
- '^', '&', '[', ']', '{', '}', '<', '>', '`', '_', '.', ',',
- ';', ':', '\'', '"', '(', ')', '?', '!', '±', '×', '÷', '≠',
- '≈', '∈', '∉', '⊆', '⊇', '⊂', '⊃', '∪', '∩', '∧', '∨'
- }:
- filtered_words.append(word)
-
- # 统计词频
- word_freq = {}
- for word in filtered_words:
- word_freq[word] = word_freq.get(word, 0) + 1
-
- # 按词频排序,取前3个
- sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
- top_words = [word for word, freq in sorted_words[:3]]
-
- return top_words if top_words else None
-
def identify_topic_snownlp(self, text: str) -> Optional[List[str]]:
"""使用 SnowNLP 进行主题识别
@@ -113,7 +57,7 @@ class TopicIdentifier:
try:
s = SnowNLP(text)
# 提取前3个关键词作为主题
- keywords = s.keywords(3)
+ keywords = s.keywords(5)
return keywords if keywords else None
except Exception as e:
print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}")
diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py
index 63151592..aa16268e 100644
--- a/src/plugins/chat/utils.py
+++ b/src/plugins/chat/utils.py
@@ -75,13 +75,11 @@ def cosine_similarity(v1, v2):
norm2 = np.linalg.norm(v2)
return dot_product / (norm1 * norm2)
-def calculate_information_content(text):
+def calculate_information_content(text):
"""计算文本的信息量(熵)"""
- # 统计字符频率
char_count = Counter(text)
total_chars = len(text)
- # 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
@@ -90,27 +88,37 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
- # 从数据库中根据时间戳获取离其最近的聊天记录
+ """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = ''
- closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
- # print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
+ closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
- if closest_record:
+ if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息,且groupid相同
- chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
- for record in chat_record:
- time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
- try:
- displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
- except:
- displayname=record["user_nickname"] or "用户" + str(record["user_id"])
- chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
+ chat_records = list(db.db.messages.find(
+ {"time": {"$gt": closest_time}, "group_id": group_id}
+ ).sort('time', 1).limit(length))
+
+ # 更新每条消息的memorized属性
+ for record in chat_records:
+ # 检查当前记录的memorized值
+ current_memorized = record.get('memorized', 0)
+ if current_memorized > 3:
+ # print(f"消息已读取3次,跳过")
+ return ''
+
+ # 更新memorized值
+ db.db.messages.update_one(
+ {"_id": record["_id"]},
+ {"$set": {"memorized": current_memorized + 1}}
+ )
+
+ chat_text += record["detailed_plain_text"]
+
return chat_text
-
- return [] # 如果没有找到记录,返回空列表
-
+ print(f"消息已读取3次,跳过")
+ return ''
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py
index efe2f1c9..922ab522 100644
--- a/src/plugins/chat/utils_image.py
+++ b/src/plugins/chat/utils_image.py
@@ -7,6 +7,7 @@ from ...common.database import Database
import zlib # 用于 CRC32
import base64
from nonebot import get_driver
+from loguru import logger
driver = get_driver()
config = driver.config
@@ -213,11 +214,11 @@ def storage_image(image_data: bytes) -> bytes:
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
return image_data
-def compress_base64_image_by_scale(base64_data: str, scale: float = 0.5) -> str:
- """按比例压缩base64格式的图片
+def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
+ """压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
- scale: 压缩比例(0-1之间的浮点数)
+ target_size: 目标文件大小(字节),默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
@@ -225,34 +226,64 @@ def compress_base64_image_by_scale(base64_data: str, scale: float = 0.5) -> str:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
+ # 如果已经小于目标大小,直接返回原图
+ if len(image_data) <= target_size:
+ return base64_data
+
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
- # 如果是动图,直接返回原图
- if getattr(img, 'is_animated', False):
- return base64_data
-
+ # 获取原始尺寸
+ original_width, original_height = img.size
+
+ # 计算缩放比例
+ scale = min(1.0, (target_size / len(image_data)) ** 0.5)
+
# 计算新的尺寸
- new_width = int(img.width * scale)
- new_height = int(img.height * scale)
+ new_width = int(original_width * scale)
+ new_height = int(original_height * scale)
- # 缩放图片
- img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
+ # 创建内存缓冲区
+ output_buffer = io.BytesIO()
- # 转换为RGB模式(去除透明通道)
- if img.mode in ('RGBA', 'P'):
- img = img.convert('RGB')
+ # 如果是GIF,处理所有帧
+ if getattr(img, "is_animated", False):
+ frames = []
+ for frame_idx in range(img.n_frames):
+ img.seek(frame_idx)
+ new_frame = img.copy()
+ new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
+ frames.append(new_frame)
+
+ # 保存到缓冲区
+ frames[0].save(
+ output_buffer,
+ format='GIF',
+ save_all=True,
+ append_images=frames[1:],
+ optimize=True,
+ duration=img.info.get('duration', 100),
+ loop=img.info.get('loop', 0)
+ )
+ else:
+ # 处理静态图片
+ resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
+
+ # 保存到缓冲区,保持原始格式
+ if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
+ resized_img.save(output_buffer, format='PNG', optimize=True)
+ else:
+ resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
- # 保存压缩后的图片
- output = io.BytesIO()
- img.save(output, format='JPEG', quality=85, optimize=True)
- compressed_data = output.getvalue()
+ # 获取压缩后的数据并转换为base64
+ compressed_data = output_buffer.getvalue()
+ logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
+ logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
- # 转换回base64
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
- print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
+ logger.error(f"压缩图片失败: {str(e)}")
import traceback
- print(traceback.format_exc())
+ logger.error(traceback.format_exc())
return base64_data
\ No newline at end of file
diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py
index f90889f7..7559406f 100644
--- a/src/plugins/chat/willing_manager.py
+++ b/src/plugins/chat/willing_manager.py
@@ -9,7 +9,7 @@ class WillingManager:
async def _decay_reply_willing(self):
"""定期衰减回复意愿"""
while True:
- await asyncio.sleep(3)
+ await asyncio.sleep(5)
for group_id in self.group_reply_willing:
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
@@ -39,11 +39,11 @@ class WillingManager:
if interested_rate > 0.65:
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
- current_willing += interested_rate-0.5
+ current_willing += interested_rate-0.6
self.group_reply_willing[group_id] = min(current_willing, 3.0)
- reply_probability = max((current_willing - 0.5) * 2, 0)
+ reply_probability = max((current_willing - 0.55) * 1.9, 0)
if group_id not in config.talk_allowed_groups:
current_willing = 0
reply_probability = 0
@@ -52,8 +52,8 @@ class WillingManager:
reply_probability = reply_probability / 3.5
reply_probability = min(reply_probability, 1)
- if reply_probability < 0.1:
- reply_probability = 0.1
+ if reply_probability < 0:
+ reply_probability = 0
return reply_probability
def change_reply_willing_sent(self, group_id: int):
@@ -65,7 +65,7 @@ class WillingManager:
"""发送消息后提高群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
if current_willing < 1:
- self.group_reply_willing[group_id] = min(1, current_willing + 0.3)
+ self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
async def ensure_started(self):
"""确保衰减任务已启动"""
diff --git a/src/plugins/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py
index cd6122b4..d7071985 100644
--- a/src/plugins/knowledege/knowledge_library.py
+++ b/src/plugins/knowledege/knowledge_library.py
@@ -79,7 +79,7 @@ class KnowledgeLibrary:
content = f.read()
# 按1024字符分段
- segments = [content[i:i+400] for i in range(0, len(content), 400)]
+ segments = [content[i:i+600] for i in range(0, len(content), 600)]
# 处理每个分段
for segment in segments:
diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py
index ddb11d57..fad3f5f3 100644
--- a/src/plugins/memory_system/draw_memory.py
+++ b/src/plugins/memory_system/draw_memory.py
@@ -22,63 +22,6 @@ from src.common.database import Database # 使用正确的导入语法
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path)
-class LLMModel:
- def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
- self.model_name = model_name
- self.params = kwargs
- self.api_key = os.getenv("SILICONFLOW_KEY")
- self.base_url = os.getenv("SILICONFLOW_BASE_URL")
-
- async def generate_response(self, prompt: str) -> Tuple[str, str]:
- """根据输入的提示生成模型的响应"""
- headers = {
- "Authorization": f"Bearer {self.api_key}",
- "Content-Type": "application/json"
- }
-
- # 构建请求体
- data = {
- "model": self.model_name,
- "messages": [{"role": "user", "content": prompt}],
- "temperature": 0.5,
- **self.params
- }
-
- # 发送请求到完整的chat/completions端点
- api_url = f"{self.base_url.rstrip('/')}/chat/completions"
-
- max_retries = 3
- base_wait_time = 15
-
- for retry in range(max_retries):
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(api_url, headers=headers, json=data) as response:
- if response.status == 429:
- wait_time = base_wait_time * (2 ** retry) # 指数退避
- print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
- await asyncio.sleep(wait_time)
- continue
-
- response.raise_for_status() # 检查其他响应状态
-
- result = await response.json()
- if "choices" in result and len(result["choices"]) > 0:
- content = result["choices"][0]["message"]["content"]
- reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
- return content, reasoning_content
- return "没有返回结果", ""
-
- except Exception as e:
- if retry < max_retries - 1: # 如果还有重试机会
- wait_time = base_wait_time * (2 ** retry)
- print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
- await asyncio.sleep(wait_time)
- else:
- return f"请求失败: {str(e)}", ""
-
- return "达到最大重试次数,请求仍然失败", ""
-
class Memory_graph:
def __init__(self):
@@ -232,19 +175,10 @@ def main():
)
memory_graph = Memory_graph()
- # 创建LLM模型实例
-
memory_graph.load_graph_from_db()
- # 展示两种不同的可视化方式
- print("\n按连接数量着色的图谱:")
- # visualize_graph(memory_graph, color_by_memory=False)
- visualize_graph_lite(memory_graph, color_by_memory=False)
- print("\n按记忆数量着色的图谱:")
- # visualize_graph(memory_graph, color_by_memory=True)
- visualize_graph_lite(memory_graph, color_by_memory=True)
-
- # memory_graph.save_graph_to_db()
+ # 只显示一次优化后的图形
+ visualize_graph_lite(memory_graph)
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
@@ -327,7 +261,7 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
nx.draw(G, pos,
with_labels=True,
node_color=node_colors,
- node_size=2000,
+ node_size=200,
font_size=10,
font_family='SimHei',
font_weight='bold')
@@ -353,7 +287,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
- if memory_count <= 2 or degree <= 2:
+ if memory_count < 5 or degree < 2: # 改为小于2而不是小于等于2
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
@@ -366,55 +300,55 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
# 保存图到本地
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
- # 根据连接条数或记忆数量设置节点颜色
+ # 计算节点大小和颜色
node_colors = []
- nodes = list(H.nodes()) # 获取图中实际的节点列表
+ node_sizes = []
+ nodes = list(H.nodes())
- if color_by_memory:
- # 计算每个节点的记忆数量
- memory_counts = []
- for node in nodes:
- memory_items = H.nodes[node].get('memory_items', [])
- if isinstance(memory_items, list):
- count = len(memory_items)
- else:
- count = 1 if memory_items else 0
- memory_counts.append(count)
- max_memories = max(memory_counts) if memory_counts else 1
+ # 获取最大记忆数和最大度数用于归一化
+ max_memories = 1
+ max_degree = 1
+ for node in nodes:
+ memory_items = H.nodes[node].get('memory_items', [])
+ memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
+ degree = H.degree(node)
+ max_memories = max(max_memories, memory_count)
+ max_degree = max(max_degree, degree)
+
+ # 计算每个节点的大小和颜色
+ for node in nodes:
+ # 计算节点大小(基于记忆数量)
+ memory_items = H.nodes[node].get('memory_items', [])
+ memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
+ # 使用指数函数使变化更明显
+ ratio = memory_count / max_memories
+ size = 500 + 5000 * (ratio ** 2) # 使用平方函数使差异更明显
+ node_sizes.append(size)
- for count in memory_counts:
- # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
- if max_memories > 0:
- intensity = min(1.0, count / max_memories)
- color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
- else:
- color = (0, 0, 1) # 如果没有记忆,则为蓝色
- node_colors.append(color)
- else:
- # 使用原来的连接数量着色方案
- max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
- for node in nodes:
- degree = H.degree(node)
- if max_degree > 0:
- red = min(1.0, degree / max_degree)
- blue = 1.0 - red
- color = (red, 0, blue)
- else:
- color = (0, 0, 1)
- node_colors.append(color)
+ # 计算节点颜色(基于连接数)
+ degree = H.degree(node)
+ # 红色分量随着度数增加而增加
+ red = min(1.0, degree / max_degree)
+ # 蓝色分量随着度数减少而增加
+ blue = 1.0 - red
+ color = (red, 0, blue)
+ node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
- pos = nx.spring_layout(H, k=1, iterations=50)
+ pos = nx.spring_layout(H, k=1.5, iterations=50) # 增加k值使节点分布更开
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
- node_size=2000,
+ node_size=node_sizes,
font_size=10,
font_family='SimHei',
- font_weight='bold')
+ font_weight='bold',
+ edge_color='gray',
+ width=0.5,
+ alpha=0.7)
- title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
+ title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py
index e0095dad..4d20d05a 100644
--- a/src/plugins/memory_system/memory.py
+++ b/src/plugins/memory_system/memory.py
@@ -9,15 +9,26 @@ import random
import time
from ..chat.config import global_config
from ...common.database import Database # 使用正确的导入语法
-from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
from ..models.utils_model import LLM_request
+import math
+from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
+
+
+
+
+
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
- self.G.add_edge(concept1, concept2)
+ # 如果边已存在,增加 strength
+ if self.G.has_edge(concept1, concept2):
+ self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
+ else:
+ # 如果是新边,初始化 strength 为 1
+ self.G.add_edge(concept1, concept2, strength=1)
def add_dot(self, concept, memory):
if concept in self.G:
@@ -38,9 +49,7 @@ class Memory_graph:
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
- # print(node_data)
- # 创建新的Memory_dot对象
- return concept,node_data
+ return concept, node_data
return None
def get_related_item(self, topic, depth=1):
@@ -52,7 +61,6 @@ class Memory_graph:
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
- # print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
@@ -69,7 +77,6 @@ class Memory_graph:
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
- # print(f"第二层: {neighbor}")
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
@@ -87,87 +94,59 @@ class Memory_graph:
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
- def save_graph_to_db(self):
- # 保存节点
- for node in self.G.nodes(data=True):
- concept = node[0]
- memory_items = node[1].get('memory_items', [])
+ def forget_topic(self, topic):
+ """随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点"""
+ if topic not in self.G:
+ return None
- # 查找是否存在同名节点
- existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
- if existing_node:
- # 如果存在,合并memory_items并去重
- existing_items = existing_node.get('memory_items', [])
- if not isinstance(existing_items, list):
- existing_items = [existing_items] if existing_items else []
-
- # 合并并去重
- all_items = list(set(existing_items + memory_items))
-
- # 更新节点
- self.db.db.graph_data.nodes.update_one(
- {'concept': concept},
- {'$set': {'memory_items': all_items}}
- )
- else:
- # 如果不存在,创建新节点
- node_data = {
- 'concept': concept,
- 'memory_items': memory_items
- }
- self.db.db.graph_data.nodes.insert_one(node_data)
+ # 获取话题节点数据
+ node_data = self.G.nodes[topic]
- # 保存边
- for edge in self.G.edges():
- source, target = edge
+ # 如果节点存在memory_items
+ if 'memory_items' in node_data:
+ memory_items = node_data['memory_items']
- # 查找是否存在同样的边
- existing_edge = self.db.db.graph_data.edges.find_one({
- 'source': source,
- 'target': target
- })
-
- if existing_edge:
- # 如果存在,增加num属性
- num = existing_edge.get('num', 1) + 1
- self.db.db.graph_data.edges.update_one(
- {'source': source, 'target': target},
- {'$set': {'num': num}}
- )
- else:
- # 如果不存在,创建新边
- edge_data = {
- 'source': source,
- 'target': target,
- 'num': 1
- }
- self.db.db.graph_data.edges.insert_one(edge_data)
-
- def load_graph_from_db(self):
- # 清空当前图
- self.G.clear()
- # 加载节点
- nodes = self.db.db.graph_data.nodes.find()
- for node in nodes:
- memory_items = node.get('memory_items', [])
+ # 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
- self.G.add_node(node['concept'], memory_items=memory_items)
- # 加载边
- edges = self.db.db.graph_data.edges.find()
- for edge in edges:
- self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
-
-
-
+
+ # 如果有记忆项可以删除
+ if memory_items:
+ # 随机选择一个记忆项删除
+ removed_item = random.choice(memory_items)
+ memory_items.remove(removed_item)
+
+ # 更新节点的记忆项
+ if memory_items:
+ self.G.nodes[topic]['memory_items'] = memory_items
+ else:
+ # 如果没有记忆项了,删除整个节点
+ self.G.remove_node(topic)
+
+ return removed_item
+
+ return None
# 海马体
class Hippocampus:
def __init__(self,memory_graph:Memory_graph):
self.memory_graph = memory_graph
- self.llm_model = LLM_request(model = global_config.llm_normal,temperature=0.5)
- self.llm_model_small = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
+ self.llm_model_get_topic = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
+ self.llm_model_summary = LLM_request(model = global_config.llm_normal,temperature=0.5)
+
+ def calculate_node_hash(self, concept, memory_items):
+ """计算节点的特征值"""
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+ sorted_items = sorted(memory_items)
+ content = f"{concept}:{'|'.join(sorted_items)}"
+ return hash(content)
+
+ def calculate_edge_hash(self, source, target):
+ """计算边的特征值"""
+ nodes = sorted([source, target])
+ return hash(f"{nodes[0]}:{nodes[1]}")
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
@@ -175,82 +154,340 @@ class Hippocampus:
#短期:1h 中期:4h 长期:24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
- # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
- # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
- # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
- return chat_text
+ return [text for text in chat_text if text]
- async def memory_compress(self, input_text, rate=1):
- information_content = calculate_information_content(input_text)
- print(f"文本的信息量(熵): {information_content:.4f} bits")
- topic_num = max(1, min(5, int(information_content * rate / 4)))
- topic_prompt = find_topic(input_text, topic_num)
- topic_response = await self.llm_model.generate_response(topic_prompt)
- # 检查 topic_response 是否为元组
- if isinstance(topic_response, tuple):
- topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
- else:
- topics = topic_response.split(",")
- compressed_memory = set()
+ async def memory_compress(self, input_text, compress_rate=0.1):
+ print(input_text)
+
+ #获取topics
+ topic_num = self.calculate_topic_num(input_text, compress_rate)
+ topics_response = await self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num))
+ # 修改话题处理逻辑
+ print(f"话题: {topics_response[0]}")
+ topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
+ print(f"话题: {topics}")
+
+ # 创建所有话题的请求任务
+ tasks = []
for topic in topics:
- topic_what_prompt = topic_what(input_text,topic)
- topic_what_response = await self.llm_model_small.generate_response(topic_what_prompt)
- compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
+ topic_what_prompt = self.topic_what(input_text, topic)
+ # 创建异步任务
+ task = self.llm_model_summary.generate_response_async(topic_what_prompt)
+ tasks.append((topic.strip(), task))
+
+ # 等待所有任务完成
+ compressed_memory = set()
+ for topic, task in tasks:
+ response = await task
+ if response:
+ compressed_memory.add((topic, response[0]))
+
return compressed_memory
-
- async def build_memory(self,chat_size=12):
- #最近消息获取频率
- time_frequency = {'near':1,'mid':2,'far':2}
+
+ def calculate_topic_num(self,text, compress_rate):
+ """计算文本的话题数量"""
+ information_content = calculate_information_content(text)
+ topic_by_length = text.count('\n')*compress_rate
+ topic_by_information_content = max(1, min(5, int((information_content-3) * 2)))
+ topic_num = int((topic_by_length + topic_by_information_content)/2)
+ print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
+ return topic_num
+
+ async def operation_build_memory(self,chat_size=20):
+ # 最近消息获取频率
+ time_frequency = {'near':2,'mid':4,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
- # print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
+
for i, input_text in enumerate(memory_sample, 1):
- #加载进度可视化
+ # 加载进度可视化
+ all_topics = []
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
bar = '█' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
- if input_text:
- # 生成压缩后记忆
- first_memory = set()
- first_memory = await self.memory_compress(input_text, 2.5)
- #将记忆加入到图谱中
- for topic, memory in first_memory:
- topics = segment_text(topic)
- print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
- for split_topic in topics:
- self.memory_graph.add_dot(split_topic,memory)
- for split_topic in topics:
- for other_split_topic in topics:
- if split_topic != other_split_topic:
- self.memory_graph.connect_dot(split_topic, other_split_topic)
+
+ # 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
+ compressed_memory = set()
+ compress_rate = 0.1
+ compressed_memory = await self.memory_compress(input_text, compress_rate)
+ print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
+
+ # 将记忆加入到图谱中
+ for topic, memory in compressed_memory:
+ print(f"\033[1;32m添加节点\033[0m: {topic}")
+ self.memory_graph.add_dot(topic, memory)
+ all_topics.append(topic) # 收集所有话题
+ for i in range(len(all_topics)):
+ for j in range(i + 1, len(all_topics)):
+ print(f"\033[1;32m连接节点\033[0m: {all_topics[i]} 和 {all_topics[j]}")
+ self.memory_graph.connect_dot(all_topics[i], all_topics[j])
+
+ self.sync_memory_to_db()
+
+ def sync_memory_to_db(self):
+ """检查并同步内存中的图结构与数据库"""
+ # 获取数据库中所有节点和内存中所有节点
+ db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
+ memory_nodes = list(self.memory_graph.G.nodes(data=True))
+
+ # 转换数据库节点为字典格式,方便查找
+ db_nodes_dict = {node['concept']: node for node in db_nodes}
+
+ # 检查并更新节点
+ for concept, data in memory_nodes:
+ memory_items = data.get('memory_items', [])
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+
+ # 计算内存中节点的特征值
+ memory_hash = self.calculate_node_hash(concept, memory_items)
+
+ if concept not in db_nodes_dict:
+ # 数据库中缺少的节点,添加
+ node_data = {
+ 'concept': concept,
+ 'memory_items': memory_items,
+ 'hash': memory_hash
+ }
+ self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
else:
- print(f"空消息 跳过")
- self.memory_graph.save_graph_to_db()
+ # 获取数据库中节点的特征值
+ db_node = db_nodes_dict[concept]
+ db_hash = db_node.get('hash', None)
+
+ # 如果特征值不同,则更新节点
+ if db_hash != memory_hash:
+ self.memory_graph.db.db.graph_data.nodes.update_one(
+ {'concept': concept},
+ {'$set': {
+ 'memory_items': memory_items,
+ 'hash': memory_hash
+ }}
+ )
+
+ # 检查并删除数据库中多余的节点
+ memory_concepts = set(node[0] for node in memory_nodes)
+ for db_node in db_nodes:
+ if db_node['concept'] not in memory_concepts:
+ self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
+
+ # 处理边的信息
+ db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
+ memory_edges = list(self.memory_graph.G.edges())
+
+ # 创建边的哈希值字典
+ db_edge_dict = {}
+ for edge in db_edges:
+ edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
+ db_edge_dict[(edge['source'], edge['target'])] = {
+ 'hash': edge_hash,
+ 'strength': edge.get('strength', 1)
+ }
+
+ # 检查并更新边
+ for source, target in memory_edges:
+ edge_hash = self.calculate_edge_hash(source, target)
+ edge_key = (source, target)
+ strength = self.memory_graph.G[source][target].get('strength', 1)
+
+ if edge_key not in db_edge_dict:
+ # 添加新边
+ edge_data = {
+ 'source': source,
+ 'target': target,
+ 'strength': strength,
+ 'hash': edge_hash
+ }
+ self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
+ else:
+ # 检查边的特征值是否变化
+ if db_edge_dict[edge_key]['hash'] != edge_hash:
+ self.memory_graph.db.db.graph_data.edges.update_one(
+ {'source': source, 'target': target},
+ {'$set': {
+ 'hash': edge_hash,
+ 'strength': strength
+ }}
+ )
+
+ # 删除多余的边
+ memory_edge_set = set(memory_edges)
+ for edge_key in db_edge_dict:
+ if edge_key not in memory_edge_set:
+ source, target = edge_key
+ self.memory_graph.db.db.graph_data.edges.delete_one({
+ 'source': source,
+ 'target': target
+ })
+
+ def sync_memory_from_db(self):
+ """从数据库同步数据到内存中的图结构"""
+ # 清空当前图
+ self.memory_graph.G.clear()
+
+ # 从数据库加载所有节点
+ nodes = self.memory_graph.db.db.graph_data.nodes.find()
+ for node in nodes:
+ concept = node['concept']
+ memory_items = node.get('memory_items', [])
+ # 确保memory_items是列表
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+ # 添加节点到图中
+ self.memory_graph.G.add_node(concept, memory_items=memory_items)
+
+ # 从数据库加载所有边
+ edges = self.memory_graph.db.db.graph_data.edges.find()
+ for edge in edges:
+ source = edge['source']
+ target = edge['target']
+ strength = edge.get('strength', 1) # 获取 strength,默认为 1
+ # 只有当源节点和目标节点都存在时才添加边
+ if source in self.memory_graph.G and target in self.memory_graph.G:
+ self.memory_graph.G.add_edge(source, target, strength=strength)
+
+ async def operation_forget_topic(self, percentage=0.1):
+ """随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘"""
+ # 获取所有节点
+ all_nodes = list(self.memory_graph.G.nodes())
+ # 计算要检查的节点数量
+ check_count = max(1, int(len(all_nodes) * percentage))
+ # 随机选择节点
+ nodes_to_check = random.sample(all_nodes, check_count)
+
+ forgotten_nodes = []
+ for node in nodes_to_check:
+ # 获取节点的连接数
+ connections = self.memory_graph.G.degree(node)
+
+ # 获取节点的内容条数
+ memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+ content_count = len(memory_items)
+
+ # 检查连接强度
+ weak_connections = True
+ if connections > 1: # 只有当连接数大于1时才检查强度
+ for neighbor in self.memory_graph.G.neighbors(node):
+ strength = self.memory_graph.G[node][neighbor].get('strength', 1)
+ if strength > 2:
+ weak_connections = False
+ break
+
+ # 如果满足遗忘条件
+ if (connections <= 1 and weak_connections) or content_count <= 2:
+ removed_item = self.memory_graph.forget_topic(node)
+ if removed_item:
+ forgotten_nodes.append((node, removed_item))
+ print(f"遗忘节点 {node} 的记忆: {removed_item}")
+
+ # 同步到数据库
+ if forgotten_nodes:
+ self.sync_memory_to_db()
+ print(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
+ else:
+ print("本次检查没有节点满足遗忘条件")
+
+ async def merge_memory(self, topic):
+ """
+ 对指定话题的记忆进行合并压缩
+
+ Args:
+ topic: 要合并的话题节点
+ """
+ # 获取节点的记忆项
+ memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+
+ # 如果记忆项不足,直接返回
+ if len(memory_items) < 10:
+ return
+
+ # 随机选择10条记忆
+ selected_memories = random.sample(memory_items, 10)
+
+ # 拼接成文本
+ merged_text = "\n".join(selected_memories)
+ print(f"\n[合并记忆] 话题: {topic}")
+ print(f"选择的记忆:\n{merged_text}")
+
+ # 使用memory_compress生成新的压缩记忆
+ compressed_memories = await self.memory_compress(merged_text, 0.1)
+
+ # 从原记忆列表中移除被选中的记忆
+ for memory in selected_memories:
+ memory_items.remove(memory)
+
+ # 添加新的压缩记忆
+ for _, compressed_memory in compressed_memories:
+ memory_items.append(compressed_memory)
+ print(f"添加压缩记忆: {compressed_memory}")
+
+ # 更新节点的记忆项
+ self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
+ print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
+
+ async def operation_merge_memory(self, percentage=0.1):
+ """
+ 随机检查一定比例的节点,对内容数量超过100的节点进行记忆合并
+
+ Args:
+ percentage: 要检查的节点比例,默认为0.1(10%)
+ """
+ # 获取所有节点
+ all_nodes = list(self.memory_graph.G.nodes())
+ # 计算要检查的节点数量
+ check_count = max(1, int(len(all_nodes) * percentage))
+ # 随机选择节点
+ nodes_to_check = random.sample(all_nodes, check_count)
+
+ merged_nodes = []
+ for node in nodes_to_check:
+ # 获取节点的内容条数
+ memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+ content_count = len(memory_items)
+
+ # 如果内容数量超过100,进行合并
+ if content_count > 100:
+ print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
+ await self.merge_memory(node)
+ merged_nodes.append(node)
+
+ # 同步到数据库
+ if merged_nodes:
+ self.sync_memory_to_db()
+ print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
+ else:
+ print("\n本次检查没有需要合并的节点")
+
+ def find_topic_llm(self,text, topic_num):
+ prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
+ return prompt
+
+ def topic_what(self,text, topic):
+ prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
+ return prompt
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
-def find_topic(text, topic_num):
- prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
- return prompt
-
-def topic_what(text, topic):
- prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
- return prompt
-
from nonebot import get_driver
driver = get_driver()
@@ -268,10 +505,10 @@ Database.initialize(
)
#创建记忆图
memory_graph = Memory_graph()
-#加载数据库中存储的记忆图
-memory_graph.load_graph_from_db()
#创建海马体
hippocampus = Hippocampus(memory_graph)
+#从数据库加载记忆图
+hippocampus.sync_memory_from_db()
end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
\ No newline at end of file
diff --git a/src/plugins/memory_system/memory_make.py b/src/plugins/memory_system/memory_make.py
deleted file mode 100644
index d1757b24..00000000
--- a/src/plugins/memory_system/memory_make.py
+++ /dev/null
@@ -1,463 +0,0 @@
-# -*- coding: utf-8 -*-
-import sys
-import jieba
-import networkx as nx
-import matplotlib.pyplot as plt
-import math
-from collections import Counter
-import datetime
-import random
-import time
-import os
-# from chat.config import global_config
-sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
-from src.common.database import Database # 使用正确的导入语法
-from src.plugins.memory_system.llm_module import LLMModel
-
-def calculate_information_content(text):
- """计算文本的信息量(熵)"""
- # 统计字符频率
- char_count = Counter(text)
- total_chars = len(text)
-
- # 计算熵
- entropy = 0
- for count in char_count.values():
- probability = count / total_chars
- entropy -= probability * math.log2(probability)
-
- return entropy
-
-def get_cloest_chat_from_db(db, length: int, timestamp: str):
- """从数据库中获取最接近指定时间戳的聊天记录"""
- chat_text = ''
- closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
-
- if closest_record:
- closest_time = closest_record['time']
- group_id = closest_record['group_id'] # 获取groupid
- # 获取该时间戳之后的length条消息,且groupid相同
- chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
- for record in chat_record:
- time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
- chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
- return chat_text
-
- return ''
-
-class Memory_graph:
- def __init__(self):
- self.G = nx.Graph() # 使用 networkx 的图结构
- self.db = Database.get_instance()
-
- def connect_dot(self, concept1, concept2):
- self.G.add_edge(concept1, concept2)
-
- def add_dot(self, concept, memory):
- if concept in self.G:
- # 如果节点已存在,将新记忆添加到现有列表中
- if 'memory_items' in self.G.nodes[concept]:
- if not isinstance(self.G.nodes[concept]['memory_items'], list):
- # 如果当前不是列表,将其转换为列表
- self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
- self.G.nodes[concept]['memory_items'].append(memory)
- else:
- self.G.nodes[concept]['memory_items'] = [memory]
- else:
- # 如果是新节点,创建新的记忆列表
- self.G.add_node(concept, memory_items=[memory])
-
- def get_dot(self, concept):
- # 检查节点是否存在于图中
- if concept in self.G:
- # 从图中获取节点数据
- node_data = self.G.nodes[concept]
- # print(node_data)
- # 创建新的Memory_dot对象
- return concept,node_data
- return None
-
- def get_related_item(self, topic, depth=1):
- if topic not in self.G:
- return [], []
-
- first_layer_items = []
- second_layer_items = []
-
- # 获取相邻节点
- neighbors = list(self.G.neighbors(topic))
- # print(f"第一层: {topic}")
-
- # 获取当前节点的记忆项
- node_data = self.get_dot(topic)
- if node_data:
- concept, data = node_data
- if 'memory_items' in data:
- memory_items = data['memory_items']
- if isinstance(memory_items, list):
- first_layer_items.extend(memory_items)
- else:
- first_layer_items.append(memory_items)
-
- # 只在depth=2时获取第二层记忆
- if depth >= 2:
- # 获取相邻节点的记忆项
- for neighbor in neighbors:
- # print(f"第二层: {neighbor}")
- node_data = self.get_dot(neighbor)
- if node_data:
- concept, data = node_data
- if 'memory_items' in data:
- memory_items = data['memory_items']
- if isinstance(memory_items, list):
- second_layer_items.extend(memory_items)
- else:
- second_layer_items.append(memory_items)
-
- return first_layer_items, second_layer_items
-
- def store_memory(self):
- for node in self.G.nodes():
- dot_data = {
- "concept": node
- }
- self.db.db.store_memory_dots.insert_one(dot_data)
-
- @property
- def dots(self):
- # 返回所有节点对应的 Memory_dot 对象
- return [self.get_dot(node) for node in self.G.nodes()]
-
-
- def get_random_chat_from_db(self, length: int, timestamp: str):
- # 从数据库中根据时间戳获取离其最近的聊天记录
- chat_text = ''
- closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
-
- # print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
-
- if closest_record:
- closest_time = closest_record['time']
- group_id = closest_record['group_id'] # 获取groupid
- # 获取该时间戳之后的length条消息,且groupid相同
- chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
- for record in chat_record:
- if record:
- time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
- try:
- displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
- except:
- displayname=record["user_nickname"] or "用户" + str(record["user_id"])
- chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
- return chat_text
-
- return [] # 如果没有找到记录,返回空列表
-
- def save_graph_to_db(self):
- # 保存节点
- for node in self.G.nodes(data=True):
- concept = node[0]
- memory_items = node[1].get('memory_items', [])
-
- # 查找是否存在同名节点
- existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
- if existing_node:
- # 如果存在,合并memory_items并去重
- existing_items = existing_node.get('memory_items', [])
- if not isinstance(existing_items, list):
- existing_items = [existing_items] if existing_items else []
-
- # 合并并去重
- all_items = list(set(existing_items + memory_items))
-
- # 更新节点
- self.db.db.graph_data.nodes.update_one(
- {'concept': concept},
- {'$set': {'memory_items': all_items}}
- )
- else:
- # 如果不存在,创建新节点
- node_data = {
- 'concept': concept,
- 'memory_items': memory_items
- }
- self.db.db.graph_data.nodes.insert_one(node_data)
-
- # 保存边
- for edge in self.G.edges():
- source, target = edge
-
- # 查找是否存在同样的边
- existing_edge = self.db.db.graph_data.edges.find_one({
- 'source': source,
- 'target': target
- })
-
- if existing_edge:
- # 如果存在,增加num属性
- num = existing_edge.get('num', 1) + 1
- self.db.db.graph_data.edges.update_one(
- {'source': source, 'target': target},
- {'$set': {'num': num}}
- )
- else:
- # 如果不存在,创建新边
- edge_data = {
- 'source': source,
- 'target': target,
- 'num': 1
- }
- self.db.db.graph_data.edges.insert_one(edge_data)
-
- def load_graph_from_db(self):
- # 清空当前图
- self.G.clear()
- # 加载节点
- nodes = self.db.db.graph_data.nodes.find()
- for node in nodes:
- memory_items = node.get('memory_items', [])
- if not isinstance(memory_items, list):
- memory_items = [memory_items] if memory_items else []
- self.G.add_node(node['concept'], memory_items=memory_items)
- # 加载边
- edges = self.db.db.graph_data.edges.find()
- for edge in edges:
- self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
-
-# 海马体
-class Hippocampus:
- def __init__(self,memory_graph:Memory_graph):
- self.memory_graph = memory_graph
- self.llm_model = LLMModel()
- self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
-
- def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
- current_timestamp = datetime.datetime.now().timestamp()
- chat_text = []
- #短期:1h 中期:4h 长期:24h
- for _ in range(time_frequency.get('near')): # 循环10次
- random_time = current_timestamp - random.randint(1, 3600) # 随机时间
- chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
- chat_text.append(chat_)
- for _ in range(time_frequency.get('mid')): # 循环10次
- random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
- chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
- chat_text.append(chat_)
- for _ in range(time_frequency.get('far')): # 循环10次
- random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
- chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
- chat_text.append(chat_)
- return chat_text
-
- def build_memory(self,chat_size=12):
- #最近消息获取频率
- time_frequency = {'near':1,'mid':2,'far':2}
- memory_sample = self.get_memory_sample(chat_size,time_frequency)
-
- #加载进度可视化
- for i, input_text in enumerate(memory_sample, 1):
- progress = (i / len(memory_sample)) * 100
- bar_length = 30
- filled_length = int(bar_length * i // len(memory_sample))
- bar = '█' * filled_length + '-' * (bar_length - filled_length)
- print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
- # print(f"第{i}条消息: {input_text}")
- if input_text:
- # 生成压缩后记忆
- first_memory = set()
- first_memory = self.memory_compress(input_text, 2.5)
- #将记忆加入到图谱中
- for topic, memory in first_memory:
- topics = segment_text(topic)
- print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
- for split_topic in topics:
- self.memory_graph.add_dot(split_topic,memory)
- for split_topic in topics:
- for other_split_topic in topics:
- if split_topic != other_split_topic:
- self.memory_graph.connect_dot(split_topic, other_split_topic)
- else:
- print(f"空消息 跳过")
-
- self.memory_graph.save_graph_to_db()
-
- def memory_compress(self, input_text, rate=1):
- information_content = calculate_information_content(input_text)
- print(f"文本的信息量(熵): {information_content:.4f} bits")
- topic_num = max(1, min(5, int(information_content * rate / 4)))
- topic_prompt = find_topic(input_text, topic_num)
- topic_response = self.llm_model.generate_response(topic_prompt)
- # 检查 topic_response 是否为元组
- if isinstance(topic_response, tuple):
- topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
- else:
- topics = topic_response.split(",")
- compressed_memory = set()
- for topic in topics:
- topic_what_prompt = topic_what(input_text,topic)
- topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
- compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
- return compressed_memory
-
-def segment_text(text):
- seg_text = list(jieba.cut(text))
- return seg_text
-
-def find_topic(text, topic_num):
- prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
- return prompt
-
-def topic_what(text, topic):
- prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
- return prompt
-
-def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
- # 设置中文字体
- plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
- plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
-
- G = memory_graph.G
-
- # 创建一个新图用于可视化
- H = G.copy()
-
- # 移除只有一条记忆的节点和连接数少于3的节点
- nodes_to_remove = []
- for node in H.nodes():
- memory_items = H.nodes[node].get('memory_items', [])
- memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
- degree = H.degree(node)
- if memory_count <= 1 or degree <= 2:
- nodes_to_remove.append(node)
-
- H.remove_nodes_from(nodes_to_remove)
-
- # 如果过滤后没有节点,则返回
- if len(H.nodes()) == 0:
- print("过滤后没有符合条件的节点可显示")
- return
-
- # 保存图到本地
- nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
-
- # 根据连接条数或记忆数量设置节点颜色
- node_colors = []
- nodes = list(H.nodes()) # 获取图中实际的节点列表
-
- if color_by_memory:
- # 计算每个节点的记忆数量
- memory_counts = []
- for node in nodes:
- memory_items = H.nodes[node].get('memory_items', [])
- if isinstance(memory_items, list):
- count = len(memory_items)
- else:
- count = 1 if memory_items else 0
- memory_counts.append(count)
- max_memories = max(memory_counts) if memory_counts else 1
-
- for count in memory_counts:
- # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
- if max_memories > 0:
- intensity = min(1.0, count / max_memories)
- color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
- else:
- color = (0, 0, 1) # 如果没有记忆,则为蓝色
- node_colors.append(color)
- else:
- # 使用原来的连接数量着色方案
- max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
- for node in nodes:
- degree = H.degree(node)
- if max_degree > 0:
- red = min(1.0, degree / max_degree)
- blue = 1.0 - red
- color = (red, 0, blue)
- else:
- color = (0, 0, 1)
- node_colors.append(color)
-
- # 绘制图形
- plt.figure(figsize=(12, 8))
- pos = nx.spring_layout(H, k=1, iterations=50)
- nx.draw(H, pos,
- with_labels=True,
- node_color=node_colors,
- node_size=2000,
- font_size=10,
- font_family='SimHei',
- font_weight='bold')
-
- title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
- plt.title(title, fontsize=16, fontfamily='SimHei')
- plt.show()
-
-def main():
- # 初始化数据库
- Database.initialize(
- host= os.getenv("MONGODB_HOST"),
- port= int(os.getenv("MONGODB_PORT")),
- db_name= os.getenv("DATABASE_NAME"),
- username= os.getenv("MONGODB_USERNAME"),
- password= os.getenv("MONGODB_PASSWORD"),
- auth_source=os.getenv("MONGODB_AUTH_SOURCE")
- )
-
- start_time = time.time()
-
- # 创建记忆图
- memory_graph = Memory_graph()
- # 加载数据库中存储的记忆图
- memory_graph.load_graph_from_db()
- # 创建海马体
- hippocampus = Hippocampus(memory_graph)
-
- end_time = time.time()
- print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
-
- # 构建记忆
- hippocampus.build_memory(chat_size=25)
-
- # 展示两种不同的可视化方式
- print("\n按连接数量着色的图谱:")
- visualize_graph(memory_graph, color_by_memory=False)
-
- print("\n按记忆数量着色的图谱:")
- visualize_graph(memory_graph, color_by_memory=True)
-
- # 交互式查询
- while True:
- query = input("请输入新的查询概念(输入'退出'以结束):")
- if query.lower() == '退出':
- break
- items_list = memory_graph.get_related_item(query)
- if items_list:
- for memory_item in items_list:
- print(memory_item)
- else:
- print("未找到相关记忆。")
-
- while True:
- query = input("请输入问题:")
-
- if query.lower() == '退出':
- break
-
- topic_prompt = find_topic(query, 3)
- topic_response = hippocampus.llm_model.generate_response(topic_prompt)
- # 检查 topic_response 是否为元组
- if isinstance(topic_response, tuple):
- topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
- else:
- topics = topic_response.split(",")
- print(topics)
-
- for keyword in topics:
- items_list = memory_graph.get_related_item(keyword)
- if items_list:
- print(items_list)
-
-if __name__ == "__main__":
- main()
-
-
diff --git a/src/plugins/memory_system/memory_manual_build.py b/src/plugins/memory_system/memory_manual_build.py
new file mode 100644
index 00000000..d6aa2f66
--- /dev/null
+++ b/src/plugins/memory_system/memory_manual_build.py
@@ -0,0 +1,786 @@
+# -*- coding: utf-8 -*-
+import sys
+import jieba
+import networkx as nx
+import matplotlib.pyplot as plt
+import math
+from collections import Counter
+import datetime
+import random
+import time
+import os
+from dotenv import load_dotenv
+import pymongo
+from loguru import logger
+from pathlib import Path
+from snownlp import SnowNLP
+# from chat.config import global_config
+sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
+from src.common.database import Database
+from src.plugins.memory_system.offline_llm import LLMModel
+
+# 获取当前文件的目录
+current_dir = Path(__file__).resolve().parent
+# 获取项目根目录(上三层目录)
+project_root = current_dir.parent.parent.parent
+# env.dev文件路径
+env_path = project_root / ".env.dev"
+
+# 加载环境变量
+if env_path.exists():
+ logger.info(f"从 {env_path} 加载环境变量")
+ load_dotenv(env_path)
+else:
+ logger.warning(f"未找到环境变量文件: {env_path}")
+ logger.info("将使用默认配置")
+
+class Database:
+ _instance = None
+ db = None
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls()
+ return cls._instance
+
+ def __init__(self):
+ if not Database.db:
+ Database.initialize(
+ host=os.getenv("MONGODB_HOST"),
+ port=int(os.getenv("MONGODB_PORT")),
+ db_name=os.getenv("DATABASE_NAME"),
+ username=os.getenv("MONGODB_USERNAME"),
+ password=os.getenv("MONGODB_PASSWORD"),
+ auth_source=os.getenv("MONGODB_AUTH_SOURCE")
+ )
+
+ @classmethod
+ def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
+ try:
+ if username and password:
+ uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
+ else:
+ uri = f"mongodb://{host}:{port}"
+
+ client = pymongo.MongoClient(uri)
+ cls.db = client[db_name]
+ # 测试连接
+ client.server_info()
+ logger.success("MongoDB连接成功!")
+
+ except Exception as e:
+ logger.error(f"初始化MongoDB失败: {str(e)}")
+ raise
+
+def calculate_information_content(text):
+ """计算文本的信息量(熵)"""
+ char_count = Counter(text)
+ total_chars = len(text)
+
+ entropy = 0
+ for count in char_count.values():
+ probability = count / total_chars
+ entropy -= probability * math.log2(probability)
+
+ return entropy
+
+def get_cloest_chat_from_db(db, length: int, timestamp: str):
+ """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
+ chat_text = ''
+ closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
+
+ if closest_record and closest_record.get('memorized', 0) < 4:
+ closest_time = closest_record['time']
+ group_id = closest_record['group_id'] # 获取groupid
+ # 获取该时间戳之后的length条消息,且groupid相同
+ chat_records = list(db.db.messages.find(
+ {"time": {"$gt": closest_time}, "group_id": group_id}
+ ).sort('time', 1).limit(length))
+
+ # 更新每条消息的memorized属性
+ for record in chat_records:
+ # 检查当前记录的memorized值
+ current_memorized = record.get('memorized', 0)
+ if current_memorized > 3:
+ print(f"消息已读取3次,跳过")
+ return ''
+
+ # 更新memorized值
+ db.db.messages.update_one(
+ {"_id": record["_id"]},
+ {"$set": {"memorized": current_memorized + 1}}
+ )
+
+ chat_text += record["detailed_plain_text"]
+
+ return chat_text
+ print(f"消息已读取3次,跳过")
+ return ''
+
+class Memory_graph:
+ def __init__(self):
+ self.G = nx.Graph() # 使用 networkx 的图结构
+ self.db = Database.get_instance()
+
+ def connect_dot(self, concept1, concept2):
+ # 如果边已存在,增加 strength
+ if self.G.has_edge(concept1, concept2):
+ self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
+ else:
+ # 如果是新边,初始化 strength 为 1
+ self.G.add_edge(concept1, concept2, strength=1)
+
+ def add_dot(self, concept, memory):
+ if concept in self.G:
+ # 如果节点已存在,将新记忆添加到现有列表中
+ if 'memory_items' in self.G.nodes[concept]:
+ if not isinstance(self.G.nodes[concept]['memory_items'], list):
+ # 如果当前不是列表,将其转换为列表
+ self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
+ self.G.nodes[concept]['memory_items'].append(memory)
+ else:
+ self.G.nodes[concept]['memory_items'] = [memory]
+ else:
+ # 如果是新节点,创建新的记忆列表
+ self.G.add_node(concept, memory_items=[memory])
+
+ def get_dot(self, concept):
+ # 检查节点是否存在于图中
+ if concept in self.G:
+ # 从图中获取节点数据
+ node_data = self.G.nodes[concept]
+ return concept, node_data
+ return None
+
+ def get_related_item(self, topic, depth=1):
+ if topic not in self.G:
+ return [], []
+
+ first_layer_items = []
+ second_layer_items = []
+
+ # 获取相邻节点
+ neighbors = list(self.G.neighbors(topic))
+
+ # 获取当前节点的记忆项
+ node_data = self.get_dot(topic)
+ if node_data:
+ concept, data = node_data
+ if 'memory_items' in data:
+ memory_items = data['memory_items']
+ if isinstance(memory_items, list):
+ first_layer_items.extend(memory_items)
+ else:
+ first_layer_items.append(memory_items)
+
+ # 只在depth=2时获取第二层记忆
+ if depth >= 2:
+ # 获取相邻节点的记忆项
+ for neighbor in neighbors:
+ node_data = self.get_dot(neighbor)
+ if node_data:
+ concept, data = node_data
+ if 'memory_items' in data:
+ memory_items = data['memory_items']
+ if isinstance(memory_items, list):
+ second_layer_items.extend(memory_items)
+ else:
+ second_layer_items.append(memory_items)
+
+ return first_layer_items, second_layer_items
+
+ @property
+ def dots(self):
+ # 返回所有节点对应的 Memory_dot 对象
+ return [self.get_dot(node) for node in self.G.nodes()]
+
+# 海马体
+class Hippocampus:
+ def __init__(self, memory_graph: Memory_graph):
+ self.memory_graph = memory_graph
+ self.llm_model = LLMModel()
+ self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
+ self.llm_model_get_topic = LLMModel(model_name="Pro/Qwen/Qwen2.5-7B-Instruct")
+ self.llm_model_summary = LLMModel(model_name="Qwen/Qwen2.5-32B-Instruct")
+
+ def get_memory_sample(self, chat_size=20, time_frequency:dict={'near':2,'mid':4,'far':3}):
+ current_timestamp = datetime.datetime.now().timestamp()
+ chat_text = []
+ #短期:1h 中期:4h 长期:24h
+ for _ in range(time_frequency.get('near')): # 循环10次
+ random_time = current_timestamp - random.randint(1, 3600*4) # 随机时间
+ chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
+ chat_text.append(chat_)
+ for _ in range(time_frequency.get('mid')): # 循环10次
+ random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
+ chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
+ chat_text.append(chat_)
+ for _ in range(time_frequency.get('far')): # 循环10次
+ random_time = current_timestamp - random.randint(3600*24, 3600*24*7) # 随机时间
+ chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
+ chat_text.append(chat_)
+ return [chat for chat in chat_text if chat]
+
+ def calculate_topic_num(self,text, compress_rate):
+ """计算文本的话题数量"""
+ information_content = calculate_information_content(text)
+ topic_by_length = text.count('\n')*compress_rate
+ topic_by_information_content = max(1, min(5, int((information_content-3) * 2)))
+ topic_num = int((topic_by_length + topic_by_information_content)/2)
+ print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
+ return topic_num
+
+ async def memory_compress(self, input_text, compress_rate=0.1):
+ print(input_text)
+
+ #获取topics
+ topic_num = self.calculate_topic_num(input_text, compress_rate)
+ topics_response = await self.llm_model_get_topic.generate_response_async(self.find_topic_llm(input_text, topic_num))
+ # 修改话题处理逻辑
+ topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
+ print(f"话题: {topics}")
+
+ # 创建所有话题的请求任务
+ tasks = []
+ for topic in topics:
+ topic_what_prompt = self.topic_what(input_text, topic)
+ # 创建异步任务
+ task = self.llm_model_small.generate_response_async(topic_what_prompt)
+ tasks.append((topic.strip(), task))
+
+ # 等待所有任务完成
+ compressed_memory = set()
+ for topic, task in tasks:
+ response = await task
+ if response:
+ compressed_memory.add((topic, response[0]))
+
+ return compressed_memory
+
+ async def operation_build_memory(self, chat_size=12):
+ # 最近消息获取频率
+ time_frequency = {'near': 3, 'mid': 8, 'far': 5}
+ memory_sample = self.get_memory_sample(chat_size, time_frequency)
+
+ all_topics = [] # 用于存储所有话题
+
+ for i, input_text in enumerate(memory_sample, 1):
+ # 加载进度可视化
+ all_topics = []
+ progress = (i / len(memory_sample)) * 100
+ bar_length = 30
+ filled_length = int(bar_length * i // len(memory_sample))
+ bar = '█' * filled_length + '-' * (bar_length - filled_length)
+ print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
+
+ # 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
+ compressed_memory = set()
+ compress_rate = 0.1
+ compressed_memory = await self.memory_compress(input_text, compress_rate)
+ print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
+
+ # 将记忆加入到图谱中
+ for topic, memory in compressed_memory:
+ print(f"\033[1;32m添加节点\033[0m: {topic}")
+ self.memory_graph.add_dot(topic, memory)
+ all_topics.append(topic) # 收集所有话题
+ for i in range(len(all_topics)):
+ for j in range(i + 1, len(all_topics)):
+ print(f"\033[1;32m连接节点\033[0m: {all_topics[i]} 和 {all_topics[j]}")
+ self.memory_graph.connect_dot(all_topics[i], all_topics[j])
+
+
+
+
+ self.sync_memory_to_db()
+
+ def sync_memory_from_db(self):
+ """
+ 从数据库同步数据到内存中的图结构
+ 将清空当前内存中的图,并从数据库重新加载所有节点和边
+ """
+ # 清空当前图
+ self.memory_graph.G.clear()
+
+ # 从数据库加载所有节点
+ nodes = self.memory_graph.db.db.graph_data.nodes.find()
+ for node in nodes:
+ concept = node['concept']
+ memory_items = node.get('memory_items', [])
+ # 确保memory_items是列表
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+ # 添加节点到图中
+ self.memory_graph.G.add_node(concept, memory_items=memory_items)
+
+ # 从数据库加载所有边
+ edges = self.memory_graph.db.db.graph_data.edges.find()
+ for edge in edges:
+ source = edge['source']
+ target = edge['target']
+ strength = edge.get('strength', 1) # 获取 strength,默认为 1
+ # 只有当源节点和目标节点都存在时才添加边
+ if source in self.memory_graph.G and target in self.memory_graph.G:
+ self.memory_graph.G.add_edge(source, target, strength=strength)
+
+ logger.success("从数据库同步记忆图谱完成")
+
+ def calculate_node_hash(self, concept, memory_items):
+ """
+ 计算节点的特征值
+ """
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+ # 将记忆项排序以确保相同内容生成相同的哈希值
+ sorted_items = sorted(memory_items)
+ # 组合概念和记忆项生成特征值
+ content = f"{concept}:{'|'.join(sorted_items)}"
+ return hash(content)
+
+ def calculate_edge_hash(self, source, target):
+ """
+ 计算边的特征值
+ """
+ # 对源节点和目标节点排序以确保相同的边生成相同的哈希值
+ nodes = sorted([source, target])
+ return hash(f"{nodes[0]}:{nodes[1]}")
+
+ def sync_memory_to_db(self):
+ """
+ 检查并同步内存中的图结构与数据库
+ 使用特征值(哈希值)快速判断是否需要更新
+ """
+ # 获取数据库中所有节点和内存中所有节点
+ db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
+ memory_nodes = list(self.memory_graph.G.nodes(data=True))
+
+ # 转换数据库节点为字典格式,方便查找
+ db_nodes_dict = {node['concept']: node for node in db_nodes}
+
+ # 检查并更新节点
+ for concept, data in memory_nodes:
+ memory_items = data.get('memory_items', [])
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+
+ # 计算内存中节点的特征值
+ memory_hash = self.calculate_node_hash(concept, memory_items)
+
+ if concept not in db_nodes_dict:
+ # 数据库中缺少的节点,添加
+ logger.info(f"添加新节点: {concept}")
+ node_data = {
+ 'concept': concept,
+ 'memory_items': memory_items,
+ 'hash': memory_hash
+ }
+ self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
+ else:
+ # 获取数据库中节点的特征值
+ db_node = db_nodes_dict[concept]
+ db_hash = db_node.get('hash', None)
+
+ # 如果特征值不同,则更新节点
+ if db_hash != memory_hash:
+ logger.info(f"更新节点内容: {concept}")
+ self.memory_graph.db.db.graph_data.nodes.update_one(
+ {'concept': concept},
+ {'$set': {
+ 'memory_items': memory_items,
+ 'hash': memory_hash
+ }}
+ )
+
+ # 检查并删除数据库中多余的节点
+ memory_concepts = set(node[0] for node in memory_nodes)
+ for db_node in db_nodes:
+ if db_node['concept'] not in memory_concepts:
+ logger.info(f"删除多余节点: {db_node['concept']}")
+ self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
+
+ # 处理边的信息
+ db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
+ memory_edges = list(self.memory_graph.G.edges())
+
+ # 创建边的哈希值字典
+ db_edge_dict = {}
+ for edge in db_edges:
+ edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
+ db_edge_dict[(edge['source'], edge['target'])] = {
+ 'hash': edge_hash,
+ 'num': edge.get('num', 1)
+ }
+
+ # 检查并更新边
+ for source, target in memory_edges:
+ edge_hash = self.calculate_edge_hash(source, target)
+ edge_key = (source, target)
+
+ if edge_key not in db_edge_dict:
+ # 添加新边
+ logger.info(f"添加新边: {source} - {target}")
+ edge_data = {
+ 'source': source,
+ 'target': target,
+ 'num': 1,
+ 'hash': edge_hash
+ }
+ self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
+ else:
+ # 检查边的特征值是否变化
+ if db_edge_dict[edge_key]['hash'] != edge_hash:
+ logger.info(f"更新边: {source} - {target}")
+ self.memory_graph.db.db.graph_data.edges.update_one(
+ {'source': source, 'target': target},
+ {'$set': {'hash': edge_hash}}
+ )
+
+ # 删除多余的边
+ memory_edge_set = set(memory_edges)
+ for edge_key in db_edge_dict:
+ if edge_key not in memory_edge_set:
+ source, target = edge_key
+ logger.info(f"删除多余边: {source} - {target}")
+ self.memory_graph.db.db.graph_data.edges.delete_one({
+ 'source': source,
+ 'target': target
+ })
+
+ logger.success("完成记忆图谱与数据库的差异同步")
+
+ def find_topic_llm(self,text, topic_num):
+ # prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
+ prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
+ return prompt
+
+ def topic_what(self,text, topic):
+ # prompt = f'这是一段文字:{text}。我想知道这段文字里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
+ prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
+ return prompt
+
+ def remove_node_from_db(self, topic):
+ """
+ 从数据库中删除指定节点及其相关的边
+
+ Args:
+ topic: 要删除的节点概念
+ """
+ # 删除节点
+ self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': topic})
+ # 删除所有涉及该节点的边
+ self.memory_graph.db.db.graph_data.edges.delete_many({
+ '$or': [
+ {'source': topic},
+ {'target': topic}
+ ]
+ })
+
+ def forget_topic(self, topic):
+ """
+ 随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点
+ 只在内存中的图上操作,不直接与数据库交互
+
+ Args:
+ topic: 要删除记忆的话题
+
+ Returns:
+ removed_item: 被删除的记忆项,如果没有删除任何记忆则返回 None
+ """
+ if topic not in self.memory_graph.G:
+ return None
+
+ # 获取话题节点数据
+ node_data = self.memory_graph.G.nodes[topic]
+
+ # 如果节点存在memory_items
+ if 'memory_items' in node_data:
+ memory_items = node_data['memory_items']
+
+ # 确保memory_items是列表
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+
+ # 如果有记忆项可以删除
+ if memory_items:
+ # 随机选择一个记忆项删除
+ removed_item = random.choice(memory_items)
+ memory_items.remove(removed_item)
+
+ # 更新节点的记忆项
+ if memory_items:
+ self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
+ else:
+ # 如果没有记忆项了,删除整个节点
+ self.memory_graph.G.remove_node(topic)
+
+ return removed_item
+
+ return None
+
+ async def operation_forget_topic(self, percentage=0.1):
+ """
+ 随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘
+
+ Args:
+ percentage: 要检查的节点比例,默认为0.1(10%)
+ """
+ # 获取所有节点
+ all_nodes = list(self.memory_graph.G.nodes())
+ # 计算要检查的节点数量
+ check_count = max(1, int(len(all_nodes) * percentage))
+ # 随机选择节点
+ nodes_to_check = random.sample(all_nodes, check_count)
+
+ forgotten_nodes = []
+ for node in nodes_to_check:
+ # 获取节点的连接数
+ connections = self.memory_graph.G.degree(node)
+
+ # 获取节点的内容条数
+ memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+ content_count = len(memory_items)
+
+ # 检查连接强度
+ weak_connections = True
+ if connections > 1: # 只有当连接数大于1时才检查强度
+ for neighbor in self.memory_graph.G.neighbors(node):
+ strength = self.memory_graph.G[node][neighbor].get('strength', 1)
+ if strength > 2:
+ weak_connections = False
+ break
+
+ # 如果满足遗忘条件
+ if (connections <= 1 and weak_connections) or content_count <= 2:
+ removed_item = self.forget_topic(node)
+ if removed_item:
+ forgotten_nodes.append((node, removed_item))
+ logger.info(f"遗忘节点 {node} 的记忆: {removed_item}")
+
+ # 同步到数据库
+ if forgotten_nodes:
+ self.sync_memory_to_db()
+ logger.info(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
+ else:
+ logger.info("本次检查没有节点满足遗忘条件")
+
+ async def merge_memory(self, topic):
+ """
+ 对指定话题的记忆进行合并压缩
+
+ Args:
+ topic: 要合并的话题节点
+ """
+ # 获取节点的记忆项
+ memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+
+ # 如果记忆项不足,直接返回
+ if len(memory_items) < 10:
+ return
+
+ # 随机选择10条记忆
+ selected_memories = random.sample(memory_items, 10)
+
+ # 拼接成文本
+ merged_text = "\n".join(selected_memories)
+ print(f"\n[合并记忆] 话题: {topic}")
+ print(f"选择的记忆:\n{merged_text}")
+
+ # 使用memory_compress生成新的压缩记忆
+ compressed_memories = await self.memory_compress(merged_text, 0.1)
+
+ # 从原记忆列表中移除被选中的记忆
+ for memory in selected_memories:
+ memory_items.remove(memory)
+
+ # 添加新的压缩记忆
+ for _, compressed_memory in compressed_memories:
+ memory_items.append(compressed_memory)
+ print(f"添加压缩记忆: {compressed_memory}")
+
+ # 更新节点的记忆项
+ self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
+ print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
+
+ async def operation_merge_memory(self, percentage=0.1):
+ """
+ 随机检查一定比例的节点,对内容数量超过100的节点进行记忆合并
+
+ Args:
+ percentage: 要检查的节点比例,默认为0.1(10%)
+ """
+ # 获取所有节点
+ all_nodes = list(self.memory_graph.G.nodes())
+ # 计算要检查的节点数量
+ check_count = max(1, int(len(all_nodes) * percentage))
+ # 随机选择节点
+ nodes_to_check = random.sample(all_nodes, check_count)
+
+ merged_nodes = []
+ for node in nodes_to_check:
+ # 获取节点的内容条数
+ memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
+ if not isinstance(memory_items, list):
+ memory_items = [memory_items] if memory_items else []
+ content_count = len(memory_items)
+
+ # 如果内容数量超过100,进行合并
+ if content_count > 100:
+ print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
+ await self.merge_memory(node)
+ merged_nodes.append(node)
+
+ # 同步到数据库
+ if merged_nodes:
+ self.sync_memory_to_db()
+ print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
+ else:
+ print("\n本次检查没有需要合并的节点")
+
+
+def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
+ # 设置中文字体
+ plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
+ plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
+
+ G = memory_graph.G
+
+ # 创建一个新图用于可视化
+ H = G.copy()
+
+ # 计算节点大小和颜色
+ node_colors = []
+ node_sizes = []
+ nodes = list(H.nodes())
+
+ # 获取最大记忆数用于归一化节点大小
+ max_memories = 1
+ for node in nodes:
+ memory_items = H.nodes[node].get('memory_items', [])
+ memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
+ max_memories = max(max_memories, memory_count)
+
+ # 计算每个节点的大小和颜色
+ for node in nodes:
+ # 计算节点大小(基于记忆数量)
+ memory_items = H.nodes[node].get('memory_items', [])
+ memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
+ # 使用指数函数使变化更明显
+ ratio = memory_count / max_memories
+ size = 400 + 2000 * (ratio ** 2) # 增大节点大小
+ node_sizes.append(size)
+
+ # 计算节点颜色(基于连接数)
+ degree = H.degree(node)
+ if degree >= 30:
+ node_colors.append((1.0, 0, 0)) # 亮红色 (#FF0000)
+ else:
+ # 将1-10映射到0-1的范围
+ color_ratio = (degree - 1) / 29.0 if degree > 1 else 0
+ # 使用蓝到红的渐变
+ red = min(0.9, color_ratio)
+ blue = max(0.0, 1.0 - color_ratio)
+ node_colors.append((red, 0, blue))
+
+ # 绘制图形
+ plt.figure(figsize=(16, 12)) # 减小图形尺寸
+ pos = nx.spring_layout(H,
+ k=1, # 调整节点间斥力
+ iterations=100, # 增加迭代次数
+ scale=1.5, # 减小布局尺寸
+ weight='strength') # 使用边的strength属性作为权重
+
+ nx.draw(H, pos,
+ with_labels=True,
+ node_color=node_colors,
+ node_size=node_sizes,
+ font_size=12, # 保持增大的字体大小
+ font_family='SimHei',
+ font_weight='bold',
+ edge_color='gray',
+ width=1.5) # 统一的边宽度
+
+ title = '记忆图谱可视化 - 节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近'
+ plt.title(title, fontsize=16, fontfamily='SimHei')
+ plt.show()
+
+async def main():
+ # 初始化数据库
+ logger.info("正在初始化数据库连接...")
+ db = Database.get_instance()
+ start_time = time.time()
+
+ test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
+
+ # 创建记忆图
+ memory_graph = Memory_graph()
+
+ # 创建海马体
+ hippocampus = Hippocampus(memory_graph)
+
+ # 从数据库同步数据
+ hippocampus.sync_memory_from_db()
+
+ end_time = time.time()
+ logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
+
+ # 构建记忆
+ if test_pare['do_build_memory']:
+ logger.info("开始构建记忆...")
+ chat_size = 20
+ await hippocampus.operation_build_memory(chat_size=chat_size)
+
+ end_time = time.time()
+ logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m")
+
+ if test_pare['do_forget_topic']:
+ logger.info("开始遗忘记忆...")
+ await hippocampus.operation_forget_topic(percentage=0.1)
+
+ end_time = time.time()
+ logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
+
+ if test_pare['do_merge_memory']:
+ logger.info("开始合并记忆...")
+ await hippocampus.operation_merge_memory(percentage=0.1)
+
+ end_time = time.time()
+ logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
+
+ if test_pare['do_visualize_graph']:
+ # 展示优化后的图形
+ logger.info("生成记忆图谱可视化...")
+ print("\n生成优化后的记忆图谱:")
+ visualize_graph_lite(memory_graph)
+
+ if test_pare['do_query']:
+ # 交互式查询
+ while True:
+ query = input("\n请输入新的查询概念(输入'退出'以结束):")
+ if query.lower() == '退出':
+ break
+
+ items_list = memory_graph.get_related_item(query)
+ if items_list:
+ first_layer, second_layer = items_list
+ if first_layer:
+ print("\n直接相关的记忆:")
+ for item in first_layer:
+ print(f"- {item}")
+ if second_layer:
+ print("\n间接相关的记忆:")
+ for item in second_layer:
+ print(f"- {item}")
+ else:
+ print("未找到相关记忆。")
+
+
+if __name__ == "__main__":
+ import asyncio
+ asyncio.run(main())
+
+
diff --git a/src/plugins/memory_system/llm_module_memory_make.py b/src/plugins/memory_system/offline_llm.py
similarity index 50%
rename from src/plugins/memory_system/llm_module_memory_make.py
rename to src/plugins/memory_system/offline_llm.py
index 41a5d7c0..5e877dce 100644
--- a/src/plugins/memory_system/llm_module_memory_make.py
+++ b/src/plugins/memory_system/offline_llm.py
@@ -2,28 +2,23 @@ import os
import requests
from typing import Tuple, Union
import time
-from nonebot import get_driver
import aiohttp
import asyncio
from loguru import logger
-from src.plugins.chat.config import BotConfig, global_config
-
-driver = get_driver()
-config = driver.config
class LLMModel:
- def __init__(self, model_name=global_config.SILICONFLOW_MODEL_V3, **kwargs):
+ def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
- self.api_key = config.siliconflow_key
- self.base_url = config.siliconflow_base_url
+ self.api_key = os.getenv("SILICONFLOW_KEY")
+ self.base_url = os.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url
- async def generate_response(self, prompt: str) -> Tuple[str, str]:
+ def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
@@ -47,7 +42,60 @@ class LLMModel:
for retry in range(max_retries):
try:
- async with aiohttp.ClientSession() as session:
+ response = requests.post(api_url, headers=headers, json=data)
+
+ if response.status_code == 429:
+ wait_time = base_wait_time * (2 ** retry) # 指数退避
+ logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
+ time.sleep(wait_time)
+ continue
+
+ response.raise_for_status() # 检查其他响应状态
+
+ result = response.json()
+ if "choices" in result and len(result["choices"]) > 0:
+ content = result["choices"][0]["message"]["content"]
+ reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
+ return content, reasoning_content
+ return "没有返回结果", ""
+
+ except Exception as e:
+ if retry < max_retries - 1: # 如果还有重试机会
+ wait_time = base_wait_time * (2 ** retry)
+ logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
+ time.sleep(wait_time)
+ else:
+ logger.error(f"请求失败: {str(e)}")
+ return f"请求失败: {str(e)}", ""
+
+ logger.error("达到最大重试次数,请求仍然失败")
+ return "达到最大重试次数,请求仍然失败", ""
+
+ async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
+ """异步方式根据输入的提示生成模型的响应"""
+ headers = {
+ "Authorization": f"Bearer {self.api_key}",
+ "Content-Type": "application/json"
+ }
+
+ # 构建请求体
+ data = {
+ "model": self.model_name,
+ "messages": [{"role": "user", "content": prompt}],
+ "temperature": 0.5,
+ **self.params
+ }
+
+ # 发送请求到完整的 chat/completions 端点
+ api_url = f"{self.base_url.rstrip('/')}/chat/completions"
+ logger.info(f"Request URL: {api_url}") # 记录请求的 URL
+
+ max_retries = 3
+ base_wait_time = 15
+
+ async with aiohttp.ClientSession() as session:
+ for retry in range(max_retries):
+ try:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
@@ -63,15 +111,15 @@ class LLMModel:
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
-
- except Exception as e:
- if retry < max_retries - 1: # 如果还有重试机会
- wait_time = base_wait_time * (2 ** retry)
- logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
- await asyncio.sleep(wait_time)
- else:
- logger.error(f"请求失败: {str(e)}")
- return f"请求失败: {str(e)}", ""
-
- logger.error("达到最大重试次数,请求仍然失败")
- return "达到最大重试次数,请求仍然失败", ""
+
+ except Exception as e:
+ if retry < max_retries - 1: # 如果还有重试机会
+ wait_time = base_wait_time * (2 ** retry)
+ logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
+ await asyncio.sleep(wait_time)
+ else:
+ logger.error(f"请求失败: {str(e)}")
+ return f"请求失败: {str(e)}", ""
+
+ logger.error("达到最大重试次数,请求仍然失败")
+ return "达到最大重试次数,请求仍然失败", ""
diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py
index 4741d259..11d7e2b7 100644
--- a/src/plugins/models/utils_model.py
+++ b/src/plugins/models/utils_model.py
@@ -55,13 +55,24 @@ class LLM_request:
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
+
+ if response.status in [500, 503]:
+ logger.error(f"服务器错误: {response.status}")
+ raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
- content = result["choices"][0]["message"]["content"]
- reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
+ message = result["choices"][0]["message"]
+ content = message.get("content", "")
+ think_match = None
+ reasoning_content = message.get("reasoning_content", "")
+ if not reasoning_content:
+ think_match = re.search(r'(.*?)', content, re.DOTALL)
+ if think_match:
+ reasoning_content = think_match.group(1).strip()
+ content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip()
return content, reasoning_content
return "没有返回结果", ""
@@ -117,6 +128,7 @@ class LLM_request:
base_wait_time = 15
current_image_base64 = image_base64
+ current_image_base64 = compress_base64_image_by_scale(current_image_base64)
for retry in range(max_retries):
@@ -163,6 +175,61 @@ class LLM_request:
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数,API请求仍然失败")
+ async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
+ """异步方式根据输入的提示生成模型的响应"""
+ headers = {
+ "Authorization": f"Bearer {self.api_key}",
+ "Content-Type": "application/json"
+ }
+
+ # 构建请求体
+ data = {
+ "model": self.model_name,
+ "messages": [{"role": "user", "content": prompt}],
+ "temperature": 0.5,
+ **self.params
+ }
+
+ # 发送请求到完整的 chat/completions 端点
+ api_url = f"{self.base_url.rstrip('/')}/chat/completions"
+ logger.info(f"Request URL: {api_url}") # 记录请求的 URL
+
+ max_retries = 3
+ base_wait_time = 15
+
+ async with aiohttp.ClientSession() as session:
+ for retry in range(max_retries):
+ try:
+ async with session.post(api_url, headers=headers, json=data) as response:
+ if response.status == 429:
+ wait_time = base_wait_time * (2 ** retry) # 指数退避
+ logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
+ await asyncio.sleep(wait_time)
+ continue
+
+ response.raise_for_status() # 检查其他响应状态
+
+ result = await response.json()
+ if "choices" in result and len(result["choices"]) > 0:
+ content = result["choices"][0]["message"]["content"]
+ reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
+ return content, reasoning_content
+ return "没有返回结果", ""
+
+ except Exception as e:
+ if retry < max_retries - 1: # 如果还有重试机会
+ wait_time = base_wait_time * (2 ** retry)
+ logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
+ await asyncio.sleep(wait_time)
+ else:
+ logger.error(f"请求失败: {str(e)}")
+ return f"请求失败: {str(e)}", ""
+
+ logger.error("达到最大重试次数,请求仍然失败")
+ return "达到最大重试次数,请求仍然失败", ""
+
+
+
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
"""同步方法:根据输入的提示和图片生成模型的响应"""
headers = {
@@ -170,6 +237,8 @@ class LLM_request:
"Content-Type": "application/json"
}
+ image_base64=compress_base64_image_by_scale(image_base64)
+
# 构建请求体
data = {
"model": self.model_name,
diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py
index 93bb0413..f2b11c33 100644
--- a/src/plugins/schedule/schedule_generator.py
+++ b/src/plugins/schedule/schedule_generator.py
@@ -1,12 +1,12 @@
import datetime
import os
-from typing import List, Dict
+from typing import List, Dict, Union
from ...common.database import Database # 使用正确的导入语法
from src.plugins.chat.config import global_config
from nonebot import get_driver
from ..models.utils_model import LLM_request
from loguru import logger
-
+import json
driver = get_driver()
config = driver.config
@@ -58,19 +58,20 @@ class ScheduleGenerator:
elif read_only == False:
print(f"{date_str}的日程不存在,准备生成新的日程。")
- prompt = f"""我是{global_config.BOT_NICKNAME},{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}({weekday})的日程安排,包括:
+ prompt = f"""我是{global_config.BOT_NICKNAME},{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}({weekday})的日程安排,包括:"""+\
+ """
1. 早上的学习和工作安排
2. 下午的活动和任务
3. 晚上的计划和休息时间
- 请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床"。"""
+ 请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用JSON格式返回日程表,仅返回内容,不要返回注释,时间采用24小时制,格式为{"时间": "活动","时间": "活动",...}。"""
try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
+ self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
except Exception as e:
logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了"
# print(self.schedule_text)
- self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
else:
print(f"{date_str}的日程不存在。")
schedule_text = "忘了"
@@ -80,20 +81,15 @@ class ScheduleGenerator:
schedule_form = self._parse_schedule(schedule_text)
return schedule_text,schedule_form
- def _parse_schedule(self, schedule_text: str) -> Dict[str, str]:
+ def _parse_schedule(self, schedule_text: str) -> Union[bool, Dict[str, str]]:
"""解析日程文本,转换为时间和活动的字典"""
- schedule_dict = {}
- # 按行分割日程文本
- lines = schedule_text.strip().split('\n')
- for line in lines:
- # print(line)
- if ',' in line:
- # 假设格式为 "时间: 活动"
- time_str, activity = line.split(',', 1)
- # print(time_str)
- # print(activity)
- schedule_dict[time_str.strip()] = activity.strip()
- return schedule_dict
+ try:
+ schedule_dict = json.loads(schedule_text)
+ return schedule_dict
+ except json.JSONDecodeError as e:
+ print(schedule_text)
+ print(f"解析日程失败: {str(e)}")
+ return False
def _parse_time(self, time_str: str) -> str:
"""解析时间字符串,转换为时间"""
@@ -108,6 +104,8 @@ class ScheduleGenerator:
min_diff = float('inf')
# 检查今天的日程
+ if not self.today_schedule:
+ return "摸鱼"
for time_str in self.today_schedule.keys():
diff = abs(self._time_diff(current_time, time_str))
if closest_time is None or diff < min_diff:
@@ -148,11 +146,14 @@ class ScheduleGenerator:
def print_schedule(self):
"""打印完整的日程安排"""
-
- print("\n=== 今日日程安排 ===")
- for time_str, activity in self.today_schedule.items():
- print(f"时间[{time_str}]: 活动[{activity}]")
- print("==================\n")
+ if not self._parse_schedule(self.today_schedule_text):
+ print("今日日程有误,将在下次运行时重新生成")
+ self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
+ else:
+ print("\n=== 今日日程安排 ===")
+ for time_str, activity in self.today_schedule.items():
+ print(f"时间[{time_str}]: 活动[{activity}]")
+ print("==================\n")
# def main():
# # 使用示例