diff --git a/.env.prod b/.env.prod new file mode 100644 index 00000000..3d795978 --- /dev/null +++ b/.env.prod @@ -0,0 +1,24 @@ +HOST=127.0.0.1 +PORT=8080 + +COMMAND_START=["/"] + +# 插件配置 +PLUGINS=["src2.plugins.chat"] + +# 默认配置 +MONGODB_HOST=127.0.0.1 +MONGODB_PORT=27017 +DATABASE_NAME=MegBot + +MONGODB_USERNAME = "" # 默认空值 +MONGODB_PASSWORD = "" # 默认空值 +MONGODB_AUTH_SOURCE = "" # 默认空值 + +#key and url +CHAT_ANY_WHERE_KEY= +SILICONFLOW_KEY= +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ +DEEP_SEEK_KEY= +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 26510818..c19b9ce3 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ reasoning_content.bat reasoning_window.bat queue_update.txt memory_graph.gml +.env.dev # Byte-compiled / optimized / DLL files diff --git a/README.md b/README.md index a85fcc4e..1310d487 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ > ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次 > ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语( -关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言! +关于麦麦的开发和建议相关的讨论群:766798517(不建议发布无关消息)这里不会有麦麦发言! ## 开发计划TODO:LIST @@ -41,16 +41,13 @@ - config自动生成和检测 - log别用print - 给发送消息写专门的类 +- 改进表情包发送逻辑l -
- -
- ## 📚 详细文档 - [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成) -### 安装方法(还没测试好,现在部署可能遇到未知问题!!!!) +### 安装方法(还没测试好,随时outdated ,现在部署可能遇到未知问题!!!!) #### Linux 使用 Docker Compose 部署 获取项目根目录中的```docker-compose.yml```文件,运行以下命令 diff --git a/bot.py b/bot.py index f9544f40..8741eca7 100644 --- a/bot.py +++ b/bot.py @@ -1,14 +1,62 @@ +import os import nonebot from nonebot.adapters.onebot.v11 import Adapter +from dotenv import load_dotenv +from loguru import logger + +'''彩蛋''' +from colorama import init, Fore +init() +text = "多年以后,面对行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午" +rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA] +rainbow_text = "" +for i, char in enumerate(text): + rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char +print(rainbow_text) +'''彩蛋''' + +# 首先加载基础环境变量 +if os.path.exists(".env"): + load_dotenv(".env") + logger.success("成功加载基础环境变量配置") +else: + logger.error("基础环境变量配置文件 .env 不存在") + exit(1) +# 根据 ENVIRONMENT 加载对应的环境配置 +env = os.getenv("ENVIRONMENT") +env_file = f".env.{env}" + +if env_file == ".env.dev" and os.path.exists(env_file): + logger.success("加载开发环境变量配置") + load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量 +elif env_file == ".env.prod" and os.path.exists(env_file): + logger.success("加载环境变量配置") + load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量 +else: + logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.") + exit(1) -# 初始化 NoneBot nonebot.init( - # napcat 默认使用 8080 端口 - websocket_port=8080, - # 设置日志级别 + # 从环境变量中读取配置 + websocket_port=os.getenv("PORT", 8080), + host=os.getenv("HOST", "127.0.0.1"), log_level="INFO", - # 设置超级用户 - superusers={"你的QQ号"} + # 添加自定义配置 + mongodb_host=os.getenv("MONGODB_HOST", "127.0.0.1"), + mongodb_port=os.getenv("MONGODB_PORT", 27017), + database_name=os.getenv("DATABASE_NAME", "MegBot"), + mongodb_username=os.getenv("MONGODB_USERNAME", ""), + mongodb_password=os.getenv("MONGODB_PASSWORD", ""), + mongodb_auth_source=os.getenv("MONGODB_AUTH_SOURCE", ""), + # API相关配置 + chat_any_where_key=os.getenv("CHAT_ANY_WHERE_KEY", ""), + siliconflow_key=os.getenv("SILICONFLOW_KEY", ""), + chat_any_where_base_url=os.getenv("CHAT_ANY_WHERE_BASE_URL", "https://api.chatanywhere.tech/v1"), + siliconflow_base_url=os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1/"), + deep_seek_key=os.getenv("DEEP_SEEK_KEY", ""), + deep_seek_base_url=os.getenv("DEEP_SEEK_BASE_URL", "https://api.deepseek.com/v1"), + # 插件配置 + plugins=os.getenv("PLUGINS", ["src2.plugins.chat"]) ) # 注册适配器 diff --git a/config/bot_config_toml b/config/bot_config_toml new file mode 100644 index 00000000..83a3c497 --- /dev/null +++ b/config/bot_config_toml @@ -0,0 +1,45 @@ +[bot] +qq = 123456 #填入你的机器人QQ +nickname = "麦麦" #你希望bot被称呼的名字 + +[message] +min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息 +max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃 +emoji_chance = 0.2 # 麦麦使用表情包的概率 + +[emoji] +check_interval = 120 +register_interval = 10 + +[cq_code] +enable_pic_translate = false + + +[response] +api_using = "siliconflow" # 选择大模型API,可选值为siliconflow,deepseek,建议使用siliconflow,因为识图api目前只支持siliconflow的deepseek-vl2模型 +model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率 +model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率 +model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率 + +[memory] +build_memory_interval = 300 # 记忆构建间隔 + + + +[others] +enable_advance_output = true # 开启后输出更多日志,false关闭true开启 + + +[groups] + +talk_allowed = [ + 123456,12345678 +] #可以回复消息的群 + +talk_frequency_down = [ + 123456,12345678 +] #降低回复频率的群 + +ban_user_id = [ + 123456,12345678 +] #禁止回复消息的QQ号 diff --git a/run_maimai.bat b/run_maimai.bat index 0e1bd7eb..702d39ed 100644 --- a/run_maimai.bat +++ b/run_maimai.bat @@ -2,4 +2,5 @@ call conda activate niuniu cd . REM 执行nb run命令 -nb run \ No newline at end of file +nb run +pause \ No newline at end of file diff --git a/src/gui/reasoning_gui.py b/src/gui/reasoning_gui.py index 356be3bd..f12b5979 100644 --- a/src/gui/reasoning_gui.py +++ b/src/gui/reasoning_gui.py @@ -331,9 +331,12 @@ class ReasoningGUI: def main(): """主函数""" Database.initialize( - "127.0.0.1", - 27017, - "MegBot" + host= os.getenv("MONGODB_HOST"), + port= int(os.getenv("MONGODB_PORT")), + db_name= os.getenv("DATABASE_NAME"), + username= os.getenv("MONGODB_USERNAME"), + password= os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE") ) app = ReasoningGUI() diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index 6287f8cb..3da9a0b1 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -11,13 +11,18 @@ from .relationship_manager import relationship_manager from ..schedule.schedule_generator import bot_schedule from .willing_manager import willing_manager + # 获取驱动器 driver = get_driver() +config = driver.config Database.initialize( - global_config.MONGODB_HOST, - global_config.MONGODB_PORT, - global_config.DATABASE_NAME + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source= config.mongodb_auth_source ) print("\033[1;32m[初始化数据库完成]\033[0m") @@ -34,7 +39,7 @@ emoji_manager.initialize() print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m") # 创建机器人实例 -chat_bot = ChatBot(global_config) +chat_bot = ChatBot() # 注册消息处理器 group_msg = on_message() # 创建定时任务 diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index cfe41589..dcd536b7 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -18,10 +18,9 @@ from .utils import is_mentioned_bot_in_txt, calculate_typing_time from ..memory_system.memory import memory_graph class ChatBot: - def __init__(self, config: BotConfig): - self.config = config + def __init__(self): self.storage = MessageStorage() - self.gpt = LLMResponseGenerator(config) + self.gpt = LLMResponseGenerator() self.bot = None # bot 实例引用 self._started = False @@ -39,11 +38,11 @@ class ChatBot: async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None: """处理收到的群消息""" - if event.group_id not in self.config.talk_allowed_groups: + if event.group_id not in global_config.talk_allowed_groups: return self.bot = bot # 更新 bot 实例 - if event.user_id in self.config.ban_user_id: + if event.user_id in global_config.ban_user_id: return # 打印原始消息内容 @@ -121,7 +120,7 @@ class ChatBot: event.group_id, topic[0] if topic else None, is_mentioned, - self.config, + global_config, event.user_id, message.is_emoji, interested_rate @@ -147,10 +146,14 @@ class ChatBot: thinking_message.interupt=True # 如果生成了回复,发送并记录 - + + ''' + 生成回复后的内容 + + ''' if response: - message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id) + message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) accu_typing_time = 0 for msg in response: print(f"当前消息: {msg}") @@ -161,7 +164,7 @@ class ChatBot: bot_message = Message( group_id=event.group_id, - user_id=self.config.BOT_QQ, + user_id=global_config.BOT_QQ, message_id=think_id, message_based_id=event.message_id, raw_message=msg, @@ -178,7 +181,7 @@ class ChatBot: bot_response_time = tinking_time_point - if random() < self.config.emoji_chance: + if random() < global_config.emoji_chance: emoji_path = await emoji_manager.get_emoji_for_emotion(emotion) if emoji_path: emoji_cq = CQCode.create_emoji_cq(emoji_path) @@ -190,7 +193,7 @@ class ChatBot: bot_message = Message( group_id=event.group_id, - user_id=self.config.BOT_QQ, + user_id=global_config.BOT_QQ, message_id=0, raw_message=emoji_cq, plain_text=emoji_cq, diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index a44a2c58..7c8e77fb 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -7,22 +7,13 @@ import configparser import tomli import sys from loguru import logger -from dotenv import load_dotenv +from nonebot import get_driver @dataclass class BotConfig: - """机器人配置类""" - - # 基础配置 - MONGODB_HOST: str = "mongodb" - MONGODB_PORT: int = 27017 - DATABASE_NAME: str = "MegBot" - MONGODB_USERNAME: Optional[str] = None # 默认空值 - MONGODB_PASSWORD: Optional[str] = None # 默认空值 - MONGODB_AUTH_SOURCE: Optional[str] = None # 默认空值 - + """机器人配置类""" BOT_QQ: Optional[int] = 1 BOT_NICKNAME: Optional[str] = None @@ -75,17 +66,7 @@ class BotConfig: if os.path.exists(config_path): with open(config_path, "rb") as f: toml_dict = tomli.load(f) - - # 数据库配置 - if "database" in toml_dict: - db_config = toml_dict["database"] - config.MONGODB_HOST = db_config.get("host", config.MONGODB_HOST) - config.MONGODB_PORT = db_config.get("port", config.MONGODB_PORT) - config.DATABASE_NAME = db_config.get("name", config.DATABASE_NAME) - config.MONGODB_USERNAME = db_config.get("username", config.MONGODB_USERNAME) or None # 空字符串转为 None - config.MONGODB_PASSWORD = db_config.get("password", config.MONGODB_PASSWORD) or None # 空字符串转为 None - config.MONGODB_AUTH_SOURCE = db_config.get("auth_source", config.MONGODB_AUTH_SOURCE) or None # 空字符串转为 None - + if "emoji" in toml_dict: emoji_config = toml_dict["emoji"] config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL) @@ -146,20 +127,10 @@ class BotConfig: # 获取配置文件路径 bot_config_path = BotConfig.get_default_config_path() config_dir = os.path.dirname(bot_config_path) -env_path = os.path.join(config_dir, '.env') logger.info(f"尝试从 {bot_config_path} 加载机器人配置") global_config = BotConfig.load_config(config_path=bot_config_path) -# 加载环境变量 - -logger.info(f"尝试从 {env_path} 加载环境变量配置") -if os.path.exists(env_path): - load_dotenv(env_path) - logger.success("成功加载环境变量配置") -else: - logger.error(f"环境变量配置文件不存在: {env_path}") - @dataclass class LLMConfig: """机器人配置类""" @@ -170,10 +141,11 @@ class LLMConfig: DEEP_SEEK_BASE_URL: str = None llm_config = LLMConfig() -llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY') -llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL') -llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY') -llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL') +config = get_driver().config +llm_config.SILICONFLOW_API_KEY = config.siliconflow_key +llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url +llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key +llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url if not global_config.enable_advance_output: diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index 92ca20bd..ae5d8a25 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -7,7 +7,7 @@ from PIL import Image import os from random import random from nonebot.adapters.onebot.v11 import Bot -from .config import global_config, llm_config +from .config import global_config import time import asyncio from .utils_image import storage_image,storage_emoji @@ -16,6 +16,10 @@ from .utils_user import get_user_nickname #包含CQ码类 import urllib3 from urllib3.util import create_urllib3_context +from nonebot import get_driver + +driver = get_driver() +config = driver.config # TLS1.3特殊处理 https://github.com/psf/requests/issues/6616 ctx = create_urllib3_context() @@ -179,7 +183,7 @@ class CQCode: """调用AI接口获取表情包描述""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + "Authorization": f"Bearer {config.siliconflow_key}" } payload = { @@ -206,7 +210,7 @@ class CQCode: } response = requests.post( - f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + f"{config.siliconflow_base_url}chat/completions", headers=headers, json=payload, timeout=30 @@ -224,7 +228,7 @@ class CQCode: """调用AI接口获取普通图片描述""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + "Authorization": f"Bearer {config.siliconflow_key}" } payload = { @@ -251,7 +255,7 @@ class CQCode: } response = requests.post( - f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + f"{config.siliconflow_base_url}chat/completions", headers=headers, json=payload, timeout=30 diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index a4352758..c8c9dc81 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -10,10 +10,14 @@ import hashlib from datetime import datetime import base64 import shutil -from .config import global_config, llm_config import asyncio import time +from nonebot import get_driver + +driver = get_driver() +config = driver.config + class EmojiManager: _instance = None @@ -93,7 +97,7 @@ class EmojiManager: # 准备请求数据 headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + "Authorization": f"Bearer {config.siliconflow_key}" } payload = { @@ -115,7 +119,7 @@ class EmojiManager: async with aiohttp.ClientSession() as session: async with session.post( - f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + f"{config.siliconflow_base_url}chat/completions", headers=headers, json=payload ) as response: @@ -249,7 +253,7 @@ class EmojiManager: async with aiohttp.ClientSession() as session: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + "Authorization": f"Bearer {config.siliconflow_key}" } payload = { @@ -276,7 +280,7 @@ class EmojiManager: } async with session.post( - f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + f"{config.siliconflow_base_url}chat/completions", headers=headers, json=payload ) as response: diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index 98cab286..fc2fc11f 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -1,40 +1,34 @@ from typing import Dict, Any, List, Optional, Union, Tuple from openai import OpenAI import asyncio -import requests from functools import partial from .message import Message -from .config import BotConfig, global_config +from .config import global_config from ...common.database import Database import random import time -import os import numpy as np -from dotenv import load_dotenv from .relationship_manager import relationship_manager -from ..schedule.schedule_generator import bot_schedule from .prompt_builder import prompt_builder -from .config import llm_config, global_config +from .config import global_config from .utils import process_llm_response +from nonebot import get_driver +driver = get_driver() +config = driver.config -# 获取当前文件的绝对路径 -current_dir = os.path.dirname(os.path.abspath(__file__)) -root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) -load_dotenv(os.path.join(root_dir, '.env')) class LLMResponseGenerator: - def __init__(self, config: BotConfig): - self.config = config - if self.config.API_USING == "siliconflow": + def __init__(self): + if global_config.API_USING == "siliconflow": self.client = OpenAI( - api_key=llm_config.SILICONFLOW_API_KEY, - base_url=llm_config.SILICONFLOW_BASE_URL + api_key=config.siliconflow_key, + base_url=config.siliconflow_base_url ) - elif self.config.API_USING == "deepseek": + elif global_config.API_USING == "deepseek": self.client = OpenAI( - api_key=llm_config.DEEP_SEEK_API_KEY, - base_url=llm_config.DEEP_SEEK_BASE_URL + api_key=config.deep_seek_key, + base_url=config.deep_seek_base_url ) self.db = Database.get_instance() @@ -58,6 +52,7 @@ class LLMResponseGenerator: else: self.current_model_type = 'r1_distill' # 默认使用 R1-Distill + print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++") if self.current_model_type == 'r1': model_response = await self._generate_r1_response(message) @@ -96,8 +91,9 @@ class LLMResponseGenerator: print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") else: relationship_value = 0.0 + - # 构建prompt + ''' 构建prompt ''' prompt,prompt_check = prompt_builder._build_prompt( message_txt=message.processed_plain_text, sender_name=sender_name, @@ -105,6 +101,7 @@ class LLMResponseGenerator: group_id=message.group_id ) + # 设置默认参数 default_params = { "model": model_name, @@ -121,11 +118,28 @@ class LLMResponseGenerator: "max_tokens": 2048, "temperature": 0.7 } + + default_params_check = { + "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + "messages": [{"role": "user", "content": prompt_check}], + "stream": False, + "max_tokens": 1024, + "temperature": 0.7 + } + + default_params_check = { + "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + "messages": [{"role": "user", "content": prompt_check}], + "stream": False, + "max_tokens": 1024, + "temperature": 0.7 + } # 更新参数 if model_params: default_params.update(model_params) + def create_completion(): return self.client.chat.completions.create(**default_params) @@ -135,6 +149,7 @@ class LLMResponseGenerator: loop = asyncio.get_event_loop() # 读空气模块 + air = 0 reasoning_content_check='' content_check='' if global_config.enable_kuuki_read: @@ -148,21 +163,26 @@ class LLMResponseGenerator: content_check = response_check.choices[0].message.content print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}") if 'yes' not in content_check.lower(): - self.db.db.reasoning_logs.insert_one({ - 'time': time.time(), - 'group_id': message.group_id, - 'user': sender_name, - 'message': message.processed_plain_text, - 'model': model_name, - 'reasoning_check': reasoning_content_check, - 'response_check': content_check, - 'reasoning': "", - 'response': "", - 'prompt': prompt, - 'prompt_check': prompt_check, - 'model_params': default_params - }) - return None + air = 1 + #稀释读空气的判定 + if air == 1 and random.random() < 0.3: + self.db.db.reasoning_logs.insert_one({ + 'time': time.time(), + 'group_id': message.group_id, + 'user': sender_name, + 'message': message.processed_plain_text, + 'model': model_name, + 'reasoning_check': reasoning_content_check, + 'response_check': content_check, + 'reasoning': "", + 'response': "", + 'prompt': prompt, + 'prompt_check': prompt_check, + 'model_params': default_params + }) + return None + + @@ -206,7 +226,7 @@ class LLMResponseGenerator: async def _generate_r1_response(self, message: Message) -> Optional[str]: """使用 DeepSeek-R1 模型生成回复""" - if self.config.API_USING == "deepseek": + if global_config.API_USING == "deepseek": return await self._generate_base_response( message, "deepseek-reasoner", @@ -221,7 +241,7 @@ class LLMResponseGenerator: async def _generate_v3_response(self, message: Message) -> Optional[str]: """使用 DeepSeek-V3 模型生成回复""" - if self.config.API_USING == "deepseek": + if global_config.API_USING == "deepseek": return await self._generate_base_response( message, "deepseek-chat", @@ -274,7 +294,7 @@ class LLMResponseGenerator: messages = [{"role": "user", "content": prompt}] loop = asyncio.get_event_loop() - if self.config.API_USING == "deepseek": + if global_config.API_USING == "deepseek": model = "deepseek-chat" else: model = "Pro/deepseek-ai/DeepSeek-V3" @@ -311,4 +331,4 @@ class LLMResponseGenerator: return processed_response, emotion_tags # 创建全局实例 -llm_response = LLMResponseGenerator(global_config) \ No newline at end of file +llm_response = LLMResponseGenerator() \ No newline at end of file diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index 3b6c5f79..6ad5226b 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -1,6 +1,5 @@ import time import random -from dotenv import load_dotenv from ..schedule.schedule_generator import bot_schedule import os from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text @@ -10,11 +9,6 @@ from .topic_identifier import topic_identifier from ..memory_system.memory import memory_graph from random import choice -# 获取当前文件的绝对路径 -current_dir = os.path.dirname(os.path.abspath(__file__)) -root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) -load_dotenv(os.path.join(root_dir, '.env')) - class PromptBuilder: def __init__(self): @@ -72,12 +66,15 @@ class PromptBuilder: overlapping_second_layer.update(overlap) # 合并所有需要的记忆 - if all_first_layer_items: - print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}") - if overlapping_second_layer: - print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}") + # if all_first_layer_items: + # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}") + # if overlapping_second_layer: + # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}") - all_memories = all_first_layer_items + list(overlapping_second_layer) + # 使用集合去重 + all_memories = list(set(all_first_layer_items) | set(overlapping_second_layer)) + if all_memories: + print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}") if all_memories: # 只在列表非空时选择随机项 random_item = choice(all_memories) @@ -186,15 +183,19 @@ class PromptBuilder: # prompt += f"{activate_prompt}\n" prompt += f"{prompt_personality}\n" prompt += f"{prompt_ger}\n" - prompt += f"{extra_info}\n" + prompt += f"{extra_info}\n" + + + + '''读空气prompt处理''' - activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。不要太受现在做的事情影响,因为摸鱼是很正常的。" + activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。" prompt_personality_check = '' - extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息,无论如何一定要回复,如果是刚刚理会过的人发送消息,且还在与那个人对话中的话一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。" + extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复,如果自己正在和别人聊天一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。" if personality_choice < 4/6: # 第一种人格 - prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[0]},{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' + prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧,你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' elif personality_choice < 1: # 第二种人格 - prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[1]},{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' + prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}" diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py index fc8a915c..57458b23 100644 --- a/src/plugins/chat/topic_identifier.py +++ b/src/plugins/chat/topic_identifier.py @@ -1,14 +1,17 @@ from typing import Optional, Dict, List from openai import OpenAI from .message import Message -from .config import global_config, llm_config import jieba +from nonebot import get_driver + +driver = get_driver() +config = driver.config class TopicIdentifier: def __init__(self): self.client = OpenAI( - api_key=llm_config.SILICONFLOW_API_KEY, - base_url=llm_config.SILICONFLOW_BASE_URL + api_key=config.siliconflow_key, + base_url=config.siliconflow_base_url ) def identify_topic_llm(self, text: str) -> Optional[str]: diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 441eb700..78dd082b 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -4,11 +4,15 @@ from typing import List from .message import Message import requests import numpy as np -from .config import llm_config, global_config +from .config import global_config import re from typing import Dict from collections import Counter import math +from nonebot import get_driver + +driver = get_driver() +config = driver.config def combine_messages(messages: List[Message]) -> str: @@ -64,7 +68,7 @@ def get_embedding(text): "encoding_format": "float" } headers = { - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}", + "Authorization": f"Bearer {config.siliconflow_key}", "Content-Type": "application/json" } diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 9fe2c40c..68b2fa7f 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -7,6 +7,10 @@ from ...common.database import Database import zlib # 用于 CRC32 import base64 from .config import global_config +from nonebot import get_driver + +driver = get_driver() +config = driver.config def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes: @@ -37,12 +41,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes: # 连接数据库 db = Database( - host=global_config.MONGODB_HOST, - port=global_config.MONGODB_PORT, - db_name=global_config.DATABASE_NAME, - username=global_config.MONGODB_USERNAME, - password=global_config.MONGODB_PASSWORD, - auth_source=global_config.MONGODB_AUTH_SOURCE + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source=config.mongodb_auth_source ) # 检查是否已存在相同哈希值的图片 diff --git a/src/plugins/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py index 40756b41..d8c2e148 100644 --- a/src/plugins/knowledege/knowledge_library.py +++ b/src/plugins/knowledege/knowledge_library.py @@ -3,6 +3,10 @@ import sys import numpy as np import requests import time +from nonebot import get_driver + +driver = get_driver() +config = driver.config # 添加项目根目录到 Python 路径 root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) @@ -13,9 +17,12 @@ from src.plugins.chat.config import llm_config # 直接配置数据库连接信息 Database.initialize( - "127.0.0.1", # MongoDB 主机 - 27017, # MongoDB 端口 - "MegBot" # 数据库名称 + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source=config.mongodb_auth_source ) class KnowledgeLibrary: diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py index a92a47fa..e56de16c 100644 --- a/src/plugins/memory_system/draw_memory.py +++ b/src/plugins/memory_system/draw_memory.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import os import sys import jieba from llm_module import LLMModel @@ -157,9 +158,12 @@ class Memory_graph: def main(): # 初始化数据库 Database.initialize( - "127.0.0.1", - 27017, - "MegBot" + host= os.getenv("MONGODB_HOST"), + port= int(os.getenv("MONGODB_PORT")), + db_name= os.getenv("DATABASE_NAME"), + username= os.getenv("MONGODB_USERNAME"), + password= os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE") ) memory_graph = Memory_graph() @@ -168,10 +172,12 @@ def main(): memory_graph.load_graph_from_db() # 展示两种不同的可视化方式 print("\n按连接数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=False) + # visualize_graph(memory_graph, color_by_memory=False) + visualize_graph_lite(memory_graph, color_by_memory=False) print("\n按记忆数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=True) + # visualize_graph(memory_graph, color_by_memory=True) + visualize_graph_lite(memory_graph, color_by_memory=True) # memory_graph.save_graph_to_db() @@ -262,7 +268,89 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): plt.title(title, fontsize=16, fontfamily='SimHei') plt.show() -if __name__ == "__main__": - main() +def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False): + # 设置中文字体 + plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 + plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 + G = memory_graph.G + + # 创建一个新图用于可视化 + H = G.copy() + + # 移除只有一条记忆的节点和连接数少于3的节点 + nodes_to_remove = [] + for node in H.nodes(): + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + degree = H.degree(node) + if memory_count <= 2 or degree <= 2: + nodes_to_remove.append(node) + + H.remove_nodes_from(nodes_to_remove) + + # 如果过滤后没有节点,则返回 + if len(H.nodes()) == 0: + print("过滤后没有符合条件的节点可显示") + return + + # 保存图到本地 + nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式 + + # 根据连接条数或记忆数量设置节点颜色 + node_colors = [] + nodes = list(H.nodes()) # 获取图中实际的节点列表 + + if color_by_memory: + # 计算每个节点的记忆数量 + memory_counts = [] + for node in nodes: + memory_items = H.nodes[node].get('memory_items', []) + if isinstance(memory_items, list): + count = len(memory_items) + else: + count = 1 if memory_items else 0 + memory_counts.append(count) + max_memories = max(memory_counts) if memory_counts else 1 + + for count in memory_counts: + # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少 + if max_memories > 0: + intensity = min(1.0, count / max_memories) + color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色 + else: + color = (0, 0, 1) # 如果没有记忆,则为蓝色 + node_colors.append(color) + else: + # 使用原来的连接数量着色方案 + max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1 + for node in nodes: + degree = H.degree(node) + if max_degree > 0: + red = min(1.0, degree / max_degree) + blue = 1.0 - red + color = (red, 0, blue) + else: + color = (0, 0, 1) + node_colors.append(color) + + # 绘制图形 + plt.figure(figsize=(12, 8)) + pos = nx.spring_layout(H, k=1, iterations=50) + nx.draw(H, pos, + with_labels=True, + node_color=node_colors, + node_size=2000, + font_size=10, + font_family='SimHei', + font_weight='bold') + + title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色') + plt.title(title, fontsize=16, fontfamily='SimHei') + plt.show() + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/plugins/memory_system/llm_module.py b/src/plugins/memory_system/llm_module.py index fa879afd..bd7f60dc 100644 --- a/src/plugins/memory_system/llm_module.py +++ b/src/plugins/memory_system/llm_module.py @@ -1,19 +1,19 @@ import os import requests -from dotenv import load_dotenv from typing import Tuple, Union import time +from nonebot import get_driver -# 加载环境变量 -load_dotenv() +driver = get_driver() +config = driver.config class LLMModel: # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): self.model_name = model_name self.params = kwargs - self.api_key = os.getenv("SILICONFLOW_KEY") - self.base_url = os.getenv("SILICONFLOW_BASE_URL") + self.api_key = config.siliconflow_key + self.base_url = config.siliconflow_base_url def generate_response(self, prompt: str) -> Tuple[str, str]: """根据输入的提示生成模型的响应""" diff --git a/src/plugins/memory_system/llm_module_memory_make.py b/src/plugins/memory_system/llm_module_memory_make.py index 1abfdb2c..04ab6dbc 100644 --- a/src/plugins/memory_system/llm_module_memory_make.py +++ b/src/plugins/memory_system/llm_module_memory_make.py @@ -1,30 +1,20 @@ import os import requests -from dotenv import load_dotenv from typing import Tuple, Union import time from ..chat.config import BotConfig +from nonebot import get_driver -# 获取当前文件的绝对路径 -current_dir = os.path.dirname(os.path.abspath(__file__)) -root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) -env_path = os.path.join(root_dir, 'config', '.env') - -# 加载环境变量 -print(f"尝试从 {env_path} 加载环境变量配置") -if os.path.exists(env_path): - load_dotenv(env_path) - print("成功加载环境变量配置") -else: - print(f"环境变量配置文件不存在: {env_path}") +driver = get_driver() +config = driver.config class LLMModel: # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): self.model_name = model_name self.params = kwargs - self.api_key = os.getenv("SILICONFLOW_KEY") - self.base_url = os.getenv("SILICONFLOW_BASE_URL") + self.api_key = config.siliconflow_key + self.base_url = config.siliconflow_base_url if not self.api_key or not self.base_url: raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置") diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index d95712d7..f4962814 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import os import jieba from .llm_module import LLMModel import networkx as nx @@ -197,8 +198,6 @@ class Hippocampus: time_frequency = {'near':1,'mid':2,'far':2} memory_sample = self.get_memory_sample(chat_size,time_frequency) # print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}") - - for i, input_text in enumerate(memory_sample, 1): #加载进度可视化 progress = (i / len(memory_sample)) * 100 @@ -206,26 +205,25 @@ class Hippocampus: filled_length = int(bar_length * i // len(memory_sample)) bar = '█' * filled_length + '-' * (bar_length - filled_length) print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") - - # 生成压缩后记忆 - first_memory = set() - first_memory = self.memory_compress(input_text, 2.5) - # 延时防止访问超频 - # time.sleep(60) - #将记忆加入到图谱中 - for topic, memory in first_memory: - topics = segment_text(topic) - if '[' in topic or topic=='': - continue - print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") - for split_topic in topics: - self.memory_graph.add_dot(split_topic,memory) - for split_topic in topics: - for other_split_topic in topics: - if split_topic != other_split_topic: - self.memory_graph.connect_dot(split_topic, other_split_topic) - - self.memory_graph.save_graph_to_db() + if input_text: + # 生成压缩后记忆 + first_memory = set() + first_memory = self.memory_compress(input_text, 2.5) + # 延时防止访问超频 + # time.sleep(5) + #将记忆加入到图谱中 + for topic, memory in first_memory: + topics = segment_text(topic) + print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") + for split_topic in topics: + self.memory_graph.add_dot(split_topic,memory) + for split_topic in topics: + for other_split_topic in topics: + if split_topic != other_split_topic: + self.memory_graph.connect_dot(split_topic, other_split_topic) + else: + print(f"空消息 跳过") + self.memory_graph.save_graph_to_db() def memory_compress(self, input_text, rate=1): information_content = calculate_information_content(input_text) @@ -263,13 +261,19 @@ def topic_what(text, topic): return prompt - +from nonebot import get_driver +driver = get_driver() +config = driver.config + start_time = time.time() Database.initialize( - global_config.MONGODB_HOST, - global_config.MONGODB_PORT, - global_config.DATABASE_NAME + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source=config.mongodb_auth_source ) #创建记忆图 memory_graph = Memory_graph() diff --git a/src/plugins/memory_system/memory_make.py b/src/plugins/memory_system/memory_make.py index 74de9070..02c61945 100644 --- a/src/plugins/memory_system/memory_make.py +++ b/src/plugins/memory_system/memory_make.py @@ -9,12 +9,42 @@ import datetime import random import time import os -from dotenv import load_dotenv # from chat.config import global_config sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 from src.common.database import Database # 使用正确的导入语法 from src.plugins.memory_system.llm_module import LLMModel - + +def calculate_information_content(text): + """计算文本的信息量(熵)""" + # 统计字符频率 + char_count = Counter(text) + total_chars = len(text) + + # 计算熵 + entropy = 0 + for count in char_count.values(): + probability = count / total_chars + entropy -= probability * math.log2(probability) + + return entropy + +def get_cloest_chat_from_db(db, length: int, timestamp: str): + """从数据库中获取最接近指定时间戳的聊天记录""" + chat_text = '' + closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) + + if closest_record: + closest_time = closest_record['time'] + group_id = closest_record['group_id'] # 获取groupid + # 获取该时间戳之后的length条消息,且groupid相同 + chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) + for record in chat_record: + time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) + chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' + return chat_text + + return '' + class Memory_graph: def __init__(self): self.G = nx.Graph() # 使用 networkx 的图结构 @@ -103,7 +133,8 @@ class Memory_graph: # 从数据库中根据时间戳获取离其最近的聊天记录 chat_text = '' closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 - print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") + + # print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") if closest_record: closest_time = closest_record['time'] @@ -192,166 +223,80 @@ class Memory_graph: for edge in edges: self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1)) -def calculate_information_content(text): - - """计算文本的信息量(熵)""" - # 统计字符频率 - char_count = Counter(text) - total_chars = len(text) - - # 计算熵 - entropy = 0 - for count in char_count.values(): - probability = count / total_chars - entropy -= probability * math.log2(probability) - - return entropy - - -# Database.initialize( -# global_config.MONGODB_HOST, -# global_config.MONGODB_PORT, -# global_config.DATABASE_NAME -# ) -# memory_graph = Memory_graph() - -# llm_model = LLMModel() -# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - -# memory_graph.load_graph_from_db() - - - -def main(): - # 获取当前文件的绝对路径 - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) - env_path = os.path.join(root_dir, 'config', '.env') - - # 加载环境变量 - print(f"尝试从 {env_path} 加载环境变量配置") - if os.path.exists(env_path): - load_dotenv(env_path) - print("成功加载环境变量配置") - else: - print(f"环境变量配置文件不存在: {env_path}") - - # 初始化数据库 - Database.initialize( - "127.0.0.1", - 27017, - "MegBot" - ) - - memory_graph = Memory_graph() - # 创建LLM模型实例 - llm_model = LLMModel() - llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - - # 使用当前时间戳进行测试 - current_timestamp = datetime.datetime.now().timestamp() - chat_text = [] - - chat_size =25 - - for _ in range(30): # 循环10次 - random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间 - print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") - chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time) - chat_text.append(chat_) # 拼接所有text - # time.sleep(1) - - - - for i, input_text in enumerate(chat_text, 1): +# 海马体 +class Hippocampus: + def __init__(self,memory_graph:Memory_graph): + self.memory_graph = memory_graph + self.llm_model = LLMModel() + self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - progress = (i / len(chat_text)) * 100 - bar_length = 30 - filled_length = int(bar_length * i // len(chat_text)) - bar = '█' * filled_length + '-' * (bar_length - filled_length) - print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})") + def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}): + current_timestamp = datetime.datetime.now().timestamp() + chat_text = [] + #短期:1h 中期:4h 长期:24h + for _ in range(time_frequency.get('near')): # 循环10次 + random_time = current_timestamp - random.randint(1, 3600) # 随机时间 + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + for _ in range(time_frequency.get('mid')): # 循环10次 + random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间 + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + for _ in range(time_frequency.get('far')): # 循环10次 + random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间 + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + return chat_text + + def build_memory(self,chat_size=12): + #最近消息获取频率 + time_frequency = {'near':1,'mid':2,'far':2} + memory_sample = self.get_memory_sample(chat_size,time_frequency) - # print(input_text) - first_memory = set() - first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5) - # time.sleep(5) - - #将记忆加入到图谱中 - for topic, memory in first_memory: - # continue - topics = segment_text(topic) - print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") - for split_topic in topics: - memory_graph.add_dot(split_topic,memory) - for split_topic in topics: - for other_split_topic in topics: - if split_topic != other_split_topic: - memory_graph.connect_dot(split_topic, other_split_topic) - - # memory_graph.store_memory() - - # 展示两种不同的可视化方式 - print("\n按连接数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=False) - - print("\n按记忆数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=True) - - memory_graph.save_graph_to_db() - # memory_graph.load_graph_from_db() - - while True: - query = input("请输入新的查询概念(输入'退出'以结束):") - if query.lower() == '退出': - break - items_list = memory_graph.get_related_item(query) - if items_list: - # print(items_list) - for memory_item in items_list: - print(memory_item) - else: - print("未找到相关记忆。") + #加载进度可视化 + for i, input_text in enumerate(memory_sample, 1): + progress = (i / len(memory_sample)) * 100 + bar_length = 30 + filled_length = int(bar_length * i // len(memory_sample)) + bar = '█' * filled_length + '-' * (bar_length - filled_length) + print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") + # print(f"第{i}条消息: {input_text}") + if input_text: + # 生成压缩后记忆 + first_memory = set() + first_memory = self.memory_compress(input_text, 2.5) + #将记忆加入到图谱中 + for topic, memory in first_memory: + topics = segment_text(topic) + print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") + for split_topic in topics: + self.memory_graph.add_dot(split_topic,memory) + for split_topic in topics: + for other_split_topic in topics: + if split_topic != other_split_topic: + self.memory_graph.connect_dot(split_topic, other_split_topic) + else: + print(f"空消息 跳过") - while True: - query = input("请输入问题:") - - if query.lower() == '退出': - break - - topic_prompt = find_topic(query, 3) - topic_response = llm_model.generate_response(topic_prompt) + self.memory_graph.save_graph_to_db() + + def memory_compress(self, input_text, rate=1): + information_content = calculate_information_content(input_text) + print(f"文本的信息量(熵): {information_content:.4f} bits") + topic_num = max(1, min(5, int(information_content * rate / 4))) + topic_prompt = find_topic(input_text, topic_num) + topic_response = self.llm_model.generate_response(topic_prompt) # 检查 topic_response 是否为元组 if isinstance(topic_response, tuple): topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 else: topics = topic_response.split(",") - print(topics) - - for keyword in topics: - items_list = memory_graph.get_related_item(keyword) - if items_list: - print(items_list) - -def memory_compress(input_text, llm_model, llm_model_small, rate=1): - information_content = calculate_information_content(input_text) - print(f"文本的信息量(熵): {information_content:.4f} bits") - topic_num = max(1, min(5, int(information_content * rate / 4))) - print(topic_num) - topic_prompt = find_topic(input_text, topic_num) - topic_response = llm_model.generate_response(topic_prompt) - # 检查 topic_response 是否为元组 - if isinstance(topic_response, tuple): - topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 - else: - topics = topic_response.split(",") - print(topics) - compressed_memory = set() - for topic in topics: - topic_what_prompt = topic_what(input_text,topic) - topic_what_response = llm_model_small.generate_response(topic_what_prompt) - compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储 - return compressed_memory - + compressed_memory = set() + for topic in topics: + topic_what_prompt = topic_what(input_text,topic) + topic_what_response = self.llm_model_small.generate_response(topic_what_prompt) + compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储 + return compressed_memory def segment_text(text): seg_text = list(jieba.cut(text)) @@ -372,18 +317,37 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): G = memory_graph.G + # 创建一个新图用于可视化 + H = G.copy() + + # 移除只有一条记忆的节点和连接数少于3的节点 + nodes_to_remove = [] + for node in H.nodes(): + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + degree = H.degree(node) + if memory_count <= 1 or degree <= 2: + nodes_to_remove.append(node) + + H.remove_nodes_from(nodes_to_remove) + + # 如果过滤后没有节点,则返回 + if len(H.nodes()) == 0: + print("过滤后没有符合条件的节点可显示") + return + # 保存图到本地 - nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式 + nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式 # 根据连接条数或记忆数量设置节点颜色 node_colors = [] - nodes = list(G.nodes()) # 获取图中实际的节点列表 + nodes = list(H.nodes()) # 获取图中实际的节点列表 if color_by_memory: # 计算每个节点的记忆数量 memory_counts = [] for node in nodes: - memory_items = G.nodes[node].get('memory_items', []) + memory_items = H.nodes[node].get('memory_items', []) if isinstance(memory_items, list): count = len(memory_items) else: @@ -401,9 +365,9 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): node_colors.append(color) else: # 使用原来的连接数量着色方案 - max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 + max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1 for node in nodes: - degree = G.degree(node) + degree = H.degree(node) if max_degree > 0: red = min(1.0, degree / max_degree) blue = 1.0 - red @@ -414,8 +378,8 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): # 绘制图形 plt.figure(figsize=(12, 8)) - pos = nx.spring_layout(G, k=1, iterations=50) - nx.draw(G, pos, + pos = nx.spring_layout(H, k=1, iterations=50) + nx.draw(H, pos, with_labels=True, node_color=node_colors, node_size=2000, @@ -427,6 +391,71 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): plt.title(title, fontsize=16, fontfamily='SimHei') plt.show() +def main(): + # 初始化数据库 + Database.initialize( + host= os.getenv("MONGODB_HOST"), + port= int(os.getenv("MONGODB_PORT")), + db_name= os.getenv("DATABASE_NAME"), + username= os.getenv("MONGODB_USERNAME"), + password= os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE") + ) + + start_time = time.time() + + # 创建记忆图 + memory_graph = Memory_graph() + # 加载数据库中存储的记忆图 + memory_graph.load_graph_from_db() + # 创建海马体 + hippocampus = Hippocampus(memory_graph) + + end_time = time.time() + print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m") + + # 构建记忆 + hippocampus.build_memory(chat_size=25) + + # 展示两种不同的可视化方式 + print("\n按连接数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=False) + + print("\n按记忆数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=True) + + # 交互式查询 + while True: + query = input("请输入新的查询概念(输入'退出'以结束):") + if query.lower() == '退出': + break + items_list = memory_graph.get_related_item(query) + if items_list: + for memory_item in items_list: + print(memory_item) + else: + print("未找到相关记忆。") + + while True: + query = input("请输入问题:") + + if query.lower() == '退出': + break + + topic_prompt = find_topic(query, 3) + topic_response = hippocampus.llm_model.generate_response(topic_prompt) + # 检查 topic_response 是否为元组 + if isinstance(topic_response, tuple): + topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 + else: + topics = topic_response.split(",") + print(topics) + + for keyword in topics: + items_list = memory_graph.get_related_item(keyword) + if items_list: + print(items_list) + if __name__ == "__main__": main() diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index 8933ef4f..a712141f 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -2,29 +2,21 @@ import datetime import os from typing import List, Dict from .schedule_llm_module import LLMModel -from dotenv import load_dotenv from ...common.database import Database # 使用正确的导入语法 from ..chat.config import global_config +from nonebot import get_driver +driver = get_driver() +config = driver.config -# import sys -# sys.path.append("C:/GitHub/MegMeg-bot") # 添加项目根目录到 Python 路径 -# from src.plugins.schedule.schedule_llm_module import LLMModel -# from src.common.database import Database # 使用正确的导入语法 - -# 获取当前文件的绝对路径 -#TODO: 这个好几个地方用需要封装 -current_dir = os.path.dirname(os.path.abspath(__file__)) -root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) -load_dotenv(os.path.join(root_dir, '.env')) Database.initialize( - host= os.getenv("MONGODB_HOST"), - port= int(os.getenv("MONGODB_PORT")), - db_name= os.getenv("DATABASE_NAME"), - username= os.getenv("MONGODB_USERNAME"), - password= os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source=config.mongodb_auth_source ) class ScheduleGenerator: diff --git a/src/plugins/schedule/schedule_llm_module.py b/src/plugins/schedule/schedule_llm_module.py index 13945afb..408e7d54 100644 --- a/src/plugins/schedule/schedule_llm_module.py +++ b/src/plugins/schedule/schedule_llm_module.py @@ -1,24 +1,24 @@ import os import requests -from dotenv import load_dotenv from typing import Tuple, Union +from nonebot import get_driver -# 加载环境变量 -load_dotenv() +driver = get_driver() +config = driver.config class LLMModel: # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs): if api_using == "deepseek": - self.api_key = os.getenv("DEEPSEEK_API_KEY") - self.base_url = os.getenv("DEEPSEEK_BASE_URL") + self.api_key = config.deep_seek_key + self.base_url = config.deep_seek_base_url if model_name != "Pro/deepseek-ai/DeepSeek-R1": self.model_name = model_name else: self.model_name = "deepseek-reasoner" else: - self.api_key = os.getenv("SILICONFLOW_KEY") - self.base_url = os.getenv("SILICONFLOW_BASE_URL") + self.api_key = config.siliconflow_key + self.base_url = config.siliconflow_base_url self.model_name = model_name self.params = kwargs