From 12e738427ea9359d7e5d91fe0efad739c84802f6 Mon Sep 17 00:00:00 2001 From: cuckoo711 <3038604221@qq.com> Date: Thu, 7 Aug 2025 05:30:32 +0800 Subject: [PATCH] =?UTF-8?q?feat(database):=20=E6=94=AF=E6=8C=81=20MySQL=20?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 MySQL 数据库支持,配置文件中增加数据库类型、主机、端口、用户名、密码等设置项 - 修改数据库连接逻辑,支持 SQLite 和 MySQL两种数据库类型 - 更新数据库模型,使用 CharField 替代 TextField,限制字段长度 -增加数据库表前缀配置,方便多项目共用数据库 --- src/common/database/database.py | 57 ++++++++++++-------- src/common/database/database_model.py | 78 ++++++++++++++------------- src/config/config.py | 2 + src/config/official_configs.py | 21 ++++++++ template/bot_config_template.toml | 13 ++++- 5 files changed, 111 insertions(+), 60 deletions(-) diff --git a/src/common/database/database.py b/src/common/database/database.py index ca361481..8b4ce381 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,9 +1,11 @@ import os from pymongo import MongoClient -from peewee import SqliteDatabase +from peewee import MySQLDatabase, SqliteDatabase from pymongo.database import Database from rich.traceback import install +from src.config.config import global_config + install(extra_lines=3) _client = None @@ -57,26 +59,39 @@ class DBWrapper: return get_db()[key] # type: ignore +def create_peewee_database(): + data_base_config = global_config.data_base + + if data_base_config.db_type == "mysql": + return MySQLDatabase( + data_base_config.database_name, + user=data_base_config.username, + password=data_base_config.password, + host=data_base_config.host, + port=int(data_base_config.port), + charset='utf8mb4' + ) + elif data_base_config.db_type == "sqlite": + ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + _DB_DIR = os.path.join(ROOT_PATH, "data") + _DB_FILE = os.path.join(_DB_DIR, "MaiBot.db") + os.makedirs(_DB_DIR, exist_ok=True) + return SqliteDatabase( + _DB_FILE, + pragmas={ + "journal_mode": "wal", # WAL模式提高并发性能 + "cache_size": -64 * 1000, # 64MB缓存 + "foreign_keys": 1, + "ignore_check_constraints": 0, + "synchronous": 0, # 异步写入提高性能 + "busy_timeout": 1000, # 1秒超时而不是3秒 + }, ) + else: + raise ValueError(f"Unsupported PEEWEE_DB_TYPE: {data_base_config.db_type}") + + # 全局数据库访问点 -memory_db: Database = DBWrapper() # type: ignore - -# 定义数据库文件路径 -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) -_DB_DIR = os.path.join(ROOT_PATH, "data") -_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db") - -# 确保数据库目录存在 -os.makedirs(_DB_DIR, exist_ok=True) +memory_db: Database | DBWrapper = DBWrapper() # 全局 Peewee SQLite 数据库访问点 -db = SqliteDatabase( - _DB_FILE, - pragmas={ - "journal_mode": "wal", # WAL模式提高并发性能 - "cache_size": -64 * 1000, # 64MB缓存 - "foreign_keys": 1, - "ignore_check_constraints": 0, - "synchronous": 0, # 异步写入提高性能 - "busy_timeout": 1000, # 1秒超时而不是3秒 - }, -) +db = create_peewee_database() diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index d2b3acce..d954a422 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -1,9 +1,11 @@ -from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField +from peewee import CharField, Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField from .database import db import datetime from src.common.logger import get_logger +table_prefix = "" # 数据库表前缀 logger = get_logger("database_model") +logger.info(f"正在加载数据库模型...数据库表前缀为: {table_prefix}") # 请在此处定义您的数据库实例。 # 您需要取消注释并配置适合您的数据库的部分。 # 例如,对于 SQLite: @@ -34,7 +36,7 @@ class ChatStreams(BaseModel): # stream_id: "a544edeb1a9b73e3e1d77dff36e41264" # 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。 - stream_id = TextField(unique=True, index=True) + stream_id = CharField(max_length=64, unique=True) # create_time: 1746096761.4490178 (时间戳,精确到小数点后7位) # DoubleField 用于存储浮点数,适合此类时间戳。 @@ -70,7 +72,7 @@ class ChatStreams(BaseModel): # 如果不使用带有数据库实例的 BaseModel,或者想覆盖它, # 请取消注释并在下面设置数据库实例: # database = db - table_name = "chat_streams" # 可选:明确指定数据库中的表名 + table_name = table_prefix + "chat_streams" # 可选:明确指定数据库中的表名 class LLMUsage(BaseModel): @@ -78,9 +80,9 @@ class LLMUsage(BaseModel): 用于存储 API 使用日志数据的模型。 """ - model_name = TextField(index=True) # 添加索引 - user_id = TextField(index=True) # 添加索引 - request_type = TextField(index=True) # 添加索引 + model_name = CharField(max_length=64, index=True) # 添加索引 + user_id = CharField(max_length=64, index=True) # 添加索引 + request_type = CharField(max_length=64, index=True) # 添加索引 endpoint = TextField() prompt_tokens = IntegerField() completion_tokens = IntegerField() @@ -92,15 +94,15 @@ class LLMUsage(BaseModel): class Meta: # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # database = db - table_name = "llm_usage" + table_name = table_prefix + "llm_usage" class Emoji(BaseModel): """表情包""" - full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名) + full_path = CharField(max_length=512, unique=True) # 文件的完整路径 (包括文件名) format = TextField() # 图片格式 - emoji_hash = TextField(index=True) # 表情包的哈希值 + emoji_hash = CharField(max_length=64, index=True) # 表情包的哈希值 description = TextField() # 表情包的描述 query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数) is_registered = BooleanField(default=False) # 是否已注册 @@ -114,7 +116,7 @@ class Emoji(BaseModel): class Meta: # database = db # 继承自 BaseModel - table_name = "emoji" + table_name = table_prefix + "emoji" class Messages(BaseModel): @@ -122,10 +124,10 @@ class Messages(BaseModel): 用于存储消息数据的模型。 """ - message_id = TextField(index=True) # 消息 ID (更改自 IntegerField) + message_id = CharField(max_length=128, index=True) # 消息 ID (更改自 IntegerField) time = DoubleField() # 消息时间戳 - chat_id = TextField(index=True) # 对应的 ChatStreams stream_id + chat_id = CharField(max_length=128, index=True) # 对应的 ChatStreams stream_id reply_to = TextField(null=True) @@ -165,7 +167,7 @@ class Messages(BaseModel): class Meta: # database = db # 继承自 BaseModel - table_name = "messages" + table_name = table_prefix + "messages" class ActionRecords(BaseModel): @@ -183,13 +185,13 @@ class ActionRecords(BaseModel): action_build_into_prompt = BooleanField(default=False) action_prompt_display = TextField() - chat_id = TextField(index=True) # 对应的 ChatStreams stream_id + chat_id = CharField(max_length=128, index=True) # 对应的 ChatStreams stream_id chat_info_stream_id = TextField() chat_info_platform = TextField() class Meta: # database = db # 继承自 BaseModel - table_name = "action_records" + table_name = table_prefix + "action_records" class Images(BaseModel): @@ -198,9 +200,9 @@ class Images(BaseModel): """ image_id = TextField(default="") # 图片唯一ID - emoji_hash = TextField(index=True) # 图像的哈希值 + emoji_hash = CharField(max_length=64, index=True) # 图像的哈希值 description = TextField(null=True) # 图像的描述 - path = TextField(unique=True) # 图像文件的路径 + path = CharField(max_length=512, unique=True) # 图像文件的路径 # base64 = TextField() # 图片的base64编码 count = IntegerField(default=1) # 图片被引用的次数 timestamp = FloatField() # 时间戳 @@ -208,7 +210,7 @@ class Images(BaseModel): vlm_processed = BooleanField(default=False) # 是否已经过VLM处理 class Meta: - table_name = "images" + table_name = table_prefix + "images" class ImageDescriptions(BaseModel): @@ -217,13 +219,13 @@ class ImageDescriptions(BaseModel): """ type = TextField() # 类型,例如 "emoji" - image_description_hash = TextField(index=True) # 图像的哈希值 + image_description_hash = CharField(max_length=64, index=True) # 图像的哈希值 description = TextField() # 图像的描述 timestamp = FloatField() # 时间戳 class Meta: # database = db # 继承自 BaseModel - table_name = "image_descriptions" + table_name = table_prefix + "image_descriptions" class OnlineTime(BaseModel): @@ -232,14 +234,14 @@ class OnlineTime(BaseModel): """ # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串) - timestamp = TextField(default=datetime.datetime.now) # 时间戳 + timestamp = CharField(max_length=64, default=datetime.datetime.now) # 时间戳 duration = IntegerField() # 时长,单位分钟 start_timestamp = DateTimeField(default=datetime.datetime.now) end_timestamp = DateTimeField(index=True) class Meta: # database = db # 继承自 BaseModel - table_name = "online_time" + table_name = table_prefix + "online_time" class PersonInfo(BaseModel): @@ -247,11 +249,11 @@ class PersonInfo(BaseModel): 用于存储个人信息数据的模型。 """ - person_id = TextField(unique=True, index=True) # 个人唯一ID + person_id = CharField(max_length=64, unique=True) # 个人唯一ID person_name = TextField(null=True) # 个人名称 (允许为空) name_reason = TextField(null=True) # 名称设定的原因 platform = TextField() # 平台 - user_id = TextField(index=True) # 用户ID + user_id = CharField(max_length=64, index=True) # 用户ID nickname = TextField() # 用户昵称 impression = TextField(null=True) # 个人印象 short_impression = TextField(null=True) # 个人印象的简短描述 @@ -266,11 +268,11 @@ class PersonInfo(BaseModel): class Meta: # database = db # 继承自 BaseModel - table_name = "person_info" + table_name = table_prefix + "person_info" class Memory(BaseModel): - memory_id = TextField(index=True) + memory_id = CharField(max_length=128, index=True) chat_id = TextField(null=True) memory_text = TextField(null=True) keywords = TextField(null=True) @@ -278,7 +280,7 @@ class Memory(BaseModel): last_view_time = FloatField(null=True) class Meta: - table_name = "memory" + table_name = table_prefix + "memory" class Expression(BaseModel): @@ -290,16 +292,16 @@ class Expression(BaseModel): style = TextField() count = FloatField() last_active_time = FloatField() - chat_id = TextField(index=True) + chat_id = CharField(max_length=128, index=True) type = TextField() create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据 class Meta: - table_name = "expression" + table_name = table_prefix + "expression" class ThinkingLog(BaseModel): - chat_id = TextField(index=True) + chat_id = CharField(max_length=128, index=True) trigger_text = TextField(null=True) response_text = TextField(null=True) @@ -319,7 +321,7 @@ class ThinkingLog(BaseModel): created_at = DateTimeField(default=datetime.datetime.now) class Meta: - table_name = "thinking_logs" + table_name = table_prefix + "thinking_logs" class GraphNodes(BaseModel): @@ -327,14 +329,14 @@ class GraphNodes(BaseModel): 用于存储记忆图节点的模型 """ - concept = TextField(unique=True, index=True) # 节点概念 + concept = CharField(max_length=128, unique=True) # 节点概念 memory_items = TextField() # JSON格式存储的记忆列表 hash = TextField() # 节点哈希值 created_time = FloatField() # 创建时间戳 last_modified = FloatField() # 最后修改时间戳 class Meta: - table_name = "graph_nodes" + table_name = table_prefix + "graph_nodes" class GraphEdges(BaseModel): @@ -342,15 +344,15 @@ class GraphEdges(BaseModel): 用于存储记忆图边的模型 """ - source = TextField(index=True) # 源节点 - target = TextField(index=True) # 目标节点 + source = CharField(max_length=128, index=True) # 源节点 + target = CharField(max_length=128, index=True) # 目标节点 strength = IntegerField() # 连接强度 hash = TextField() # 边哈希值 created_time = FloatField() # 创建时间戳 last_modified = FloatField() # 最后修改时间戳 class Meta: - table_name = "graph_edges" + table_name = table_prefix + "graph_edges" def create_tables(): @@ -400,7 +402,7 @@ def initialize_database(): GraphEdges, ActionRecords, # 添加 ActionRecords 到初始化列表 ] - + del_extra = False # 是否删除多余字段 try: with db: # 管理 table_exists 检查的连接 for model in models: @@ -452,6 +454,8 @@ def initialize_database(): logger.error(f"添加字段 '{field_name}' 失败: {e}") # 检查并删除多余字段(新增逻辑) + if not del_extra: + continue extra_fields = existing_columns - model_fields if extra_fields: logger.warning(f"表 '{table_name}' 存在多余字段: {extra_fields}") diff --git a/src/config/config.py b/src/config/config.py index 368adaa5..6ba8ba92 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -14,6 +14,7 @@ from src.common.logger import get_logger from src.config.config_base import ConfigBase from src.config.official_configs import ( BotConfig, + DataBaseConfig, PersonalityConfig, ExpressionConfig, ChatConfig, @@ -348,6 +349,7 @@ class Config(ConfigBase): debug: DebugConfig custom_prompt: CustomPromptConfig voice: VoiceConfig + data_base: DataBaseConfig @dataclass diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 8f34a184..1d00c957 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -598,3 +598,24 @@ class LPMMKnowledgeConfig(ConfigBase): embedding_dimension: int = 1024 """嵌入向量维度,应该与模型的输出维度一致""" + +class DataBaseConfig(ConfigBase): + """数据库配置类""" + + db_type: Literal["sqlite", "mysql"] = "sqlite" + """数据库类型,支持sqlite、mysql""" + + host: str = "127.0.0.1" + """数据库主机地址""" + + port: int = 3306 + """数据库端口号""" + + username: str = "" + """数据库用户名""" + + password: str = "" + """数据库密码""" + + database_name: str = "MaiBot" + """数据库名称""" diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index fae41f82..a9ea9b2c 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.0.0" +version = "6.1.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 @@ -231,4 +231,13 @@ key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效 enable = true [experimental] #实验性功能 -enable_friend_chat = false # 是否启用好友聊天 \ No newline at end of file +enable_friend_chat = false # 是否启用好友聊天 + +[data_base] #数据库配置 +# 数据库类型,可选:sqlite, mysql +db_type = "sqlite" # 数据库类型 +host = "" # 数据库主机地址,如果是sqlite则不需要填写 +port = 3306 # 数据库端口,如果是sqlite则不需要填写 +user = "" # 数据库用户名,如果是sqlite则不需要填写 +password = "" # 数据库密码,如果是sqlite则不需要填写 +database = "MaiBot" # 数据库名称,如果是sqlite则不需要填写 \ No newline at end of file