From 4baa6c6f0aa2986a8fa6a7f7425a2dcb315d6478 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=A5=E6=B2=B3=E6=99=B4?= Date: Mon, 10 Mar 2025 14:48:43 +0900 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0MongoDB=20URI?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E8=BF=9E=E6=8E=A5=EF=BC=8C=E5=B9=B6=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E6=95=B0=E6=8D=AE=E5=BA=93=E8=BF=9E=E6=8E=A5=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/database.py | 36 ++++++-- src/gui/reasoning_gui.py | 53 ++++-------- src/plugins/chat/__init__.py | 1 + src/plugins/chat/utils_image.py | 9 +- src/plugins/knowledege/knowledge_library.py | 5 +- src/plugins/memory_system/draw_memory.py | 7 +- src/plugins/memory_system/memory.py | 8 +- .../memory_system/memory_manual_build.py | 82 ++++++------------- src/plugins/schedule/schedule_generator.py | 2 +- 9 files changed, 82 insertions(+), 121 deletions(-) diff --git a/src/common/database.py b/src/common/database.py index 45ac05da..f0954b07 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -6,20 +6,44 @@ from pymongo import MongoClient class Database: _instance: Optional["Database"] = None - def __init__(self, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None): - if username and password: + def __init__( + self, + host: str, + port: int, + db_name: str, + username: Optional[str] = None, + password: Optional[str] = None, + auth_source: Optional[str] = None, + uri: Optional[str] = None, + ): + if uri and uri.startswith("mongodb://"): + # 优先使用URI连接 + self.client = MongoClient(uri) + elif username and password: # 如果有用户名和密码,使用认证连接 - # TODO: 复杂情况直接支持URI吧 - self.client = MongoClient(host, port, username=username, password=password, authSource=auth_source) + self.client = MongoClient( + host, port, username=username, password=password, authSource=auth_source + ) else: # 否则使用无认证连接 self.client = MongoClient(host, port) self.db = self.client[db_name] @classmethod - def initialize(cls, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None) -> "Database": + def initialize( + cls, + host: str, + port: int, + db_name: str, + username: Optional[str] = None, + password: Optional[str] = None, + auth_source: Optional[str] = None, + uri: Optional[str] = None, + ) -> "Database": if cls._instance is None: - cls._instance = cls(host, port, db_name, username, password, auth_source) + cls._instance = cls( + host, port, db_name, username, password, auth_source, uri + ) return cls._instance @classmethod diff --git a/src/gui/reasoning_gui.py b/src/gui/reasoning_gui.py index 514a95df..dd62e063 100644 --- a/src/gui/reasoning_gui.py +++ b/src/gui/reasoning_gui.py @@ -7,7 +7,7 @@ from datetime import datetime from typing import Dict, List from loguru import logger from typing import Optional -from pymongo import MongoClient +from ..common.database import Database import customtkinter as ctk from dotenv import load_dotenv @@ -28,38 +28,6 @@ else: logger.error("未找到环境配置文件") sys.exit(1) - -class Database: - _instance: Optional["Database"] = None - - def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None, - auth_source: str = None): - if username and password: - self.client = MongoClient( - host=host, - port=port, - username=username, - password=password, - authSource=auth_source or 'admin' - ) - else: - self.client = MongoClient(host, port) - self.db = self.client[db_name] - - @classmethod - def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None, - auth_source: str = None) -> "Database": - if cls._instance is None: - cls._instance = cls(host, port, db_name, username, password, auth_source) - return cls._instance - - @classmethod - def get_instance(cls) -> "Database": - if cls._instance is None: - raise RuntimeError("Database not initialized") - return cls._instance - - class ReasoningGUI: def __init__(self): # 记录启动时间戳,转换为Unix时间戳 @@ -83,7 +51,15 @@ class ReasoningGUI: except RuntimeError: logger.warning("数据库未初始化,正在尝试初始化...") try: - Database.initialize("127.0.0.1", 27017, "maimai_bot") + Database.initialize( + uri=os.getenv("MONGODB_URI"), + host=os.getenv("MONGODB_HOST", "127.0.0.1"), + port=int(os.getenv("MONGODB_PORT", "27017")), + db_name=os.getenv("DATABASE_NAME", "MegBot"), + username=os.getenv("MONGODB_USERNAME"), + password=os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), + ) self.db = Database.get_instance().db logger.success("数据库初始化成功") except Exception: @@ -359,12 +335,13 @@ class ReasoningGUI: def main(): """主函数""" Database.initialize( - host=os.getenv("MONGODB_HOST"), - port=int(os.getenv("MONGODB_PORT")), - db_name=os.getenv("DATABASE_NAME"), + uri=os.getenv("MONGODB_URI"), + host=os.getenv("MONGODB_HOST", "127.0.0.1"), + port=int(os.getenv("MONGODB_PORT", "27017")), + db_name=os.getenv("DATABASE_NAME", "MegBot"), username=os.getenv("MONGODB_USERNAME"), password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), ) app = ReasoningGUI() diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index bd71be01..36d558d1 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -31,6 +31,7 @@ driver = get_driver() config = driver.config Database.initialize( + uri=config.MONGODB_URI, host=config.MONGODB_HOST, port=int(config.MONGODB_PORT), db_name=config.DATABASE_NAME, diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 8a8b3ce5..7e57560c 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -37,14 +37,7 @@ def storage_compress_image(base64_data: str, max_size: int = 200) -> str: os.makedirs(images_dir, exist_ok=True) # 连接数据库 - db = Database( - host=config.mongodb_host, - port=int(config.mongodb_port), - db_name=config.database_name, - username=config.mongodb_username, - password=config.mongodb_password, - auth_source=config.mongodb_auth_source - ) + db = Database.get_instance() # 检查是否已存在相同哈希值的图片 collection = db.db['images'] diff --git a/src/plugins/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py index 48107696..99e2f842 100644 --- a/src/plugins/knowledege/knowledge_library.py +++ b/src/plugins/knowledege/knowledge_library.py @@ -19,12 +19,13 @@ from src.common.database import Database # 从环境变量获取配置 Database.initialize( + uri=os.getenv("MONGODB_URI"), host=os.getenv("MONGODB_HOST", "127.0.0.1"), port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "maimai"), + db_name=os.getenv("DATABASE_NAME", "MegBot"), username=os.getenv("MONGODB_USERNAME"), password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE", "admin") + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), ) class KnowledgeLibrary: diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py index 6da330d9..ffe2ba42 100644 --- a/src/plugins/memory_system/draw_memory.py +++ b/src/plugins/memory_system/draw_memory.py @@ -162,12 +162,13 @@ class Memory_graph: def main(): # 初始化数据库 Database.initialize( + uri=os.getenv("MONGODB_URI"), host=os.getenv("MONGODB_HOST", "127.0.0.1"), port=int(os.getenv("MONGODB_PORT", "27017")), db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME", ""), - password=os.getenv("MONGODB_PASSWORD", ""), - auth_source=os.getenv("MONGODB_AUTH_SOURCE", "") + username=os.getenv("MONGODB_USERNAME"), + password=os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), ) memory_graph = Memory_graph() diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index 9b325b36..b894aa6f 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -8,6 +8,7 @@ import jieba import networkx as nx from loguru import logger +from nonebot import get_driver from ...common.database import Database # 使用正确的导入语法 from ..chat.config import global_config from ..chat.utils import ( @@ -18,7 +19,6 @@ from ..chat.utils import ( ) from ..models.utils_model import LLM_request - class Memory_graph: def __init__(self): self.G = nx.Graph() # 使用 networkx 的图结构 @@ -130,7 +130,7 @@ class Memory_graph: return None -# 海马体 +# 海马体 class Hippocampus: def __init__(self, memory_graph: Memory_graph): self.memory_graph = memory_graph @@ -749,15 +749,13 @@ def segment_text(text): seg_text = list(jieba.cut(text)) return seg_text - -from nonebot import get_driver - driver = get_driver() config = driver.config start_time = time.time() Database.initialize( + uri=config.MONGODB_URI, host=config.MONGODB_HOST, port=config.MONGODB_PORT, db_name=config.DATABASE_NAME, diff --git a/src/plugins/memory_system/memory_manual_build.py b/src/plugins/memory_system/memory_manual_build.py index 3c120f21..3a2961b6 100644 --- a/src/plugins/memory_system/memory_manual_build.py +++ b/src/plugins/memory_system/memory_manual_build.py @@ -35,45 +35,6 @@ else: logger.warning(f"未找到环境变量文件: {env_path}") logger.info("将使用默认配置") -class Database: - _instance = None - db = None - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls() - return cls._instance - - def __init__(self): - if not Database.db: - Database.initialize( - host=os.getenv("MONGODB_HOST"), - port=int(os.getenv("MONGODB_PORT")), - db_name=os.getenv("DATABASE_NAME"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") - ) - - @classmethod - def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"): - try: - if username and password: - uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}" - else: - uri = f"mongodb://{host}:{port}" - - client = pymongo.MongoClient(uri) - cls.db = client[db_name] - # 测试连接 - client.server_info() - logger.success("MongoDB连接成功!") - - except Exception as e: - logger.error(f"初始化MongoDB失败: {str(e)}") - raise - def calculate_information_content(text): """计算文本的信息量(熵)""" char_count = Counter(text) @@ -202,7 +163,7 @@ class Memory_graph: # 返回所有节点对应的 Memory_dot 对象 return [self.get_dot(node) for node in self.G.nodes()] -# 海马体 +# 海马体 class Hippocampus: def __init__(self, memory_graph: Memory_graph): self.memory_graph = memory_graph @@ -941,59 +902,67 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal async def main(): # 初始化数据库 logger.info("正在初始化数据库连接...") - db = Database.get_instance() + Database.initialize( + uri=os.getenv("MONGODB_URI"), + host=os.getenv("MONGODB_HOST", "127.0.0.1"), + port=int(os.getenv("MONGODB_PORT", "27017")), + db_name=os.getenv("DATABASE_NAME", "MegBot"), + username=os.getenv("MONGODB_USERNAME"), + password=os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), + ) start_time = time.time() - + test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False} - + # 创建记忆图 memory_graph = Memory_graph() - + # 创建海马体 hippocampus = Hippocampus(memory_graph) - + # 从数据库同步数据 hippocampus.sync_memory_from_db() - + end_time = time.time() logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m") - + # 构建记忆 if test_pare['do_build_memory']: logger.info("开始构建记忆...") chat_size = 20 await hippocampus.operation_build_memory(chat_size=chat_size) - + end_time = time.time() logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m") - + if test_pare['do_forget_topic']: logger.info("开始遗忘记忆...") await hippocampus.operation_forget_topic(percentage=0.1) - + end_time = time.time() logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m") - + if test_pare['do_merge_memory']: logger.info("开始合并记忆...") await hippocampus.operation_merge_memory(percentage=0.1) - + end_time = time.time() logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m") - + if test_pare['do_visualize_graph']: # 展示优化后的图形 logger.info("生成记忆图谱可视化...") print("\n生成优化后的记忆图谱:") visualize_graph_lite(memory_graph) - + if test_pare['do_query']: # 交互式查询 while True: query = input("\n请输入新的查询概念(输入'退出'以结束):") if query.lower() == '退出': break - + items_list = memory_graph.get_related_item(query) if items_list: first_layer, second_layer = items_list @@ -1008,9 +977,6 @@ async def main(): else: print("未找到相关记忆。") - if __name__ == "__main__": import asyncio asyncio.run(main()) - - diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index fc07a152..b968d43c 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -14,6 +14,7 @@ driver = get_driver() config = driver.config Database.initialize( + uri=config.MONGODB_URI, host=config.MONGODB_HOST, port=int(config.MONGODB_PORT), db_name=config.DATABASE_NAME, @@ -22,7 +23,6 @@ Database.initialize( auth_source=config.MONGODB_AUTH_SOURCE ) - class ScheduleGenerator: def __init__(self): # 根据global_config.llm_normal这一字典配置指定模型