feat: 实现MongoDB URI方式连接,并统一数据库连接代码。

pull/157/head
春河晴 2025-03-10 14:48:43 +09:00
parent c9f12446c0
commit 4baa6c6f0a
No known key found for this signature in database
9 changed files with 82 additions and 121 deletions

View File

@ -6,20 +6,44 @@ from pymongo import MongoClient
class Database: class Database:
_instance: Optional["Database"] = None _instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None): def __init__(
if username and password: self,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
):
if uri and uri.startswith("mongodb://"):
# 优先使用URI连接
self.client = MongoClient(uri)
elif username and password:
# 如果有用户名和密码,使用认证连接 # 如果有用户名和密码,使用认证连接
# TODO: 复杂情况直接支持URI吧 self.client = MongoClient(
self.client = MongoClient(host, port, username=username, password=password, authSource=auth_source) host, port, username=username, password=password, authSource=auth_source
)
else: else:
# 否则使用无认证连接 # 否则使用无认证连接
self.client = MongoClient(host, port) self.client = MongoClient(host, port)
self.db = self.client[db_name] self.db = self.client[db_name]
@classmethod @classmethod
def initialize(cls, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None) -> "Database": def initialize(
cls,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
) -> "Database":
if cls._instance is None: if cls._instance is None:
cls._instance = cls(host, port, db_name, username, password, auth_source) cls._instance = cls(
host, port, db_name, username, password, auth_source, uri
)
return cls._instance return cls._instance
@classmethod @classmethod

View File

@ -7,7 +7,7 @@ from datetime import datetime
from typing import Dict, List from typing import Dict, List
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from pymongo import MongoClient from ..common.database import Database
import customtkinter as ctk import customtkinter as ctk
from dotenv import load_dotenv from dotenv import load_dotenv
@ -28,38 +28,6 @@ else:
logger.error("未找到环境配置文件") logger.error("未找到环境配置文件")
sys.exit(1) sys.exit(1)
class Database:
_instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None,
auth_source: str = None):
if username and password:
self.client = MongoClient(
host=host,
port=port,
username=username,
password=password,
authSource=auth_source or 'admin'
)
else:
self.client = MongoClient(host, port)
self.db = self.client[db_name]
@classmethod
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None,
auth_source: str = None) -> "Database":
if cls._instance is None:
cls._instance = cls(host, port, db_name, username, password, auth_source)
return cls._instance
@classmethod
def get_instance(cls) -> "Database":
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance
class ReasoningGUI: class ReasoningGUI:
def __init__(self): def __init__(self):
# 记录启动时间戳转换为Unix时间戳 # 记录启动时间戳转换为Unix时间戳
@ -83,7 +51,15 @@ class ReasoningGUI:
except RuntimeError: except RuntimeError:
logger.warning("数据库未初始化,正在尝试初始化...") logger.warning("数据库未初始化,正在尝试初始化...")
try: try:
Database.initialize("127.0.0.1", 27017, "maimai_bot") Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance().db self.db = Database.get_instance().db
logger.success("数据库初始化成功") logger.success("数据库初始化成功")
except Exception: except Exception:
@ -359,12 +335,13 @@ class ReasoningGUI:
def main(): def main():
"""主函数""" """主函数"""
Database.initialize( Database.initialize(
host=os.getenv("MONGODB_HOST"), uri=os.getenv("MONGODB_URI"),
port=int(os.getenv("MONGODB_PORT")), host=os.getenv("MONGODB_HOST", "127.0.0.1"),
db_name=os.getenv("DATABASE_NAME"), port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"), username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"), password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE") auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
) )
app = ReasoningGUI() app = ReasoningGUI()

View File

@ -31,6 +31,7 @@ driver = get_driver()
config = driver.config config = driver.config
Database.initialize( Database.initialize(
uri=config.MONGODB_URI,
host=config.MONGODB_HOST, host=config.MONGODB_HOST,
port=int(config.MONGODB_PORT), port=int(config.MONGODB_PORT),
db_name=config.DATABASE_NAME, db_name=config.DATABASE_NAME,

View File

@ -37,14 +37,7 @@ def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
os.makedirs(images_dir, exist_ok=True) os.makedirs(images_dir, exist_ok=True)
# 连接数据库 # 连接数据库
db = Database( db = Database.get_instance()
host=config.mongodb_host,
port=int(config.mongodb_port),
db_name=config.database_name,
username=config.mongodb_username,
password=config.mongodb_password,
auth_source=config.mongodb_auth_source
)
# 检查是否已存在相同哈希值的图片 # 检查是否已存在相同哈希值的图片
collection = db.db['images'] collection = db.db['images']

View File

@ -19,12 +19,13 @@ from src.common.database import Database
# 从环境变量获取配置 # 从环境变量获取配置
Database.initialize( Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"), host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")), port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "maimai"), db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"), username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"), password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "admin") auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
) )
class KnowledgeLibrary: class KnowledgeLibrary:

View File

@ -162,12 +162,13 @@ class Memory_graph:
def main(): def main():
# 初始化数据库 # 初始化数据库
Database.initialize( Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"), host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")), port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"), db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""), username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD", ""), password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "") auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
) )
memory_graph = Memory_graph() memory_graph = Memory_graph()

View File

@ -8,6 +8,7 @@ import jieba
import networkx as nx import networkx as nx
from loguru import logger from loguru import logger
from nonebot import get_driver
from ...common.database import Database # 使用正确的导入语法 from ...common.database import Database # 使用正确的导入语法
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils import ( from ..chat.utils import (
@ -18,7 +19,6 @@ from ..chat.utils import (
) )
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
@ -130,7 +130,7 @@ class Memory_graph:
return None return None
# 海马体 # 海马体
class Hippocampus: class Hippocampus:
def __init__(self, memory_graph: Memory_graph): def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph self.memory_graph = memory_graph
@ -749,15 +749,13 @@ def segment_text(text):
seg_text = list(jieba.cut(text)) seg_text = list(jieba.cut(text))
return seg_text return seg_text
from nonebot import get_driver
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
start_time = time.time() start_time = time.time()
Database.initialize( Database.initialize(
uri=config.MONGODB_URI,
host=config.MONGODB_HOST, host=config.MONGODB_HOST,
port=config.MONGODB_PORT, port=config.MONGODB_PORT,
db_name=config.DATABASE_NAME, db_name=config.DATABASE_NAME,

View File

@ -35,45 +35,6 @@ else:
logger.warning(f"未找到环境变量文件: {env_path}") logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置") logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text): def calculate_information_content(text):
"""计算文本的信息量(熵)""" """计算文本的信息量(熵)"""
char_count = Counter(text) char_count = Counter(text)
@ -202,7 +163,7 @@ class Memory_graph:
# 返回所有节点对应的 Memory_dot 对象 # 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()] return [self.get_dot(node) for node in self.G.nodes()]
# 海马体 # 海马体
class Hippocampus: class Hippocampus:
def __init__(self, memory_graph: Memory_graph): def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph self.memory_graph = memory_graph
@ -941,59 +902,67 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
async def main(): async def main():
# 初始化数据库 # 初始化数据库
logger.info("正在初始化数据库连接...") logger.info("正在初始化数据库连接...")
db = Database.get_instance() Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
start_time = time.time() start_time = time.time()
test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False} test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
# 创建记忆图 # 创建记忆图
memory_graph = Memory_graph() memory_graph = Memory_graph()
# 创建海马体 # 创建海马体
hippocampus = Hippocampus(memory_graph) hippocampus = Hippocampus(memory_graph)
# 从数据库同步数据 # 从数据库同步数据
hippocampus.sync_memory_from_db() hippocampus.sync_memory_from_db()
end_time = time.time() end_time = time.time()
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m") logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆 # 构建记忆
if test_pare['do_build_memory']: if test_pare['do_build_memory']:
logger.info("开始构建记忆...") logger.info("开始构建记忆...")
chat_size = 20 chat_size = 20
await hippocampus.operation_build_memory(chat_size=chat_size) await hippocampus.operation_build_memory(chat_size=chat_size)
end_time = time.time() end_time = time.time()
logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m") logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m")
if test_pare['do_forget_topic']: if test_pare['do_forget_topic']:
logger.info("开始遗忘记忆...") logger.info("开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1) await hippocampus.operation_forget_topic(percentage=0.1)
end_time = time.time() end_time = time.time()
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m") logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_merge_memory']: if test_pare['do_merge_memory']:
logger.info("开始合并记忆...") logger.info("开始合并记忆...")
await hippocampus.operation_merge_memory(percentage=0.1) await hippocampus.operation_merge_memory(percentage=0.1)
end_time = time.time() end_time = time.time()
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m") logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_visualize_graph']: if test_pare['do_visualize_graph']:
# 展示优化后的图形 # 展示优化后的图形
logger.info("生成记忆图谱可视化...") logger.info("生成记忆图谱可视化...")
print("\n生成优化后的记忆图谱:") print("\n生成优化后的记忆图谱:")
visualize_graph_lite(memory_graph) visualize_graph_lite(memory_graph)
if test_pare['do_query']: if test_pare['do_query']:
# 交互式查询 # 交互式查询
while True: while True:
query = input("\n请输入新的查询概念(输入'退出'以结束):") query = input("\n请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出': if query.lower() == '退出':
break break
items_list = memory_graph.get_related_item(query) items_list = memory_graph.get_related_item(query)
if items_list: if items_list:
first_layer, second_layer = items_list first_layer, second_layer = items_list
@ -1008,9 +977,6 @@ async def main():
else: else:
print("未找到相关记忆。") print("未找到相关记忆。")
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View File

@ -14,6 +14,7 @@ driver = get_driver()
config = driver.config config = driver.config
Database.initialize( Database.initialize(
uri=config.MONGODB_URI,
host=config.MONGODB_HOST, host=config.MONGODB_HOST,
port=int(config.MONGODB_PORT), port=int(config.MONGODB_PORT),
db_name=config.DATABASE_NAME, db_name=config.DATABASE_NAME,
@ -22,7 +23,6 @@ Database.initialize(
auth_source=config.MONGODB_AUTH_SOURCE auth_source=config.MONGODB_AUTH_SOURCE
) )
class ScheduleGenerator: class ScheduleGenerator:
def __init__(self): def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型 # 根据global_config.llm_normal这一字典配置指定模型