diff --git a/bot.py b/bot.py
index 7faad3ed..02dee81c 100644
--- a/bot.py
+++ b/bot.py
@@ -1,4 +1,4 @@
-raise RuntimeError("System Not Ready")
+# raise RuntimeError("System Not Ready")
import asyncio
import hashlib
import os
@@ -181,14 +181,14 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
logger.info("正在优雅关闭麦麦...")
# 关闭 WebUI 服务器
- try:
- from src.webui.webui_server import get_webui_server
+ # try:
+ # from src.webui.webui_server import get_webui_server
- webui_server = get_webui_server()
- if webui_server and webui_server._server:
- await webui_server.shutdown()
- except Exception as e:
- logger.warning(f"关闭 WebUI 服务器时出错: {e}")
+ # webui_server = get_webui_server()
+ # if webui_server and webui_server._server:
+ # await webui_server.shutdown()
+ # except Exception as e:
+ # logger.warning(f"关闭 WebUI 服务器时出错: {e}")
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType
diff --git a/changelogs/mai_next_design.md b/changelogs/mai_next_design.md
index fc870960..dcee4681 100644
--- a/changelogs/mai_next_design.md
+++ b/changelogs/mai_next_design.md
@@ -2,15 +2,12 @@
Version 0.2.2 - 2025-11-05
## 配置文件设计
-- [x] 使用 `toml` 作为配置文件格式
-- [x] 合理使用注释说明当前配置作用(提案)
-- [x] 使用 python 方法作为配置项说明(提案)
- - [x] 取消`bot_config_template.toml`
- - [x] 取消`model_config_template.toml`
-- [x] 配置类中的所有原子项目应该只包含以下类型: `str`, `int`, `float`, `bool`, `list`, `dict`, `set`
- - [ ] 暂时禁止使用 `Union` 类型(尚未支持解析)
- - [ ] 不建议使用`tuple`类型,使用时会发出警告,考虑使用嵌套`dataclass`替代
- - [x] 复杂类型使用嵌套配置类实现
+主体利用`pydantic`的`BaseModel`进行配置类设计`ConfigBase`类
+要求每个属性必须具有类型注解,且类型注解满足以下要求:
+- 原子类型仅允许使用: `str`, `int`, `float`, `bool`, 以及基于`ConfigBase`的嵌套配置类
+- 复杂类型允许使用: `list`, `dict`, `set`,但其内部类型必须为原子类型或嵌套配置类,不可使用`list[list[int]]`,`list[dict[str, int]]`等写法
+- 禁止了使用`Union`, `tuple/Tuple`类型
+ - 但是`Optional`仍然允许使用
### 移除template的方案提案
配置项说明的废案
diff --git a/pytests/config_test/test_config_base.py b/pytests/config_test/test_config_base.py
index e210a82f..fea2fe8e 100644
--- a/pytests/config_test/test_config_base.py
+++ b/pytests/config_test/test_config_base.py
@@ -46,6 +46,10 @@ def patch_attrdoc_post_init():
config_base_module.logger = logging.getLogger("config_base_test_logger")
+class SimpleClass(ConfigBase):
+ a: int = 1
+ b: str = "test"
+
class TestConfigBase:
# ---------------------------------------------------------
@@ -267,6 +271,18 @@ class TestConfigBase:
"不允许嵌套泛型类型",
id="listset-validation-nested-generic-inner-list",
),
+ pytest.param(
+ List[SimpleClass],
+ False,
+ None,
+ id="listset-validation-list-configbase-element_allow",
+ ),
+ pytest.param(
+ Set[SimpleClass],
+ True,
+ "ConfigBase is not Hashable",
+ id="listset-validation-set-configbase-element_reject",
+ )
],
)
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
@@ -319,6 +335,12 @@ class TestConfigBase:
"必须指定键和值的类型参数",
id="dict-validation-missing-args",
),
+ pytest.param(
+ Dict[str, SimpleClass],
+ False,
+ None,
+ id="dict-validation-happy-configbase-value",
+ )
],
)
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
@@ -370,6 +392,22 @@ class TestConfigBase:
# Assert
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
+
+ def test_discourage_any_usage_suppressed_warning(self, caplog):
+ class Sample(ConfigBase):
+ _validate_any: bool = False
+ suppress_any_warning: bool = True
+
+ instance = Sample()
+
+ # Arrange
+ caplog.set_level(logging.WARNING, logger="config_base_test_logger")
+
+ # Act
+ instance._discourage_any_usage("field_z")
+
+ # Assert
+ assert "字段'field_z'中使用了 Any 类型注解" not in caplog.text
# ---------------------------------------------------------
# model_post_init 规则覆盖(错误与边界情况)
diff --git a/src/common/remote.py b/src/common/remote.py
index 5380cd01..098c13d1 100644
--- a/src/common/remote.py
+++ b/src/common/remote.py
@@ -5,7 +5,7 @@ import platform
from src.common.logger import get_logger
from src.common.tcp_connector import get_tcp_connector
-from src.config.config import global_config
+from src.config.config import global_config, MMC_VERSION
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
@@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask):
info_dict = {
"os_type": "Unknown",
"py_version": platform.python_version(),
- "mmc_version": global_config.MMC_VERSION,
+ "mmc_version": MMC_VERSION,
}
match platform.system():
diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py
deleted file mode 100644
index f450e91f..00000000
--- a/src/config/api_ada_configs.py
+++ /dev/null
@@ -1,139 +0,0 @@
-from dataclasses import dataclass, field
-
-from .config_base import ConfigBase
-
-
-@dataclass
-class APIProvider(ConfigBase):
- """API提供商配置类"""
-
- name: str
- """API提供商名称"""
-
- base_url: str
- """API基础URL"""
-
- api_key: str = field(default_factory=str, repr=False)
- """API密钥列表"""
-
- client_type: str = field(default="openai")
- """客户端类型(如openai/google等,默认为openai)"""
-
- max_retry: int = 2
- """最大重试次数(单个模型API调用失败,最多重试的次数)"""
-
- timeout: int = 10
- """API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)"""
-
- retry_interval: int = 10
- """重试间隔(如果API调用失败,重试的间隔时间,单位:秒)"""
-
- def get_api_key(self) -> str:
- return self.api_key
-
- def __post_init__(self):
- """确保api_key在repr中不被显示"""
- if not self.api_key:
- raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
- if not self.base_url and self.client_type != "gemini":
- raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
- if not self.name:
- raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。")
-
-
-@dataclass
-class ModelInfo(ConfigBase):
- """单个模型信息配置类"""
-
- model_identifier: str
- """模型标识符(用于URL调用)"""
-
- name: str
- """模型名称(用于模块调用)"""
-
- api_provider: str
- """API提供商(如OpenAI、Azure等)"""
-
- price_in: float = field(default=0.0)
- """每M token输入价格"""
-
- price_out: float = field(default=0.0)
- """每M token输出价格"""
-
- temperature: float | None = field(default=None)
- """模型级别温度(可选),会覆盖任务配置中的温度"""
-
- max_tokens: int | None = field(default=None)
- """模型级别最大token数(可选),会覆盖任务配置中的max_tokens"""
-
- force_stream_mode: bool = field(default=False)
- """是否强制使用流式输出模式"""
-
- extra_params: dict = field(default_factory=dict)
- """额外参数(用于API调用时的额外配置)"""
-
- def __post_init__(self):
- if not self.model_identifier:
- raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
- if not self.name:
- raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
- if not self.api_provider:
- raise ValueError("API提供商不能为空,请在配置中设置有效的API提供商。")
-
-
-@dataclass
-class TaskConfig(ConfigBase):
- """任务配置类"""
-
- model_list: list[str] = field(default_factory=list)
- """任务使用的模型列表"""
-
- max_tokens: int = 1024
- """任务最大输出token数"""
-
- temperature: float = 0.3
- """模型温度"""
-
- slow_threshold: float = 15.0
- """慢请求阈值(秒),超过此值会输出警告日志"""
-
- selection_strategy: str = field(default="balance")
- """模型选择策略:balance(负载均衡)或 random(随机选择)"""
-
-
-@dataclass
-class ModelTaskConfig(ConfigBase):
- """模型配置类"""
-
- utils: TaskConfig
- """组件模型配置"""
-
- replyer: TaskConfig
- """normal_chat首要回复模型模型配置"""
-
- vlm: TaskConfig
- """视觉语言模型配置"""
-
- voice: TaskConfig
- """语音识别模型配置"""
-
- tool_use: TaskConfig
- """专注工具使用模型配置"""
-
- planner: TaskConfig
- """规划模型配置"""
-
- embedding: TaskConfig
- """嵌入模型配置"""
-
- lpmm_entity_extract: TaskConfig
- """LPMM实体提取模型配置"""
-
- lpmm_rdf_build: TaskConfig
- """LPMM RDF构建模型配置"""
-
- def get_task(self, task_name: str) -> TaskConfig:
- """获取指定任务的配置"""
- if hasattr(self, task_name):
- return getattr(self, task_name)
- raise ValueError(f"任务 '{task_name}' 未找到对应的配置")
diff --git a/src/config/config.py b/src/config/config.py
index 0a3ba48c..a6685bc0 100644
--- a/src/config/config.py
+++ b/src/config/config.py
@@ -1,19 +1,12 @@
-import os
+from pathlib import Path
+from typing import TypeVar
+from datetime import datetime
+from typing import Any
+
import tomlkit
-import shutil
import sys
-from datetime import datetime
-from tomlkit import TOMLDocument
-from tomlkit.items import Table, KeyType
-from dataclasses import field, dataclass
-from rich.traceback import install
-from typing import List, Optional
-
-from src.common.logger import get_logger
-from src.common.toml_utils import format_toml_string
-from src.config.config_base import ConfigBase
-from src.config.official_configs import (
+from .official_configs import (
BotConfig,
PersonalityConfig,
ExpressionConfig,
@@ -34,345 +27,112 @@ from src.config.official_configs import (
MemoryConfig,
DebugConfig,
DreamConfig,
- WebUIConfig,
)
+from .model_configs import ModelInfo, ModelTaskConfig, APIProvider
+from .config_base import ConfigBase, Field, AttributeData
+from .config_utils import recursive_parse_item_to_table, output_config_changes, compare_versions
-from .api_ada_configs import (
- ModelTaskConfig,
- ModelInfo,
- APIProvider,
-)
+from src.common.logger import get_logger
+"""
+如果你想要修改配置文件,请递增version的值
-install(extra_lines=3)
+版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
+ 主版本号:MMC版本更新
+ 次版本号:配置文件内容大更新
+ 修订号:配置文件内容小更新
+"""
+PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
+CONFIG_DIR: Path = PROJECT_ROOT / "config"
+BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
+MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
+MMC_VERSION: str = "0.13.0"
+CONFIG_VERSION: str = "8.0.0"
+MODEL_CONFIG_VERSION: str = "1.12.0"
-# 配置主程序日志格式
logger = get_logger("config")
-# 获取当前文件所在目录的父目录的父目录(即MaiBot项目根目录)
-PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
-CONFIG_DIR = os.path.join(PROJECT_ROOT, "config")
-TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
-
-# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
-# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
-MMC_VERSION = "0.13.0-snapshot.1"
+T = TypeVar("T", bound="ConfigBase")
-def get_key_comment(toml_table, key):
- # 获取key的注释(如果有)
- if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
- return toml_table.trivia.comment
- if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
- item = toml_table.value.get(key)
- if item is not None and hasattr(item, "trivia"):
- return item.trivia.comment
- if hasattr(toml_table, "keys"):
- for k in toml_table.keys():
- if isinstance(k, KeyType) and k.key == key: # type: ignore
- return k.trivia.comment # type: ignore
- return None
-
-
-def compare_dicts(new, old, path=None, logs=None):
- # 递归比较两个dict,找出新增和删减项,收集注释
- if path is None:
- path = []
- if logs is None:
- logs = []
- # 新增项
- for key in new:
- if key == "version":
- continue
- if key not in old:
- comment = get_key_comment(new, key)
- logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
- elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
- compare_dicts(new[key], old[key], path + [str(key)], logs)
- # 删减项
- for key in old:
- if key == "version":
- continue
- if key not in new:
- comment = get_key_comment(old, key)
- logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
- return logs
-
-
-def get_value_by_path(d, path):
- for k in path:
- if isinstance(d, dict) and k in d:
- d = d[k]
- else:
- return None
- return d
-
-
-def set_value_by_path(d, path, value):
- """设置嵌套字典中指定路径的值"""
- for k in path[:-1]:
- if k not in d or not isinstance(d[k], dict):
- d[k] = {}
- d = d[k]
-
- # 使用 tomlkit.item 来保持 TOML 格式
- try:
- d[path[-1]] = tomlkit.item(value)
- except (TypeError, ValueError):
- # 如果转换失败,直接赋值
- d[path[-1]] = value
-
-
-def compare_default_values(new, old, path=None, logs=None, changes=None):
- # 递归比较两个dict,找出默认值变化项
- if path is None:
- path = []
- if logs is None:
- logs = []
- if changes is None:
- changes = []
- for key in new:
- if key == "version":
- continue
- if key in old:
- if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
- compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
- elif new[key] != old[key]:
- logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
- changes.append((path + [str(key)], old[key], new[key]))
- return logs, changes
-
-
-def _get_version_from_toml(toml_path) -> Optional[str]:
- """从TOML文件中获取版本号"""
- if not os.path.exists(toml_path):
- return None
- with open(toml_path, "r", encoding="utf-8") as f:
- doc = tomlkit.load(f)
- if "inner" in doc and "version" in doc["inner"]: # type: ignore
- return doc["inner"]["version"] # type: ignore
- return None
-
-
-def _version_tuple(v):
- """将版本字符串转换为元组以便比较"""
- if v is None:
- return (0,)
- return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
-
-
-def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
- """
- 将source字典的值更新到target字典中(如果target中存在相同的键)
- """
- for key, value in source.items():
- # 跳过version字段的更新
- if key == "version":
- continue
- if key in target:
- target_value = target[key]
- if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
- _update_dict(target_value, value)
- else:
- try:
- # 统一使用 tomlkit.item 来保持原生类型与转义,不对列表做字符串化处理
- target[key] = tomlkit.item(value)
- except (TypeError, ValueError):
- # 如果转换失败,直接赋值
- target[key] = value
-
-
-def _update_config_generic(config_name: str, template_name: str):
- """
- 通用的配置文件更新函数
-
- Args:
- config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config'
- template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template'
- """
- # 获取根目录路径
- old_config_dir = os.path.join(CONFIG_DIR, "old")
- compare_dir = os.path.join(TEMPLATE_DIR, "compare")
-
- # 定义文件路径
- template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
- old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
- new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
- compare_path = os.path.join(compare_dir, f"{template_name}.toml")
-
- # 创建compare目录(如果不存在)
- os.makedirs(compare_dir, exist_ok=True)
-
- template_version = _get_version_from_toml(template_path)
- compare_version = _get_version_from_toml(compare_path)
-
- # 检查配置文件是否存在
- if not os.path.exists(old_config_path):
- logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置")
- os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
- shutil.copy2(template_path, old_config_path) # 复制模板文件
- logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
- # 新创建配置文件,退出
- sys.exit(0)
-
- compare_config = None
- new_config = None
- old_config = None
-
- # 先读取 compare 下的模板(如果有),用于默认值变动检测
- if os.path.exists(compare_path):
- with open(compare_path, "r", encoding="utf-8") as f:
- compare_config = tomlkit.load(f)
-
- # 读取当前模板
- with open(template_path, "r", encoding="utf-8") as f:
- new_config = tomlkit.load(f)
-
- # 检查默认值变化并处理(只有 compare_config 存在时才做)
- if compare_config:
- # 读取旧配置
- with open(old_config_path, "r", encoding="utf-8") as f:
- old_config = tomlkit.load(f)
- logs, changes = compare_default_values(new_config, compare_config)
- if logs:
- logger.info(f"检测到{config_name}模板默认值变动如下:")
- for log in logs:
- logger.info(log)
- # 检查旧配置是否等于旧默认值,如果是则更新为新默认值
- config_updated = False
- for path, old_default, new_default in changes:
- old_value = get_value_by_path(old_config, path)
- if old_value == old_default:
- set_value_by_path(old_config, path, new_default)
- logger.info(
- f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
- )
- config_updated = True
-
- # 如果配置有更新,立即保存到文件
- if config_updated:
- with open(old_config_path, "w", encoding="utf-8") as f:
- f.write(format_toml_string(old_config))
- logger.info(f"已保存更新后的{config_name}配置文件")
- else:
- logger.info(f"未检测到{config_name}模板默认值变动")
-
- # 检查 compare 下没有模板,或新模板版本更高,则复制
- if not os.path.exists(compare_path):
- shutil.copy2(template_path, compare_path)
- logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
- elif _version_tuple(template_version) > _version_tuple(compare_version):
- shutil.copy2(template_path, compare_path)
- logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}")
- else:
- logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
-
- # 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次)
- if old_config is None:
- with open(old_config_path, "r", encoding="utf-8") as f:
- old_config = tomlkit.load(f)
- # new_config 已经读取
-
- # 检查version是否相同
- if old_config and "inner" in old_config and "inner" in new_config:
- old_version = old_config["inner"].get("version") # type: ignore
- new_version = new_config["inner"].get("version") # type: ignore
- if old_version and new_version and old_version == new_version:
- logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
- return
- else:
- logger.info(
- f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
- )
- else:
- logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
-
- # 创建old目录(如果不存在)
- os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
-
- # 移动旧配置文件到old目录
- shutil.move(old_config_path, old_backup_path)
- logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
-
- # 复制模板文件到配置目录
- shutil.copy2(template_path, new_config_path)
- logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
-
- # 输出新增和删减项及注释
- if old_config:
- logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
- if logs := compare_dicts(new_config, old_config):
- for log in logs:
- logger.info(log)
- else:
- logger.info("无新增或删减项")
-
- # 将旧配置的值更新到新配置中
- logger.info(f"开始合并{config_name}新旧配置...")
- _update_dict(new_config, old_config)
-
- # 保存更新后的配置(保留注释和格式,数组多行格式化)
- with open(new_config_path, "w", encoding="utf-8") as f:
- f.write(format_toml_string(new_config))
- logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
-
-
-def update_config():
- """更新bot_config.toml配置文件"""
- _update_config_generic("bot_config", "bot_config_template")
-
-
-def update_model_config():
- """更新model_config.toml配置文件"""
- _update_config_generic("model_config", "model_config_template")
-
-
-@dataclass
class Config(ConfigBase):
"""总配置类"""
- MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
+ bot: BotConfig = Field(default_factory=BotConfig)
+ """机器人配置类"""
- bot: BotConfig
- personality: PersonalityConfig
- relationship: RelationshipConfig
- chat: ChatConfig
- message_receive: MessageReceiveConfig
- emoji: EmojiConfig
- expression: ExpressionConfig
- keyword_reaction: KeywordReactionConfig
- chinese_typo: ChineseTypoConfig
- response_post_process: ResponsePostProcessConfig
- response_splitter: ResponseSplitterConfig
- telemetry: TelemetryConfig
- webui: WebUIConfig
- experimental: ExperimentalConfig
- maim_message: MaimMessageConfig
- lpmm_knowledge: LPMMKnowledgeConfig
- tool: ToolConfig
- memory: MemoryConfig
- debug: DebugConfig
- voice: VoiceConfig
- dream: DreamConfig
+ personality: PersonalityConfig = Field(default_factory=PersonalityConfig)
+ """人格配置类"""
+
+ expression: ExpressionConfig = Field(default_factory=ExpressionConfig)
+ """表达配置类"""
+
+ chat: ChatConfig = Field(default_factory=ChatConfig)
+ """聊天配置类"""
+
+ memory: MemoryConfig = Field(default_factory=MemoryConfig)
+ """记忆配置类"""
+
+ relationship: RelationshipConfig = Field(default_factory=RelationshipConfig)
+ """关系配置类"""
+
+ message_receive: MessageReceiveConfig = Field(default_factory=MessageReceiveConfig)
+ """消息接收配置类"""
+
+ dream: DreamConfig = Field(default_factory=DreamConfig)
+ """做梦配置类"""
+
+ tool: ToolConfig = Field(default_factory=ToolConfig)
+ """工具配置类"""
+
+ voice: VoiceConfig = Field(default_factory=VoiceConfig)
+ """语音配置类"""
+
+ emoji: EmojiConfig = Field(default_factory=EmojiConfig)
+ """表情包配置类"""
+
+ keyword_reaction: KeywordReactionConfig = Field(default_factory=KeywordReactionConfig)
+ """关键词反应配置类"""
+
+ response_post_process: ResponsePostProcessConfig = Field(default_factory=ResponsePostProcessConfig)
+ """回复后处理配置类"""
+
+ chinese_typo: ChineseTypoConfig = Field(default_factory=ChineseTypoConfig)
+ """中文错别字生成器配置类"""
+
+ response_splitter: ResponseSplitterConfig = Field(default_factory=ResponseSplitterConfig)
+ """回复分割器配置类"""
+
+ telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig)
+ """遥测配置类"""
+
+ debug: DebugConfig = Field(default_factory=DebugConfig)
+ """调试配置类"""
+
+ experimental: ExperimentalConfig = Field(default_factory=ExperimentalConfig)
+ """实验性功能配置类"""
+
+ maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig)
+ """maim_message配置类"""
+
+ lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig)
+ """LPMM知识库配置类"""
-@dataclass
-class APIAdapterConfig(ConfigBase):
- """API Adapter配置类"""
+class ModelConfig(ConfigBase):
+ """模型配置类"""
- models: List[ModelInfo]
- """模型列表"""
+ models: list[ModelInfo] = Field(default_factory=list)
+ """模型配置列表"""
- model_task_config: ModelTaskConfig
+ model_task_config: ModelTaskConfig = Field(default_factory=ModelTaskConfig)
"""模型任务配置"""
- api_providers: List[APIProvider] = field(default_factory=list)
+ api_providers: list[APIProvider] = Field(default_factory=list)
"""API提供商列表"""
- def __post_init__(self):
+ def model_post_init(self, context: Any = None):
if not self.models:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
if not self.api_providers:
@@ -388,78 +148,134 @@ class APIAdapterConfig(ConfigBase):
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
- self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
- self.models_dict = {model.name: model for model in self.models}
+ api_providers_dict = {provider.name: provider for provider in self.api_providers}
for model in self.models:
if not model.model_identifier:
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
- if not model.api_provider or model.api_provider not in self.api_providers_dict:
+ if not model.api_provider or model.api_provider not in api_providers_dict:
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
-
- def get_model_info(self, model_name: str) -> ModelInfo:
- """根据模型名称获取模型信息"""
- if not model_name:
- raise ValueError("模型名称不能为空")
- if model_name not in self.models_dict:
- raise KeyError(f"模型 '{model_name}' 不存在")
- return self.models_dict[model_name]
-
- def get_provider(self, provider_name: str) -> APIProvider:
- """根据提供商名称获取API提供商信息"""
- if not provider_name:
- raise ValueError("API提供商名称不能为空")
- if provider_name not in self.api_providers_dict:
- raise KeyError(f"API提供商 '{provider_name}' 不存在")
- return self.api_providers_dict[provider_name]
+ return super().model_post_init(context)
-def load_config(config_path: str) -> Config:
+class ConfigManager:
+ """总配置管理类"""
+
+ def __init__(self):
+ self.bot_config_path: Path = BOT_CONFIG_PATH
+ self.model_config_path: Path = MODEL_CONFIG_PATH
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
+
+ def initialize(self):
+ logger.info(f"MaiCore当前版本: {MMC_VERSION}")
+ logger.info("正在品鉴配置文件...")
+ self.global_config: Config = self._load_global_config()
+ self.model_config: ModelConfig = self._load_model_config()
+ logger.info("非常的新鲜,非常的美味!")
+
+ def _load_global_config(self) -> Config:
+ config, updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION)
+ if updated:
+ sys.exit(0) # 先直接退出
+ return config
+
+ def _load_model_config(self) -> ModelConfig:
+ config, updated = load_config_from_file(ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True)
+ if updated:
+ sys.exit(0) # 先直接退出
+ return config
+
+ def get_global_config(self) -> Config:
+ return self.global_config
+
+ def get_model_config(self) -> ModelConfig:
+ return self.model_config
+
+
+def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None:
+ """生成新的配置文件
+
+ :param config_class: 配置类
+ :param config_path: 配置文件路径
+ :param inner_config_version: 配置文件版本号
"""
- 加载配置文件
- Args:
- config_path: 配置文件路径
- Returns:
- Config对象
- """
- # 读取配置文件
+ config = config_class()
+ write_config_to_file(config, config_path, inner_config_version)
+
+
+def load_config_from_file(
+ config_class: type[T], config_path: Path, new_ver: str, override_repr: bool = False
+) -> tuple[T, bool]:
+ attribute_data = AttributeData()
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
-
- # 创建Config对象
+ old_ver: str = config_data["inner"]["version"] # type: ignore
+ config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
+ config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
try:
- return Config.from_dict(config_data)
+ updated: bool = False
+ target_config = config_class.from_dict(attribute_data, config_data)
+ if compare_versions(old_ver, new_ver):
+ output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
+ write_config_to_file(target_config, config_path, new_ver, override_repr)
+ updated = True
+ return target_config, updated
except Exception as e:
- logger.critical("配置文件解析失败")
+ logger.critical(f"配置文件{config_path.name}解析失败")
raise e
-def api_ada_load_config(config_path: str) -> APIAdapterConfig:
+def write_config_to_file(
+ config: ConfigBase, config_path: Path, inner_config_version: str, override_repr: bool = False
+) -> None:
+ """将配置写入文件
+
+ :param config: 配置对象
+ :param config_path: 配置文件路径
"""
- 加载API适配器配置文件
- Args:
- config_path: 配置文件路径
- Returns:
- APIAdapterConfig对象
- """
- # 读取配置文件
- with open(config_path, "r", encoding="utf-8") as f:
- config_data = tomlkit.load(f)
+ # 创建空TOMLDocument
+ full_config_data = tomlkit.document()
- # 创建APIAdapterConfig对象
- try:
- return APIAdapterConfig.from_dict(config_data)
- except Exception as e:
- logger.critical("API适配器配置文件解析失败")
- raise e
+ # 首先写入配置文件版本信息
+ version_table = tomlkit.table()
+ version_table.add("version", inner_config_version)
+ full_config_data.add("inner", version_table)
+
+ # 递归解析配置项为表格
+ for config_item_name, config_item in type(config).model_fields.items():
+ if not config_item.repr and not override_repr:
+ continue
+ if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]:
+ continue
+ config_field = getattr(config, config_item_name)
+ if isinstance(config_field, ConfigBase):
+ full_config_data.add(
+ config_item_name, recursive_parse_item_to_table(config_field, override_repr=override_repr)
+ )
+ elif isinstance(config_field, list):
+ aot = tomlkit.aot()
+ for item in config_field:
+ if not isinstance(item, ConfigBase):
+ raise TypeError("配置写入只支持ConfigBase子类")
+ aot.append(recursive_parse_item_to_table(item, override_repr=override_repr))
+ full_config_data.add(config_item_name, aot)
+ else:
+ raise TypeError("配置写入只支持ConfigBase子类")
+
+ # 备份旧文件
+ if config_path.exists():
+ backup_root = config_path.parent / "old"
+ backup_root.mkdir(parents=True, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ backup_path = backup_root / f"{config_path.stem}_{timestamp}.toml"
+ config_path.replace(backup_path)
+
+ # 写入文件
+ with open(config_path, "w", encoding="utf-8") as f:
+ tomlkit.dump(full_config_data, f)
-# 获取配置文件路径
-logger.info(f"MaiCore当前版本: {MMC_VERSION}")
-update_config()
-update_model_config()
-
-logger.info("正在品鉴配置文件...")
-global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
-model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
-logger.info("非常的新鲜,非常的美味!")
+# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION)
+config_manager = ConfigManager()
+config_manager.initialize()
+# global_config = config_manager.get_global_config()
diff --git a/src/config/config_base.py b/src/config/config_base.py
index 02332ad6..7baf1bd6 100644
--- a/src/config/config_base.py
+++ b/src/config/config_base.py
@@ -2,27 +2,35 @@ import ast
import inspect
import types
+from dataclasses import dataclass, field
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field
-from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set
+from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set, Literal
-__all__ = ["ConfigBase", "Field"]
+__all__ = ["ConfigBase", "Field", "AttributeData"]
from src.common.logger import get_logger
logger = get_logger("ConfigBase")
+@dataclass
+class AttributeData:
+ missing_attributes: list[str] = field(default_factory=list)
+ """缺失的属性列表"""
+ redundant_attributes: list[str] = field(default_factory=list)
+ """多余的属性列表"""
+
+
class AttrDocBase:
"""解析字段说明的基类"""
field_docs: dict[str, str] = {}
- def __post_init__(self):
- self.field_docs = self._get_field_docs() # 全局仅获取一次并保留
+ def __post_init__(self, allow_extra_methods: bool = False):
+ self.field_docs = self._get_field_docs(allow_extra_methods) # 全局仅获取一次并保留
- @classmethod
- def _get_field_docs(cls) -> dict[str, str]:
+ def _get_field_docs(self, allow_extra_methods: bool) -> dict[str, str]:
"""
获取字段的说明字符串
@@ -30,11 +38,11 @@ class AttrDocBase:
:return: 字段说明字典,键为字段名,值为说明字符串
"""
# 获取类的源代码文本
- class_source = cls._get_class_source()
+ class_source = self._get_class_source()
# 解析源代码,找到对应的类定义节点
- class_node = cls._find_class_node(class_source)
+ class_node = self._find_class_node(class_source)
# 从类定义节点中提取字段文档
- return cls._extract_field_docs(class_node)
+ return self._extract_field_docs(class_node, allow_extra_methods)
@classmethod
def _get_class_source(cls) -> str:
@@ -57,21 +65,22 @@ class AttrDocBase:
# 如果没有找到匹配的类定义,抛出异常
raise AttributeError(f"Class {cls.__name__} not found in source.")
- @classmethod
- def _extract_field_docs(cls, class_node: ast.ClassDef) -> dict[str, str]:
+ def _extract_field_docs(self, class_node: ast.ClassDef, allow_extra_methods: bool) -> dict[str, str]:
"""从类的 AST 节点中提取字段的文档字符串"""
+ # sourcery skip: merge-nested-ifs
doc_dict: dict[str, str] = {}
class_body = class_node.body # 类属性节点列表
for i in range(len(class_body)):
body_item = class_body[i]
- # 检查是否有非 model_post_init 的方法定义,如果有则抛出异常
- # 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法
- if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init":
- """检验ConfigBase子类中是否有除model_post_init以外的方法,规范配置类的定义"""
- raise AttributeError(
- f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}"
- ) from None
+ if not allow_extra_methods:
+ # 检查是否有非 model_post_init 的方法定义,如果有则抛出异常
+ # 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法
+ if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init":
+ """检验ConfigBase子类中是否有除model_post_init以外的方法,规范配置类的定义"""
+ raise AttributeError(
+ f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}"
+ ) from None
# 检查当前语句是否为带注解的赋值语句 (类型注解的字段定义)
# 并且下一个语句存在
@@ -110,13 +119,53 @@ class AttrDocBase:
class ConfigBase(BaseModel, AttrDocBase):
model_config = ConfigDict(validate_assignment=True, extra="forbid")
_validate_any: bool = True # 是否验证 Any 类型的使用,默认为 True
+ suppress_any_warning: bool = False # 是否抑制 Any 类型使用的警告,默认为 False,仅仅在_validate_any 为 False 时生效
+
+ @classmethod
+ def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):
+ """从字典创建配置对象,并收集缺失和多余的属性信息"""
+ class_fields = set(cls.model_fields.keys())
+ class_fields.remove("field_docs") # 忽略 field_docs 字段
+ if "_validate_any" in class_fields:
+ class_fields.remove("_validate_any") # 忽略 _validate_any 字段
+ if "suppress_any_warning" in class_fields:
+ class_fields.remove("suppress_any_warning") # 忽略 suppress_any_warning 字
+ for class_field in class_fields:
+ if class_field not in data:
+ attribute_data.missing_attributes.append(class_field) # 记录缺失的属性
+ cleaned_data_list: list[str] = []
+ for data_field in data:
+ if data_field not in class_fields:
+ cleaned_data_list.append(data_field)
+ attribute_data.redundant_attributes.append(data_field) # 记录多余的属性
+ for redundant_field in cleaned_data_list:
+ data.pop(redundant_field) # 移除多余的属性
+ # 对于是ConfigBase子类的字段,递归调用from_dict
+ class_field_infos = dict(cls.model_fields.items())
+ for field_data in data:
+ if info := class_field_infos.get(field_data):
+ field_type = info.annotation
+ if inspect.isclass(field_type) and issubclass(field_type, ConfigBase):
+ data[field_data] = field_type.from_dict(attribute_data, data[field_data])
+ if get_origin(field_type) in {list, List}:
+ elem_type = get_args(field_type)[0]
+ if inspect.isclass(elem_type) and issubclass(elem_type, ConfigBase):
+ data[field_data] = [elem_type.from_dict(attribute_data, item) for item in data[field_data]]
+ # 没有set,因为ConfigBase is not Hashable
+ if get_origin(field_type) in {dict, Dict}:
+ val_type = get_args(field_type)[1]
+ if inspect.isclass(val_type) and issubclass(val_type, ConfigBase):
+ data[field_data] = {
+ key: val_type.from_dict(attribute_data, val) for key, val in data[field_data].items()
+ }
+ return cls(**data)
def _discourage_any_usage(self, field_name: str) -> None:
"""警告使用 Any 类型的字段(可被suppress)"""
if self._validate_any:
raise TypeError(f"字段'{field_name}'中不允许使用 Any 类型注解")
- else:
- logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议避免使用。")
+ if not self.suppress_any_warning:
+ logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议使用更具体的类型注解以提高类型安全性")
def _get_real_type(self, annotation: type[Any] | Any | None):
"""获取真实类型,处理 dict 等没有参数的情况"""
@@ -157,10 +206,14 @@ class ConfigBase(BaseModel, AttrDocBase):
elem = args[0]
if elem is Any:
self._discourage_any_usage(field_name)
- if get_origin(elem) is not None:
+ elif get_origin(elem) is not None:
raise TypeError(
f"类'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。"
)
+ elif inspect.isclass(elem) and issubclass(elem, ConfigBase) and origin in (set, Set):
+ raise TypeError(
+ f"类'{type(self).__name__}'字段'{field_name}'中不允许使用 ConfigBase 子类作为 set 元素类型: {annotation}。ConfigBase is not Hashable。"
+ )
def _validate_dict_type(self, annotation: Any | None, field_name: str):
"""验证 dict 类型的使用"""
@@ -215,7 +268,7 @@ class ConfigBase(BaseModel, AttrDocBase):
if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore
continue
# 只允许 list, set, dict 三类泛型
- if origin_type not in (list, set, dict, List, Set, Dict):
+ if origin_type not in (list, set, dict, List, Set, Dict, Literal):
raise TypeError(
f"仅允许使用list, set, dict三种泛型类型注解,类'{type(self).__name__}'字段'{field_name}'中使用了: {annotation}"
)
diff --git a/src/config/config_utils.py b/src/config/config_utils.py
new file mode 100644
index 00000000..c18bccb1
--- /dev/null
+++ b/src/config/config_utils.py
@@ -0,0 +1,127 @@
+from pydantic.fields import FieldInfo
+from typing import Any, get_args, get_origin, TYPE_CHECKING, Literal, List, Set, Tuple, Dict
+from tomlkit import items
+import tomlkit
+
+from .config_base import ConfigBase
+
+if TYPE_CHECKING:
+ from .config import AttributeData
+
+
+def recursive_parse_item_to_table(
+ config: ConfigBase, is_inline_table: bool = False, override_repr: bool = False
+) -> items.Table | items.InlineTable:
+ # sourcery skip: merge-else-if-into-elif, reintroduce-else
+ """递归解析配置项为表格"""
+ config_table = tomlkit.table()
+ if is_inline_table:
+ config_table = tomlkit.inline_table()
+ for config_item_name, config_item_info in type(config).model_fields.items():
+ if not config_item_info.repr and not override_repr:
+ continue
+ value = getattr(config, config_item_name)
+ if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]:
+ continue
+ if value is None:
+ continue
+ if isinstance(value, ConfigBase):
+ config_table.add(config_item_name, recursive_parse_item_to_table(value, override_repr=override_repr))
+ else:
+ config_table.add(
+ config_item_name, convert_field(config_item_name, config_item_info, value, override_repr=override_repr)
+ )
+ if not is_inline_table:
+ config_table = comment_doc_string(config, config_item_name, config_table)
+ return config_table
+
+
+def comment_doc_string(
+ config: ConfigBase, field_name: str, toml_table: items.Table | items.InlineTable
+) -> items.Table | items.InlineTable:
+ """将配置类中的注释加入toml表格中"""
+ if doc_string := config.field_docs.get(field_name, ""):
+ doc_string_splitted = doc_string.splitlines()
+ if len(doc_string_splitted) == 1 and not doc_string_splitted[0].strip().startswith("_wrap_"):
+ if isinstance(toml_table[field_name], bool):
+ # tomlkit 故意设计的行为,布尔值不能直接添加注释
+ value = toml_table[field_name]
+ item = tomlkit.item(value)
+ item.comment(doc_string_splitted[0])
+ toml_table[field_name] = item
+ else:
+ toml_table[field_name].comment(doc_string_splitted[0])
+ else:
+ if doc_string_splitted[0].strip().startswith("_wrap_"):
+ doc_string_splitted[0] = doc_string_splitted[0].replace("_wrap_", "", 1).strip()
+ for line in doc_string_splitted:
+ toml_table.add(tomlkit.comment(line))
+ toml_table.add(tomlkit.nl())
+ return toml_table
+
+
+def convert_field(config_item_name: str, config_item_info: FieldInfo, value: Any, override_repr: bool = False):
+ # sourcery skip: extract-method
+ """将非可直接表达类转换为toml可表达类"""
+ field_type_origin = get_origin(config_item_info.annotation)
+ field_type_args = get_args(config_item_info.annotation)
+
+ if not field_type_origin: # 基础类型int,bool等直接添加
+ return value
+ elif field_type_origin in {list, set, List, Set}:
+ toml_list = tomlkit.array()
+ if field_type_args and isinstance(field_type_args[0], type) and issubclass(field_type_args[0], ConfigBase):
+ for item in value:
+ toml_list.append(recursive_parse_item_to_table(item, True, override_repr))
+ else:
+ for item in value:
+ toml_list.append(item)
+ return toml_list
+ elif field_type_origin in (tuple, Tuple):
+ toml_list = tomlkit.array()
+ for field_arg, item in zip(field_type_args, value, strict=True):
+ if isinstance(field_arg, type) and issubclass(field_arg, ConfigBase):
+ toml_list.append(recursive_parse_item_to_table(item, True, override_repr))
+ else:
+ toml_list.append(item)
+ return toml_list
+ elif field_type_origin in (dict, Dict):
+ if len(field_type_args) != 2:
+ raise TypeError(f"Expected a dictionary with two type arguments for {config_item_name}")
+ toml_sub_table = tomlkit.inline_table()
+ key_type, value_type = field_type_args
+ if key_type is not str:
+ raise TypeError(f"TOML only supports string keys for tables, got {key_type} for {config_item_name}")
+ for k, v in value.items():
+ if isinstance(value_type, type) and issubclass(value_type, ConfigBase):
+ toml_sub_table.add(k, recursive_parse_item_to_table(v, True, override_repr))
+ else:
+ toml_sub_table.add(k, v)
+ return toml_sub_table
+ elif field_type_origin is Literal:
+ if value not in field_type_args:
+ raise ValueError(f"Value {value} not in Literal options {field_type_args} for {config_item_name}")
+ return value
+ else:
+ raise TypeError(f"Unsupported field type for {config_item_name}: {config_item_info.annotation}")
+
+
+def output_config_changes(attr_data: "AttributeData", logger, old_ver: str, new_ver: str, file_name: str):
+ """输出配置变更信息"""
+ logger.info("-------- 配置文件变更信息 --------")
+ logger.info(f"新增配置数量: {len(attr_data.missing_attributes)}")
+ for attr in attr_data.missing_attributes:
+ logger.info(f"配置文件中新增配置项: {attr}")
+ logger.info(f"移除配置数量: {len(attr_data.redundant_attributes)}")
+ for attr in attr_data.redundant_attributes:
+ logger.warning(f"移除配置项: {attr}")
+ logger.info(
+ f"{file_name}配置文件已经更新. Old: {old_ver} -> New: {new_ver} 建议检查新配置文件中的内容, 以免丢失重要信息"
+ )
+
+
+def compare_versions(old_ver: str, new_ver: str) -> bool:
+ """比较版本号,返回是否有更新"""
+ old_parts = [int(part) for part in old_ver.split(".")]
+ new_parts = [int(part) for part in new_ver.split(".")]
+ return new_parts > old_parts
diff --git a/src/config/model_configs.py b/src/config/model_configs.py
new file mode 100644
index 00000000..374aef59
--- /dev/null
+++ b/src/config/model_configs.py
@@ -0,0 +1,129 @@
+from typing import Any
+from .config_base import ConfigBase, Field
+
+
+class APIProvider(ConfigBase):
+ """API提供商配置类"""
+
+ name: str = ""
+ """API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)"""
+
+ base_url: str = ""
+ """API服务商的BaseURL"""
+
+ api_key: str = Field(default_factory=str, repr=False)
+ """API密钥"""
+
+ client_type: str = Field(default="openai")
+ """客户端类型 (可选: openai/google, 默认为openai)"""
+
+ max_retry: int = Field(default=2)
+ """最大重试次数 (单个模型API调用失败, 最多重试的次数)"""
+
+ timeout: int = 10
+ """API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)"""
+
+ retry_interval: int = 10
+ """重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)"""
+
+ def model_post_init(self, context: Any = None):
+ """确保api_key在repr中不被显示"""
+ if not self.api_key:
+ raise ValueError("API密钥不能为空, 请在配置中设置有效的API密钥。")
+ if not self.base_url and self.client_type != "gemini": # TODO: 允许gemini使用base_url
+ raise ValueError("API基础URL不能为空, 请在配置中设置有效的基础URL。")
+ if not self.name:
+ raise ValueError("API提供商名称不能为空, 请在配置中设置有效的名称。")
+ return super().model_post_init(context)
+
+
+class ModelInfo(ConfigBase):
+ """单个模型信息配置类"""
+ _validate_any: bool = False
+ suppress_any_warning: bool = True
+
+ model_identifier: str = ""
+ """模型标识符 (API服务商提供的模型标识符)"""
+
+ name: str = ""
+ """模型名称 (可随意命名, 在models中需使用这个命名)"""
+
+ api_provider: str = ""
+ """API服务商名称 (对应在api_providers中配置的服务商名称)"""
+
+ price_in: float = Field(default=0.0)
+ """输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
+
+ price_out: float = Field(default=0.0)
+ """输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
+
+ temperature: float | None = Field(default=None)
+ """模型级别温度(可选),会覆盖任务配置中的温度"""
+
+ max_tokens: int | None = Field(default=None)
+ """模型级别最大token数(可选),会覆盖任务配置中的max_tokens"""
+
+ force_stream_mode: bool = Field(default=False)
+ """强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)"""
+
+ extra_params: dict[str, Any] = Field(default_factory=dict)
+ """额外参数 (用于API调用时的额外配置)"""
+
+ def model_post_init(self, context: Any = None):
+ if not self.model_identifier:
+ raise ValueError("模型标识符不能为空, 请在配置中设置有效的模型标识符。")
+ if not self.name:
+ raise ValueError("模型名称不能为空, 请在配置中设置有效的模型名称。")
+ if not self.api_provider:
+ raise ValueError("API提供商不能为空, 请在配置中设置有效的API提供商。")
+ return super().model_post_init(context)
+
+
+class TaskConfig(ConfigBase):
+ """任务配置类"""
+
+ model_list: list[str] = Field(default_factory=list)
+ """使用的模型列表, 每个元素对应上面的模型名称(name)"""
+
+ max_tokens: int = 1024
+ """任务最大输出token数"""
+
+ temperature: float = 0.3
+ """模型温度"""
+
+ slow_threshold: float = 15.0
+ """慢请求阈值(秒),超过此值会输出警告日志"""
+
+ selection_strategy: str = Field(default="balance")
+ """模型选择策略:balance(负载均衡)或 random(随机选择)"""
+
+
+class ModelTaskConfig(ConfigBase):
+ """模型配置类"""
+
+ utils: TaskConfig = Field(default_factory=TaskConfig)
+ """组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
+
+ replyer: TaskConfig = Field(default_factory=TaskConfig)
+ """首要回复模型配置, 还用于表达器和表达方式学习"""
+
+ vlm: TaskConfig = Field(default_factory=TaskConfig)
+ """视觉模型配置"""
+
+ voice: TaskConfig = Field(default_factory=TaskConfig)
+ """语音识别模型配置"""
+
+ tool_use: TaskConfig = Field(default_factory=TaskConfig)
+ """工具使用模型配置, 需要使用支持工具调用的模型"""
+
+ planner: TaskConfig = Field(default_factory=TaskConfig)
+ """规划模型配置"""
+
+ embedding: TaskConfig = Field(default_factory=TaskConfig)
+ """嵌入模型配置"""
+
+ lpmm_entity_extract: TaskConfig = Field(default_factory=TaskConfig)
+ """LPMM实体提取模型配置"""
+
+ lpmm_rdf_build: TaskConfig = Field(default_factory=TaskConfig)
+ """LPMM RDF构建模型配置"""
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
index a6652e0e..7291d188 100644
--- a/src/config/official_configs.py
+++ b/src/config/official_configs.py
@@ -1,117 +1,106 @@
+from .config_base import ConfigBase, Field
import re
-
-from dataclasses import dataclass, field
-from typing import Literal, Optional
-import time
-
-from src.config.config_base import ConfigBase
+from typing import Optional, Literal
"""
须知:
1. 本文件中记录了所有的配置项
2. 所有新增的class都需要继承自ConfigBase
-3. 所有新增的class都应在config.py中的Config类中添加字段
-4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
+3. 所有新增的class都应在official_configs.py中的Config类中添加字段
+4. 对于新增的字段,若为可选项,则应在其后添加Field()并设置default_factory或default
+5. 所有的配置项都应该按照如下方法添加字段说明:
+class ExampleConfig(ConfigBase):
+ example_field: str
+ \"""This is an example field\"""
+ - 注释前面增加_warp_标记可以实现配置文件中注释在配置项前面单独一行显示
"""
-@dataclass
class BotConfig(ConfigBase):
- """QQ机器人配置类"""
+ """机器人配置类"""
- platform: str
- """平台"""
-
- qq_account: str
- """QQ账号"""
-
- nickname: str
- """昵称"""
-
- platforms: list[str] = field(default_factory=lambda: [])
- """其他平台列表"""
-
- alias_names: list[str] = field(default_factory=lambda: [])
+ alias_names: list[str] = Field(default_factory=lambda: [])
"""别名列表"""
-@dataclass
class PersonalityConfig(ConfigBase):
"""人格配置类"""
- personality: str
- """人格"""
+ personality: str = "是一个大二在读女大学生,现在正在上网和群友聊天,有时有点攻击性,有时比较温柔"
+ """人格,建议100字以内,描述人格特质和身份特征"""
- reply_style: str = ""
- """默认表达风格"""
+ reply_style: str = "请不要刻意突出自身学科背景。可以参考贴吧,知乎和微博的回复风格。"
+ """默认表达风格,描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容,建议1-2行"""
- multiple_reply_style: list[str] = field(default_factory=lambda: [])
+ multiple_reply_style: list[str] = Field(default_factory=lambda: [])
"""可选的多种表达风格列表,当配置不为空时可按概率随机替换 reply_style"""
- multiple_probability: float = 0.0
+ multiple_probability: float = 0.3
"""每次构建回复时,从 multiple_reply_style 中随机替换 reply_style 的概率(0.0-1.0)"""
- plan_style: str = ""
- """说话规则,行为风格"""
+ plan_style: str = """1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用
+2.如果相同的action已经被执行,请不要重复执行该action
+3.如果有人对你感到厌烦,请减少回复
+4.如果有人在追问你,或者话题没有说完,请你继续回复
+5.请分析哪些对话是和你说的,哪些是其他人之间的互动,不要误认为其他人之间的互动是和你说的"""
+ """_wrap_麦麦的说话规则和行为规则"""
- visual_style: str = ""
- """图片提示词"""
+ visual_style: str = "请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本"
+ """_wrap_识图提示词,不建议修改"""
- states: list[str] = field(default_factory=lambda: [])
- """状态列表,用于随机替换personality"""
+ states: list[str] = Field(
+ default_factory=lambda: [
+ "是一个女大学生,喜欢上网聊天,会刷小红书。",
+ "是一个大二心理学生,会刷贴吧和中国知网。",
+ "是一个赛博网友,最近很想吐槽人。",
+ ]
+ )
+ """_wrap_状态列表,用于随机替换personality"""
- state_probability: float = 0.0
+ state_probability: float = 0.3
"""状态概率,每次构建人格时替换personality的概率"""
-@dataclass
class RelationshipConfig(ConfigBase):
"""关系配置类"""
enable_relationship: bool = True
- """是否启用关系系统"""
+ """是否启用关系系统,关系系统被移除,此部分配置暂时无效"""
+
+
+class TalkRulesItem(ConfigBase):
+ platform: str = ""
+ """平台,留空表示全局"""
+
+ id: str = ""
+ """用户ID"""
+
+ rule_type: Literal["group", "private"] = "group"
+ """聊天流类型,group(群聊)或private(私聊)"""
+
+ time: str = ""
+ """时间段,格式为 "HH:MM-HH:MM",支持跨夜区间"""
+
+ value: float = 0.5
+ """聊天频率值,范围0-1"""
-@dataclass
class ChatConfig(ConfigBase):
"""聊天配置类"""
- max_context_size: int = 18
- """上下文长度"""
+ talk_value: float = 1
+ """聊天频率,越小越沉默,范围0-1"""
mentioned_bot_reply: bool = True
"""是否启用提及必回复"""
- at_bot_inevitable_reply: float = 1
- """@bot 必然回复,1为100%回复,0为不额外增幅"""
+ max_context_size: int = 30
+ """上下文长度"""
planner_smooth: float = 3
- """规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐2-5,0为关闭,必须大于等于0"""
+ """规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐1-5,0为关闭,必须大于等于0"""
- talk_value: float = 1
- """思考频率"""
-
- enable_talk_value_rules: bool = True
- """是否启用动态发言频率规则"""
-
- talk_value_rules: list[dict] = field(default_factory=lambda: [])
- """
- 思考频率规则列表,支持按聊天流/按日内时段配置。
- 规则格式:{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
-
- 示例:
- [
- ["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
- ["", "09:00-22:59", 1.0], # 全局规则:白天正常
- ["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
- ["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
- ]
-
- 匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
- 时间区间支持跨夜,例如 "23:00-02:00"。
- """
-
- think_mode: Literal["classic", "deep", "dynamic"] = "classic"
+ think_mode: Literal["classic", "deep", "dynamic"] = "dynamic"
"""
思考模式配置
- classic: 默认think_level为0(轻量回复,不需要思考和回忆)
@@ -125,162 +114,63 @@ class ChatConfig(ConfigBase):
llm_quote: bool = False
"""是否在 reply action 中启用 quote 参数,启用后 LLM 可以控制是否引用消息"""
- def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
- """与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
- try:
- parts = stream_config_str.split(":")
- if len(parts) != 3:
- return None
+ enable_talk_value_rules: bool = True
+ """是否启用动态发言频率规则"""
- platform = parts[0]
- id_str = parts[1]
- stream_type = parts[2]
-
- is_group = stream_type == "group"
-
- from src.chat.message_receive.chat_stream import get_chat_manager
-
- return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
-
- except (ValueError, IndexError):
- return None
-
- def _now_minutes(self) -> int:
- """返回本地时间的分钟数(0-1439)。"""
- lt = time.localtime()
- return lt.tm_hour * 60 + lt.tm_min
-
- def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
- """解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
- try:
- start_str, end_str = [s.strip() for s in range_str.split("-")]
- sh, sm = [int(x) for x in start_str.split(":")]
- eh, em = [int(x) for x in end_str.split(":")]
- return sh * 60 + sm, eh * 60 + em
- except Exception:
- return None
-
- def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool:
- """
- 判断 now_min 是否在 [start_min, end_min] 区间内。
- 支持跨夜:如果 start > end,则表示跨越午夜。
- """
- if start_min <= end_min:
- return start_min <= now_min <= end_min
- # 跨夜:例如 23:00-02:00
- return now_min >= start_min or now_min <= end_min
-
- def get_talk_value(self, chat_id: Optional[str]) -> float:
- """根据规则返回当前 chat 的动态 talk_value,未匹配则回退到基础值。"""
- if not self.enable_talk_value_rules or not self.talk_value_rules:
- result = self.talk_value
- # 防止返回0值,自动转换为0.0001
- if result == 0:
- return 0.0000001
- return result
-
- now_min = self._now_minutes()
-
- # 1) 先尝试匹配指定 chat 的规则
- if chat_id:
- for rule in self.talk_value_rules:
- if not isinstance(rule, dict):
- continue
- target = rule.get("target", "")
- time_range = rule.get("time", "")
- value = rule.get("value", None)
- if not isinstance(time_range, str):
- continue
- # 跳过全局
- if target == "":
- continue
- config_chat_id = self._parse_stream_config_to_chat_id(str(target))
- if config_chat_id is None or config_chat_id != chat_id:
- continue
- parsed = self._parse_range(time_range)
- if not parsed:
- continue
- start_min, end_min = parsed
- if self._in_range(now_min, start_min, end_min):
- try:
- result = float(value)
- # 防止返回0值,自动转换为0.0001
- if result == 0:
- return 0.0000001
- return result
- except Exception:
- continue
-
- # 2) 再匹配全局规则("")
- for rule in self.talk_value_rules:
- if not isinstance(rule, dict):
- continue
- target = rule.get("target", None)
- time_range = rule.get("time", "")
- value = rule.get("value", None)
- if target != "" or not isinstance(time_range, str):
- continue
- parsed = self._parse_range(time_range)
- if not parsed:
- continue
- start_min, end_min = parsed
- if self._in_range(now_min, start_min, end_min):
- try:
- result = float(value)
- # 防止返回0值,自动转换为0.0001
- if result == 0:
- return 0.0000001
- return result
- except Exception:
- continue
-
- # 3) 未命中规则返回基础值
- result = self.talk_value
- # 防止返回0值,自动转换为0.0001
- if result == 0:
- return 0.0000001
- return result
+ talk_value_rules: list[TalkRulesItem] = Field(
+ default_factory=lambda: [
+ TalkRulesItem(platform="", id="", rule_type="group", time="00:00-08:59", value=0.8),
+ TalkRulesItem(platform="", id="", rule_type="group", time="09:00-18:59", value=1.0),
+ ]
+ )
+ """
+ _wrap_思考频率规则列表,支持按聊天流/按日内时段配置。
+ """
-@dataclass
class MessageReceiveConfig(ConfigBase):
"""消息接收配置类"""
- ban_words: set[str] = field(default_factory=lambda: set())
+ ban_words: set[str] = Field(default_factory=lambda: set())
"""过滤词列表"""
- ban_msgs_regex: set[str] = field(default_factory=lambda: set())
+ ban_msgs_regex: set[str] = Field(default_factory=lambda: set())
"""过滤正则表达式列表"""
+ def model_post_init(self, context: Optional[dict] = None) -> None:
+ for pattern in self.ban_msgs_regex:
+ try:
+ re.compile(pattern)
+ except re.error as e:
+ raise ValueError(f"Invalid regex pattern in ban_msgs_regex: '{pattern}'") from e
+ return super().model_post_init(context)
+
+
+class TargetItem(ConfigBase):
+ platform: str = ""
+ """平台,留空表示全局"""
+
+ id: str = ""
+ """用户ID"""
+
+ rule_type: Literal["group", "private"] = "group"
+ """聊天流类型,group(群聊)或private(私聊)"""
+
-@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
max_agent_iterations: int = 5
- """Agent最多迭代轮数(最低为1)"""
+ """记忆思考深度(最低为1)"""
agent_timeout_seconds: float = 120.0
- """Agent超时时间(秒)"""
+ """最长回忆时间(秒)"""
global_memory: bool = False
"""是否允许记忆检索在聊天记录中进行全局查询(忽略当前chat_id,仅对 search_chat_history 等工具生效)"""
- global_memory_blacklist: list[str] = field(default_factory=lambda: [])
- """
- 全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索
- 格式: ["platform:id:type", ...]
-
- 示例:
- [
- "qq:1919810:private", # 排除特定私聊
- "qq:114514:group", # 排除特定群聊
- ]
-
- 说明:
- - 当启用全局记忆时,黑名单中的聊天流不会被检索
- - 当在黑名单中的聊天流进行查询时,仅使用该聊天流的本地记忆
- """
+ global_memory_blacklist: list[TargetItem] = Field(default_factory=lambda: [])
+ """_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索"""
chat_history_topic_check_message_threshold: int = 80
"""聊天历史话题检查的消息数量阈值,当累积消息数达到此值时触发话题检查"""
@@ -297,234 +187,137 @@ class MemoryConfig(ConfigBase):
chat_history_finalize_message_count: int = 5
"""聊天历史话题打包存储的消息条数阈值,当话题的消息条数超过此值时触发打包存储"""
- def __post_init__(self):
+ chat_history_topic_check_message_threshold: int = 80
+ """聊天历史话题检查的消息数量阈值,当累积消息数达到此值时触发话题检查"""
+
+ chat_history_topic_check_time_hours: float = 8.0
+ """聊天历史话题检查的时间阈值(小时),当距离上次检查超过此时间且消息数达到最小阈值时触发话题检查"""
+
+ chat_history_topic_check_min_messages: int = 20
+ """聊天历史话题检查的时间触发模式下的最小消息数阈值"""
+
+ chat_history_finalize_no_update_checks: int = 3
+ """聊天历史话题打包存储的连续无更新检查次数阈值,当话题连续N次检查无新增内容时触发打包存储"""
+
+ chat_history_finalize_message_count: int = 5
+ """聊天历史话题打包存储的消息条数阈值,当话题的消息条数超过此值时触发打包存储"""
+
+ def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证配置值"""
if self.max_agent_iterations < 1:
raise ValueError(f"max_agent_iterations 必须至少为1,当前值: {self.max_agent_iterations}")
if self.agent_timeout_seconds <= 0:
raise ValueError(f"agent_timeout_seconds 必须大于0,当前值: {self.agent_timeout_seconds}")
if self.chat_history_topic_check_message_threshold < 1:
- raise ValueError(f"chat_history_topic_check_message_threshold 必须至少为1,当前值: {self.chat_history_topic_check_message_threshold}")
+ raise ValueError(
+ f"chat_history_topic_check_message_threshold 必须至少为1,当前值: {self.chat_history_topic_check_message_threshold}"
+ )
if self.chat_history_topic_check_time_hours <= 0:
- raise ValueError(f"chat_history_topic_check_time_hours 必须大于0,当前值: {self.chat_history_topic_check_time_hours}")
+ raise ValueError(
+ f"chat_history_topic_check_time_hours 必须大于0,当前值: {self.chat_history_topic_check_time_hours}"
+ )
if self.chat_history_topic_check_min_messages < 1:
- raise ValueError(f"chat_history_topic_check_min_messages 必须至少为1,当前值: {self.chat_history_topic_check_min_messages}")
+ raise ValueError(
+ f"chat_history_topic_check_min_messages 必须至少为1,当前值: {self.chat_history_topic_check_min_messages}"
+ )
if self.chat_history_finalize_no_update_checks < 1:
- raise ValueError(f"chat_history_finalize_no_update_checks 必须至少为1,当前值: {self.chat_history_finalize_no_update_checks}")
+ raise ValueError(
+ f"chat_history_finalize_no_update_checks 必须至少为1,当前值: {self.chat_history_finalize_no_update_checks}"
+ )
if self.chat_history_finalize_message_count < 1:
- raise ValueError(f"chat_history_finalize_message_count 必须至少为1,当前值: {self.chat_history_finalize_message_count}")
+ raise ValueError(
+ f"chat_history_finalize_message_count 必须至少为1,当前值: {self.chat_history_finalize_message_count}"
+ )
+ return super().model_post_init(context)
+
+
+class LearningItem(ConfigBase):
+ platform: str = ""
+ """平台,留空表示全局"""
+
+ id: str = ""
+ """用户ID"""
+
+ rule_type: Literal["group", "private"] = "group"
+ """聊天流类型,group(群聊)或private(私聊)"""
+
+ use_expression: bool = True
+ """是否启用表达学习"""
+
+ enable_learning: bool = True
+ """是否启用表达优化学习"""
+
+ enable_jargon_learning: bool = False
+ """是否启用jargon学习"""
+
+
+class ExpressionGroup(ConfigBase):
+ """表达互通组配置类,若列表为空代表全局共享"""
+
+ expression_groups: list[TargetItem] = Field(default_factory=lambda: [])
+ """_wrap_表达学习互通组"""
-@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
- learning_list: list[list] = field(default_factory=lambda: [])
- """
- 表达学习配置列表,支持按聊天流配置
- 格式: [["chat_stream_id", "use_expression", "enable_learning", "enable_jargon_learning"], ...]
-
- 示例:
- [
- ["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习
- ["qq:1919810:private", "enable", "enable", "enable"], # 特定私聊配置:使用表达,启用学习,启用jargon学习
- ["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习
- ]
-
- 说明:
- - 第一位: chat_stream_id,空字符串表示全局配置
- - 第二位: 是否使用学到的表达 ("enable"/"disable")
- - 第三位: 是否学习表达 ("enable"/"disable")
- - 第四位: 是否启用jargon学习 ("enable"/"disable")
- """
+ learning_list: list[LearningItem] = Field(
+ default_factory=lambda: [
+ LearningItem(
+ platform="",
+ id="",
+ rule_type="group",
+ use_expression=True,
+ enable_learning=True,
+ enable_jargon_learning=True,
+ )
+ ]
+ )
+ """_wrap_表达学习配置列表,支持按聊天流配置"""
- expression_groups: list[list[str]] = field(default_factory=list)
- """
- 表达学习互通组
- 格式: [["qq:12345:group", "qq:67890:private"]]
- """
+ expression_groups: list[ExpressionGroup] = Field(default_factory=list)
+ """_wrap_表达学习互通组"""
- expression_self_reflect: bool = False
+ expression_checked_only: bool = True
+ """是否仅选择已检查且未拒绝的表达方式"""
+
+ expression_self_reflect: bool = True
"""是否启用自动表达优化"""
-
+
+ expression_auto_check_interval: int = 600
+ """表达方式自动检查的间隔时间(秒)"""
+
+ expression_auto_check_count: int = 20
+ """每次自动检查时随机选取的表达方式数量"""
+
+ expression_auto_check_custom_criteria: list[str] = Field(default_factory=list)
+ """表达方式自动检查的额外自定义评估标准"""
+
expression_manual_reflect: bool = False
"""是否启用手动表达优化"""
- manual_reflect_operator_id: str = ""
- """表达反思操作员ID"""
+ manual_reflect_operator_id: Optional[TargetItem] = None
+ """手动表达优化操作员ID"""
- allow_reflect: list[str] = field(default_factory=list)
- """
- 允许进行表达反思的聊天流ID列表
- 格式: ["qq:123456:private", "qq:654321:group", ...]
- 只有在此列表中的聊天流才会提出问题并跟踪
- 如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true)
- """
+ allow_reflect: list[TargetItem] = Field(default_factory=list)
+ """允许进行表达反思的聊天流ID列表,只有在此列表中的聊天流才会提出问题并跟踪。如果列表为空,则所有聊天流都可以进行表达反思(前提是reflect为true)"""
- all_global_jargon: bool = False
- """是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id。注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除"""
+ all_global_jargon: bool = True
+ """是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除"""
enable_jargon_explanation: bool = True
"""是否在回复前尝试对上下文中的黑话进行解释(关闭可减少一次LLM调用,仅影响回复前的黑话匹配与解释,不影响黑话学习)"""
- jargon_mode: Literal["context", "planner"] = "context"
+ jargon_mode: Literal["context", "planner"] = "planner"
"""
- 黑话解释来源模式:
- - "context": 使用上下文自动匹配黑话并解释(原有模式)
- - "planner": 仅使用 Planner 在 reply 动作中给出的 unknown_words 列表进行黑话检索
- """
-
- expression_checked_only: bool = False
- """
- 是否仅选择已检查且未拒绝的表达方式
- 当设置为 true 时,只有 checked=True 且 rejected=False 的表达方式才会被选择
- 当设置为 false 时,保留旧的筛选原则(仅排除 rejected=True 的表达方式)
+ 黑话解释来源模式
+
+ 可选:
+ - "context":使用上下文自动匹配黑话
+ - "planner":仅使用Planner在reply动作中给出的unknown_words列表
"""
- expression_auto_check_interval: int = 3600
- """
- 表达方式自动检查的间隔时间(单位:秒)
- 默认值:3600秒(1小时)
- """
-
- expression_auto_check_count: int = 10
- """
- 每次自动检查时随机选取的表达方式数量
- 默认值:10条
- """
-
- expression_auto_check_custom_criteria: list[str] = field(default_factory=list)
- """
- 表达方式自动检查的额外自定义评估标准
- 格式: ["标准1", "标准2", "标准3", ...]
- 这些标准会被添加到评估提示词中,作为额外的评估要求
- 默认值:空列表
- """
-
- def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
- """
- 解析流配置字符串并生成对应的 chat_id
-
- Args:
- stream_config_str: 格式为 "platform:id:type" 的字符串
-
- Returns:
- str: 生成的 chat_id,如果解析失败则返回 None
- """
- try:
- parts = stream_config_str.split(":")
- if len(parts) != 3:
- return None
-
- platform = parts[0]
- id_str = parts[1]
- stream_type = parts[2]
-
- # 判断是否为群聊
- is_group = stream_type == "group"
-
- # 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
- from src.chat.message_receive.chat_stream import get_chat_manager
-
- return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
-
- except (ValueError, IndexError):
- return None
-
- def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
- """
- 根据聊天流ID获取表达配置
-
- Args:
- chat_stream_id: 聊天流ID,格式为哈希值
-
- Returns:
- tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
- """
- if not self.learning_list:
- # 如果没有配置,使用默认值:启用表达,启用学习,启用jargon学习
- return True, True, True
-
- # 优先检查聊天流特定的配置
- if chat_stream_id:
- specific_expression_config = self._get_stream_specific_config(chat_stream_id)
- if specific_expression_config is not None:
- return specific_expression_config
-
- # 检查全局配置(第一个元素为空字符串的配置)
- global_expression_config = self._get_global_config()
- if global_expression_config is not None:
- return global_expression_config
-
- # 如果都没有匹配,返回默认值:启用表达,启用学习,启用jargon学习
- return True, True, True
-
- def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, bool]]:
- """
- 获取特定聊天流的表达配置
-
- Args:
- chat_stream_id: 聊天流ID(哈希值)
-
- Returns:
- tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None
- """
- for config_item in self.learning_list:
- if not config_item or len(config_item) < 4:
- continue
-
- stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
-
- # 如果是空字符串,跳过(这是全局配置)
- if stream_config_str == "":
- continue
-
- # 解析配置字符串并生成对应的 chat_id
- config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
- if config_chat_id is None:
- continue
-
- # 比较生成的 chat_id
- if config_chat_id != chat_stream_id:
- continue
-
- # 解析配置
- try:
- use_expression: bool = config_item[1].lower() == "enable"
- enable_learning: bool = config_item[2].lower() == "enable"
- enable_jargon_learning: bool = config_item[3].lower() == "enable"
- return use_expression, enable_learning, enable_jargon_learning # type: ignore
- except (ValueError, IndexError):
- continue
-
- return None
-
- def _get_global_config(self) -> Optional[tuple[bool, bool, bool]]:
- """
- 获取全局表达配置
-
- Returns:
- tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None
- """
- for config_item in self.learning_list:
- if not config_item or len(config_item) < 4:
- continue
-
- # 检查是否为全局配置(第一个元素为空字符串)
- if config_item[0] == "":
- try:
- use_expression: bool = config_item[1].lower() == "enable"
- enable_learning: bool = config_item[2].lower() == "enable"
- enable_jargon_learning: bool = config_item[3].lower() == "enable"
- return use_expression, enable_learning, enable_jargon_learning # type: ignore
- except (ValueError, IndexError):
- continue
-
- return None
-
-
-@dataclass
class ToolConfig(ConfigBase):
"""工具配置类"""
@@ -532,54 +325,51 @@ class ToolConfig(ConfigBase):
"""是否在聊天中启用工具"""
-@dataclass
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
enable_asr: bool = False
- """是否启用语音识别"""
+ """是否启用语音识别,启用后麦麦可以识别语音消息"""
-@dataclass
class EmojiConfig(ConfigBase):
"""表情包配置类"""
- emoji_chance: float = 0.6
+ emoji_chance: float = 0.4
"""发送表情包的基础概率"""
- max_reg_num: int = 200
+ max_reg_num: int = 100
"""表情包最大注册数量"""
do_replace: bool = True
- """达到最大注册数量时替换旧表情包"""
+ """达到最大注册数量时替换旧表情包,关闭则达到最大数量时不会继续收集表情包"""
- check_interval: int = 120
+ check_interval: int = 10
"""表情包检查间隔(分钟)"""
steal_emoji: bool = True
- """是否偷取表情包,让麦麦可以发送她保存的这些表情包"""
+ """是否偷取表情包,让麦麦可以将一些表情包据为己有"""
content_filtration: bool = False
- """是否开启表情包过滤"""
+ """是否启用表情包过滤,只有符合该要求的表情包才会被保存"""
filtration_prompt: str = "符合公序良俗"
- """表情包过滤要求"""
+ """表情包过滤要求,只有符合该要求的表情包才会被保存"""
-@dataclass
class KeywordRuleConfig(ConfigBase):
"""关键词规则配置类"""
- keywords: list[str] = field(default_factory=lambda: [])
+ keywords: list[str] = Field(default_factory=lambda: [])
"""关键词列表"""
- regex: list[str] = field(default_factory=lambda: [])
+ regex: list[str] = Field(default_factory=lambda: [])
"""正则表达式列表"""
reaction: str = ""
"""关键词触发的反应"""
- def __post_init__(self):
+ def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证配置"""
if not self.keywords and not self.regex:
raise ValueError("关键词规则必须至少包含keywords或regex中的一个")
@@ -587,33 +377,31 @@ class KeywordRuleConfig(ConfigBase):
if not self.reaction:
raise ValueError("关键词规则必须包含reaction")
- # 验证正则表达式
for pattern in self.regex:
try:
re.compile(pattern)
except re.error as e:
raise ValueError(f"无效的正则表达式 '{pattern}': {str(e)}") from e
+ return super().model_post_init(context)
-@dataclass
class KeywordReactionConfig(ConfigBase):
"""关键词配置类"""
- keyword_rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
+ keyword_rules: list[KeywordRuleConfig] = Field(default_factory=lambda: [])
"""关键词规则列表"""
- regex_rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
+ regex_rules: list[KeywordRuleConfig] = Field(default_factory=lambda: [])
"""正则表达式规则列表"""
- def __post_init__(self):
+ def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证配置"""
- # 验证所有规则
for rule in self.keyword_rules + self.regex_rules:
if not isinstance(rule, KeywordRuleConfig):
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
+ return super().model_post_init(context)
-@dataclass
class ResponsePostProcessConfig(ConfigBase):
"""回复后处理配置类"""
@@ -621,7 +409,6 @@ class ResponsePostProcessConfig(ConfigBase):
"""是否启用回复后处理,包括错别字生成器,回复分割器"""
-@dataclass
class ChineseTypoConfig(ConfigBase):
"""中文错别字配置类"""
@@ -641,27 +428,25 @@ class ChineseTypoConfig(ConfigBase):
"""整词替换概率"""
-@dataclass
class ResponseSplitterConfig(ConfigBase):
"""回复分割器配置类"""
enable: bool = True
"""是否启用回复分割器"""
- max_length: int = 256
+ max_length: int = 512
"""回复允许的最大长度"""
- max_sentence_num: int = 3
+ max_sentence_num: int = 8
"""回复允许的最大句子数"""
enable_kaomoji_protection: bool = False
"""是否启用颜文字保护"""
enable_overflow_return_all: bool = False
- """是否在超出句子数量限制时合并后一次性返回"""
+ """是否在句子数量超出回复允许的最大句子数时一次性返回全部内容"""
-@dataclass
class TelemetryConfig(ConfigBase):
"""遥测配置类"""
@@ -669,36 +454,6 @@ class TelemetryConfig(ConfigBase):
"""是否启用遥测"""
-@dataclass
-class WebUIConfig(ConfigBase):
- """WebUI配置类
-
- 注意: host 和 port 配置已移至环境变量 WEBUI_HOST 和 WEBUI_PORT
- """
-
- enabled: bool = True
- """是否启用WebUI"""
-
- mode: Literal["development", "production"] = "production"
- """运行模式:development(开发) 或 production(生产)"""
-
- anti_crawler_mode: Literal["false", "strict", "loose", "basic"] = "basic"
- """防爬虫模式:false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止)"""
-
- allowed_ips: str = "127.0.0.1"
- """IP白名单(逗号分隔,支持精确IP、CIDR格式和通配符)"""
-
- trusted_proxies: str = ""
- """信任的代理IP列表(逗号分隔),只有来自这些IP的X-Forwarded-For才被信任"""
-
- trust_xff: bool = False
- """是否启用X-Forwarded-For代理解析(默认false)"""
-
- secure_cookie: bool = False
- """是否启用安全Cookie(仅通过HTTPS传输,默认false)"""
-
-
-@dataclass
class DebugConfig(ConfigBase):
"""调试配置类"""
@@ -718,47 +473,51 @@ class DebugConfig(ConfigBase):
"""是否显示记忆检索相关prompt"""
show_planner_prompt: bool = False
- """是否显示planner相关提示词"""
+ """是否显示planner的prompt和原始返回结果"""
show_lpmm_paragraph: bool = False
"""是否显示lpmm找到的相关文段日志"""
-@dataclass
+class ExtraPromptItem(ConfigBase):
+ platform: str = ""
+ """平台,留空无效"""
+
+ id: str = ""
+ """用户ID,留空无效"""
+
+ rule_type: Literal["group", "private"] = "group"
+ """聊天流类型,group(群聊)或private(私聊)"""
+
+ prompt: str = ""
+ """额外的prompt内容"""
+
+ def model_post_init(self, context: Optional[dict] = None) -> None:
+ if not self.platform or not self.id or not self.prompt:
+ raise ValueError("ExtraPromptItem 中 platform, id 和 prompt 不能为空")
+ return super().model_post_init(context)
+
+
class ExperimentalConfig(ConfigBase):
"""实验功能配置类"""
- private_plan_style: str = ""
- """私聊说话规则,行为风格(实验性功能)"""
+ private_plan_style: str = """
+1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用
+2.如果相同的内容已经被执行,请不要重复执行
+3.某句话如果已经被回复过,不要重复回复"""
+ """_wrap_私聊说话规则,行为风格(实验性功能)"""
- chat_prompts: list[str] = field(default_factory=lambda: [])
- """
- 为指定聊天添加额外的prompt配置列表
- 格式: ["platform:id:type:prompt内容", ...]
-
- 示例:
- [
- "qq:114514:group:这是一个摄影群,你精通摄影知识",
- "qq:19198:group:这是一个二次元交流群",
- "qq:114514:private:这是你与好朋友的私聊"
- ]
-
- 说明:
- - platform: 平台名称,如 "qq"
- - id: 群ID或用户ID
- - type: "group" 或 "private"
- - prompt内容: 要添加的额外prompt文本
- """
+ chat_prompts: list[ExtraPromptItem] = Field(default_factory=lambda: [])
+ """_wrap_为指定聊天添加额外的prompt配置列表"""
lpmm_memory: bool = False
"""是否将聊天历史总结导入到LPMM知识库。开启后,chat_history_summarizer总结出的历史记录会同时导入到知识库"""
-@dataclass
class MaimMessageConfig(ConfigBase):
"""maim_message配置类"""
- auth_token: list[str] = field(default_factory=lambda: [])
+ auth_token: list[str] = Field(default_factory=lambda: [])
"""认证令牌,用于旧版API验证,为空则不启用验证"""
enable_api_server: bool = False
@@ -779,11 +538,10 @@ class MaimMessageConfig(ConfigBase):
api_server_key_file: str = ""
"""新版API Server SSL密钥文件路径"""
- api_server_allowed_api_keys: list[str] = field(default_factory=lambda: [])
+ api_server_allowed_api_keys: list[str] = Field(default_factory=lambda: [])
"""新版API Server允许的API Key列表,为空则允许所有连接"""
-@dataclass
class LPMMKnowledgeConfig(ConfigBase):
"""LPMM知识库配置类"""
@@ -791,40 +549,40 @@ class LPMMKnowledgeConfig(ConfigBase):
"""是否启用LPMM知识库"""
lpmm_mode: Literal["classic", "agent"] = "classic"
- """LPMM知识库模式,可选:classic经典模式,agent 模式,结合最新的记忆一同使用"""
+ """LPMM知识库模式,可选:classic经典模式,agent 模式"""
rag_synonym_search_top_k: int = 10
- """RAG同义词搜索的Top K数量"""
+ """同义检索TopK"""
rag_synonym_threshold: float = 0.8
- """RAG同义词搜索的相似度阈值"""
+ """同义阈值,相似度高于该值的关系会被当作同义词"""
info_extraction_workers: int = 3
- """信息提取工作线程数"""
+ """实体抽取同时执行线程数,非Pro模型不要设置超过5"""
qa_relation_search_top_k: int = 10
- """QA关系搜索的Top K数量"""
+ """关系检索TopK"""
qa_relation_threshold: float = 0.75
- """QA关系搜索的相似度阈值"""
+ """关系阈值,相似度高于该值的关系会被认为是相关关系"""
qa_paragraph_search_top_k: int = 1000
- """QA段落搜索的Top K数量"""
+ """段落检索TopK(不能过小,可能影响搜索结果)"""
qa_paragraph_node_weight: float = 0.05
- """QA段落节点权重"""
+ """段落节点权重(在图搜索&PPR计算中的权重,当搜索仅使用DPR时,此参数不起作用)"""
qa_ent_filter_top_k: int = 10
- """QA实体过滤的Top K数量"""
+ """实体过滤TopK"""
qa_ppr_damping: float = 0.8
- """QA PageRank阻尼系数"""
+ """PPR阻尼系数"""
qa_res_top_k: int = 10
- """QA最终结果的Top K数量"""
+ """最终提供段落TopK"""
embedding_dimension: int = 1024
- """嵌入向量维度,应该与模型的输出维度一致"""
+ """嵌入向量维度,输出维度"""
max_embedding_workers: int = 3
"""嵌入/抽取并发线程数"""
@@ -839,7 +597,6 @@ class LPMMKnowledgeConfig(ConfigBase):
"""是否启用PPR,低配机器可关闭"""
-@dataclass
class DreamConfig(ConfigBase):
"""Dream配置类"""
@@ -849,91 +606,23 @@ class DreamConfig(ConfigBase):
max_iterations: int = 20
"""做梦最大轮次,默认20轮"""
- first_delay_seconds: int = 60
- """程序启动后首次做梦前的延迟时间(秒),默认60秒"""
+ first_delay_seconds: int = 1800
+ """程序启动后首次做梦前的延迟时间(秒),默认1800秒"""
dream_send: str = ""
- """
- 做梦结果推送目标,格式为 "platform:user_id"
- 例如: "qq:123456" 表示在做梦结束后,将梦境文本额外发送给该QQ私聊用户。
- 为空字符串时不推送。
- """
+ """做梦结果推送目标,格式为 "platform:user_id,为空则不发送"""
- dream_time_ranges: list[str] = field(default_factory=lambda: [])
- """
- 做梦时间段配置列表,格式:["HH:MM-HH:MM", ...]
- 如果列表为空,则表示全天允许做梦。
- 如果配置了时间段,则只有在这些时间段内才会实际执行做梦流程。
- 时间段外,调度器仍会按间隔检查,但不会进入做梦流程。
-
- 示例:
- [
- "09:00-22:00", # 白天允许做梦
- "23:00-02:00", # 跨夜时间段(23:00到次日02:00)
- ]
-
- 支持跨夜区间,例如 "23:00-02:00" 表示从23:00到次日02:00。
- """
-
- def _now_minutes(self) -> int:
- """返回本地时间的分钟数(0-1439)。"""
- lt = time.localtime()
- return lt.tm_hour * 60 + lt.tm_min
-
- def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
- """解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
- try:
- start_str, end_str = [s.strip() for s in range_str.split("-")]
- sh, sm = [int(x) for x in start_str.split(":")]
- eh, em = [int(x) for x in end_str.split(":")]
- return sh * 60 + sm, eh * 60 + em
- except Exception:
- return None
-
- def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool:
- """
- 判断 now_min 是否在 [start_min, end_min] 区间内。
- 支持跨夜:如果 start > end,则表示跨越午夜。
- """
- if start_min <= end_min:
- return start_min <= now_min <= end_min
- # 跨夜:例如 23:00-02:00
- return now_min >= start_min or now_min <= end_min
-
- def is_in_dream_time(self) -> bool:
- """
- 检查当前时间是否在允许做梦的时间段内。
- 如果 dream_time_ranges 为空,则返回 True(全天允许)。
- """
- if not self.dream_time_ranges:
- return True
-
- now_min = self._now_minutes()
-
- for time_range in self.dream_time_ranges:
- if not isinstance(time_range, str):
- continue
- parsed = self._parse_range(time_range)
- if not parsed:
- continue
- start_min, end_min = parsed
- if self._in_range(now_min, start_min, end_min):
- return True
-
- return False
+ dream_time_ranges: list[str] = Field(default_factory=lambda: ["23:00-10:00"])
+ """_wrap_做梦时间段配置列表"""
dream_visible: bool = False
- """
- 做梦结果是否存储到上下文
- - True: 将梦境发送给配置的用户后,也会存储到聊天上下文中,在后续对话中可见
- - False: 仅发送梦境但不存储,不在后续对话上下文中出现
- """
+ """做梦结果发送后是否存储到上下文"""
- def __post_init__(self):
- """验证配置值"""
+ def model_post_init(self, context: Optional[dict] = None) -> None:
if self.interval_minutes < 1:
raise ValueError(f"interval_minutes 必须至少为1,当前值: {self.interval_minutes}")
if self.max_iterations < 1:
raise ValueError(f"max_iterations 必须至少为1,当前值: {self.max_iterations}")
if self.first_delay_seconds < 0:
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")
+ return super().model_post_init(context)
diff --git a/src/plugin_system/apis/constants.py b/src/plugin_system/apis/constants.py
new file mode 100644
index 00000000..88d74dca
--- /dev/null
+++ b/src/plugin_system/apis/constants.py
@@ -0,0 +1,22 @@
+from pathlib import Path
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class _SystemConstants:
+ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.parent.absolute().resolve()
+ CONFIG_DIR: Path = PROJECT_ROOT / "config"
+ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
+ MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
+ PLUGINS_DIR: Path = (PROJECT_ROOT / "plugins").resolve().absolute()
+ INTERNAL_PLUGINS_DIR: Path = (PROJECT_ROOT / "src" / "plugins").resolve().absolute()
+
+
+_system_constants = _SystemConstants()
+
+PROJECT_ROOT: Path = _system_constants.PROJECT_ROOT
+CONFIG_DIR: Path = _system_constants.CONFIG_DIR
+BOT_CONFIG_PATH: Path = _system_constants.BOT_CONFIG_PATH
+MODEL_CONFIG_PATH: Path = _system_constants.MODEL_CONFIG_PATH
+PLUGINS_DIR: Path = _system_constants.PLUGINS_DIR
+INTERNAL_PLUGINS_DIR: Path = _system_constants.INTERNAL_PLUGINS_DIR