mirror of https://github.com/Mai-with-u/MaiBot.git
解决ConfigBase问题,更严格测试,实际测试
parent
fd46d8a302
commit
9186d14100
16
bot.py
16
bot.py
|
|
@ -1,4 +1,4 @@
|
|||
raise RuntimeError("System Not Ready")
|
||||
# raise RuntimeError("System Not Ready")
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
|
|
@ -181,14 +181,14 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
|||
logger.info("正在优雅关闭麦麦...")
|
||||
|
||||
# 关闭 WebUI 服务器
|
||||
try:
|
||||
from src.webui.webui_server import get_webui_server
|
||||
# try:
|
||||
# from src.webui.webui_server import get_webui_server
|
||||
|
||||
webui_server = get_webui_server()
|
||||
if webui_server and webui_server._server:
|
||||
await webui_server.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭 WebUI 服务器时出错: {e}")
|
||||
# webui_server = get_webui_server()
|
||||
# if webui_server and webui_server._server:
|
||||
# await webui_server.shutdown()
|
||||
# except Exception as e:
|
||||
# logger.warning(f"关闭 WebUI 服务器时出错: {e}")
|
||||
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
|
|
|||
|
|
@ -2,15 +2,12 @@
|
|||
Version 0.2.2 - 2025-11-05
|
||||
|
||||
## 配置文件设计
|
||||
- [x] 使用 `toml` 作为配置文件格式
|
||||
- [x] <del>合理使用注释说明当前配置作用</del>(提案)
|
||||
- [x] 使用 python 方法作为配置项说明(提案)
|
||||
- [x] 取消`bot_config_template.toml`
|
||||
- [x] 取消`model_config_template.toml`
|
||||
- [x] 配置类中的所有原子项目应该只包含以下类型: `str`, `int`, `float`, `bool`, `list`, `dict`, `set`
|
||||
- [ ] 暂时禁止使用 `Union` 类型(尚未支持解析)
|
||||
- [ ] 不建议使用`tuple`类型,使用时会发出警告,考虑使用嵌套`dataclass`替代
|
||||
- [x] 复杂类型使用嵌套配置类实现
|
||||
主体利用`pydantic`的`BaseModel`进行配置类设计`ConfigBase`类
|
||||
要求每个属性必须具有类型注解,且类型注解满足以下要求:
|
||||
- 原子类型仅允许使用: `str`, `int`, `float`, `bool`, 以及基于`ConfigBase`的嵌套配置类
|
||||
- 复杂类型允许使用: `list`, `dict`, `set`,但其内部类型必须为原子类型或嵌套配置类,不可使用`list[list[int]]`,`list[dict[str, int]]`等写法
|
||||
- 禁止了使用`Union`, `tuple/Tuple`类型
|
||||
- 但是`Optional`仍然允许使用
|
||||
### 移除template的方案提案
|
||||
<details>
|
||||
<summary>配置项说明的废案</summary>
|
||||
|
|
|
|||
|
|
@ -46,6 +46,10 @@ def patch_attrdoc_post_init():
|
|||
|
||||
config_base_module.logger = logging.getLogger("config_base_test_logger")
|
||||
|
||||
class SimpleClass(ConfigBase):
|
||||
a: int = 1
|
||||
b: str = "test"
|
||||
|
||||
|
||||
class TestConfigBase:
|
||||
# ---------------------------------------------------------
|
||||
|
|
@ -267,6 +271,18 @@ class TestConfigBase:
|
|||
"不允许嵌套泛型类型",
|
||||
id="listset-validation-nested-generic-inner-list",
|
||||
),
|
||||
pytest.param(
|
||||
List[SimpleClass],
|
||||
False,
|
||||
None,
|
||||
id="listset-validation-list-configbase-element_allow",
|
||||
),
|
||||
pytest.param(
|
||||
Set[SimpleClass],
|
||||
True,
|
||||
"ConfigBase is not Hashable",
|
||||
id="listset-validation-set-configbase-element_reject",
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
|
||||
|
|
@ -319,6 +335,12 @@ class TestConfigBase:
|
|||
"必须指定键和值的类型参数",
|
||||
id="dict-validation-missing-args",
|
||||
),
|
||||
pytest.param(
|
||||
Dict[str, SimpleClass],
|
||||
False,
|
||||
None,
|
||||
id="dict-validation-happy-configbase-value",
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
|
||||
|
|
@ -370,6 +392,22 @@ class TestConfigBase:
|
|||
|
||||
# Assert
|
||||
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
|
||||
|
||||
def test_discourage_any_usage_suppressed_warning(self, caplog):
|
||||
class Sample(ConfigBase):
|
||||
_validate_any: bool = False
|
||||
suppress_any_warning: bool = True
|
||||
|
||||
instance = Sample()
|
||||
|
||||
# Arrange
|
||||
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
|
||||
|
||||
# Act
|
||||
instance._discourage_any_usage("field_z")
|
||||
|
||||
# Assert
|
||||
assert "字段'field_z'中使用了 Any 类型注解" not in caplog.text
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# model_post_init 规则覆盖(错误与边界情况)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import platform
|
|||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.tcp_connector import get_tcp_connector
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, MMC_VERSION
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||
info_dict = {
|
||||
"os_type": "Unknown",
|
||||
"py_version": platform.python_version(),
|
||||
"mmc_version": global_config.MMC_VERSION,
|
||||
"mmc_version": MMC_VERSION,
|
||||
}
|
||||
|
||||
match platform.system():
|
||||
|
|
|
|||
|
|
@ -1,139 +0,0 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from .config_base import ConfigBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIProvider(ConfigBase):
|
||||
"""API提供商配置类"""
|
||||
|
||||
name: str
|
||||
"""API提供商名称"""
|
||||
|
||||
base_url: str
|
||||
"""API基础URL"""
|
||||
|
||||
api_key: str = field(default_factory=str, repr=False)
|
||||
"""API密钥列表"""
|
||||
|
||||
client_type: str = field(default="openai")
|
||||
"""客户端类型(如openai/google等,默认为openai)"""
|
||||
|
||||
max_retry: int = 2
|
||||
"""最大重试次数(单个模型API调用失败,最多重试的次数)"""
|
||||
|
||||
timeout: int = 10
|
||||
"""API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)"""
|
||||
|
||||
retry_interval: int = 10
|
||||
"""重试间隔(如果API调用失败,重试的间隔时间,单位:秒)"""
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.api_key
|
||||
|
||||
def __post_init__(self):
|
||||
"""确保api_key在repr中不被显示"""
|
||||
if not self.api_key:
|
||||
raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
|
||||
if not self.base_url and self.client_type != "gemini":
|
||||
raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
|
||||
if not self.name:
|
||||
raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo(ConfigBase):
|
||||
"""单个模型信息配置类"""
|
||||
|
||||
model_identifier: str
|
||||
"""模型标识符(用于URL调用)"""
|
||||
|
||||
name: str
|
||||
"""模型名称(用于模块调用)"""
|
||||
|
||||
api_provider: str
|
||||
"""API提供商(如OpenAI、Azure等)"""
|
||||
|
||||
price_in: float = field(default=0.0)
|
||||
"""每M token输入价格"""
|
||||
|
||||
price_out: float = field(default=0.0)
|
||||
"""每M token输出价格"""
|
||||
|
||||
temperature: float | None = field(default=None)
|
||||
"""模型级别温度(可选),会覆盖任务配置中的温度"""
|
||||
|
||||
max_tokens: int | None = field(default=None)
|
||||
"""模型级别最大token数(可选),会覆盖任务配置中的max_tokens"""
|
||||
|
||||
force_stream_mode: bool = field(default=False)
|
||||
"""是否强制使用流式输出模式"""
|
||||
|
||||
extra_params: dict = field(default_factory=dict)
|
||||
"""额外参数(用于API调用时的额外配置)"""
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.model_identifier:
|
||||
raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
|
||||
if not self.name:
|
||||
raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
|
||||
if not self.api_provider:
|
||||
raise ValueError("API提供商不能为空,请在配置中设置有效的API提供商。")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskConfig(ConfigBase):
|
||||
"""任务配置类"""
|
||||
|
||||
model_list: list[str] = field(default_factory=list)
|
||||
"""任务使用的模型列表"""
|
||||
|
||||
max_tokens: int = 1024
|
||||
"""任务最大输出token数"""
|
||||
|
||||
temperature: float = 0.3
|
||||
"""模型温度"""
|
||||
|
||||
slow_threshold: float = 15.0
|
||||
"""慢请求阈值(秒),超过此值会输出警告日志"""
|
||||
|
||||
selection_strategy: str = field(default="balance")
|
||||
"""模型选择策略:balance(负载均衡)或 random(随机选择)"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelTaskConfig(ConfigBase):
|
||||
"""模型配置类"""
|
||||
|
||||
utils: TaskConfig
|
||||
"""组件模型配置"""
|
||||
|
||||
replyer: TaskConfig
|
||||
"""normal_chat首要回复模型模型配置"""
|
||||
|
||||
vlm: TaskConfig
|
||||
"""视觉语言模型配置"""
|
||||
|
||||
voice: TaskConfig
|
||||
"""语音识别模型配置"""
|
||||
|
||||
tool_use: TaskConfig
|
||||
"""专注工具使用模型配置"""
|
||||
|
||||
planner: TaskConfig
|
||||
"""规划模型配置"""
|
||||
|
||||
embedding: TaskConfig
|
||||
"""嵌入模型配置"""
|
||||
|
||||
lpmm_entity_extract: TaskConfig
|
||||
"""LPMM实体提取模型配置"""
|
||||
|
||||
lpmm_rdf_build: TaskConfig
|
||||
"""LPMM RDF构建模型配置"""
|
||||
|
||||
def get_task(self, task_name: str) -> TaskConfig:
|
||||
"""获取指定任务的配置"""
|
||||
if hasattr(self, task_name):
|
||||
return getattr(self, task_name)
|
||||
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")
|
||||
|
|
@ -1,19 +1,12 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import tomlkit
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from datetime import datetime
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table, KeyType
|
||||
from dataclasses import field, dataclass
|
||||
from rich.traceback import install
|
||||
from typing import List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import format_toml_string
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.official_configs import (
|
||||
from .official_configs import (
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
ExpressionConfig,
|
||||
|
|
@ -34,345 +27,112 @@ from src.config.official_configs import (
|
|||
MemoryConfig,
|
||||
DebugConfig,
|
||||
DreamConfig,
|
||||
WebUIConfig,
|
||||
)
|
||||
from .model_configs import ModelInfo, ModelTaskConfig, APIProvider
|
||||
from .config_base import ConfigBase, Field, AttributeData
|
||||
from .config_utils import recursive_parse_item_to_table, output_config_changes, compare_versions
|
||||
|
||||
from .api_ada_configs import (
|
||||
ModelTaskConfig,
|
||||
ModelInfo,
|
||||
APIProvider,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
"""
|
||||
如果你想要修改配置文件,请递增version的值
|
||||
|
||||
install(extra_lines=3)
|
||||
版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
|
||||
主版本号:MMC版本更新
|
||||
次版本号:配置文件内容大更新
|
||||
修订号:配置文件内容小更新
|
||||
"""
|
||||
|
||||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "0.13.0"
|
||||
CONFIG_VERSION: str = "8.0.0"
|
||||
MODEL_CONFIG_VERSION: str = "1.12.0"
|
||||
|
||||
# 配置主程序日志格式
|
||||
logger = get_logger("config")
|
||||
|
||||
# 获取当前文件所在目录的父目录的父目录(即MaiBot项目根目录)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
CONFIG_DIR = os.path.join(PROJECT_ROOT, "config")
|
||||
TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.13.0-snapshot.1"
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
# 获取key的注释(如果有)
|
||||
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
||||
return toml_table.trivia.comment
|
||||
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
|
||||
item = toml_table.value.get(key)
|
||||
if item is not None and hasattr(item, "trivia"):
|
||||
return item.trivia.comment
|
||||
if hasattr(toml_table, "keys"):
|
||||
for k in toml_table.keys():
|
||||
if isinstance(k, KeyType) and k.key == key: # type: ignore
|
||||
return k.trivia.comment # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def compare_dicts(new, old, path=None, logs=None):
|
||||
# 递归比较两个dict,找出新增和删减项,收集注释
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
# 新增项
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in old:
|
||||
comment = get_key_comment(new, key)
|
||||
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
|
||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||
compare_dicts(new[key], old[key], path + [str(key)], logs)
|
||||
# 删减项
|
||||
for key in old:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in new:
|
||||
comment = get_key_comment(old, key)
|
||||
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
|
||||
return logs
|
||||
|
||||
|
||||
def get_value_by_path(d, path):
|
||||
for k in path:
|
||||
if isinstance(d, dict) and k in d:
|
||||
d = d[k]
|
||||
else:
|
||||
return None
|
||||
return d
|
||||
|
||||
|
||||
def set_value_by_path(d, path, value):
|
||||
"""设置嵌套字典中指定路径的值"""
|
||||
for k in path[:-1]:
|
||||
if k not in d or not isinstance(d[k], dict):
|
||||
d[k] = {}
|
||||
d = d[k]
|
||||
|
||||
# 使用 tomlkit.item 来保持 TOML 格式
|
||||
try:
|
||||
d[path[-1]] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
d[path[-1]] = value
|
||||
|
||||
|
||||
def compare_default_values(new, old, path=None, logs=None, changes=None):
|
||||
# 递归比较两个dict,找出默认值变化项
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
if changes is None:
|
||||
changes = []
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key in old:
|
||||
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
|
||||
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
|
||||
elif new[key] != old[key]:
|
||||
logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
|
||||
changes.append((path + [str(key)], old[key], new[key]))
|
||||
return logs, changes
|
||||
|
||||
|
||||
def _get_version_from_toml(toml_path) -> Optional[str]:
|
||||
"""从TOML文件中获取版本号"""
|
||||
if not os.path.exists(toml_path):
|
||||
return None
|
||||
with open(toml_path, "r", encoding="utf-8") as f:
|
||||
doc = tomlkit.load(f)
|
||||
if "inner" in doc and "version" in doc["inner"]: # type: ignore
|
||||
return doc["inner"]["version"] # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def _version_tuple(v):
|
||||
"""将版本字符串转换为元组以便比较"""
|
||||
if v is None:
|
||||
return (0,)
|
||||
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
|
||||
|
||||
|
||||
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
|
||||
_update_dict(target_value, value)
|
||||
else:
|
||||
try:
|
||||
# 统一使用 tomlkit.item 来保持原生类型与转义,不对列表做字符串化处理
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
|
||||
def _update_config_generic(config_name: str, template_name: str):
|
||||
"""
|
||||
通用的配置文件更新函数
|
||||
|
||||
Args:
|
||||
config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config'
|
||||
template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template'
|
||||
"""
|
||||
# 获取根目录路径
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
|
||||
|
||||
# 定义文件路径
|
||||
template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
|
||||
old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
|
||||
new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
|
||||
compare_path = os.path.join(compare_dir, f"{template_name}.toml")
|
||||
|
||||
# 创建compare目录(如果不存在)
|
||||
os.makedirs(compare_dir, exist_ok=True)
|
||||
|
||||
template_version = _get_version_from_toml(template_path)
|
||||
compare_version = _get_version_from_toml(compare_path)
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(old_config_path):
|
||||
logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置")
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
|
||||
shutil.copy2(template_path, old_config_path) # 复制模板文件
|
||||
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
|
||||
# 新创建配置文件,退出
|
||||
sys.exit(0)
|
||||
|
||||
compare_config = None
|
||||
new_config = None
|
||||
old_config = None
|
||||
|
||||
# 先读取 compare 下的模板(如果有),用于默认值变动检测
|
||||
if os.path.exists(compare_path):
|
||||
with open(compare_path, "r", encoding="utf-8") as f:
|
||||
compare_config = tomlkit.load(f)
|
||||
|
||||
# 读取当前模板
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查默认值变化并处理(只有 compare_config 存在时才做)
|
||||
if compare_config:
|
||||
# 读取旧配置
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
logs, changes = compare_default_values(new_config, compare_config)
|
||||
if logs:
|
||||
logger.info(f"检测到{config_name}模板默认值变动如下:")
|
||||
for log in logs:
|
||||
logger.info(log)
|
||||
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
|
||||
config_updated = False
|
||||
for path, old_default, new_default in changes:
|
||||
old_value = get_value_by_path(old_config, path)
|
||||
if old_value == old_default:
|
||||
set_value_by_path(old_config, path, new_default)
|
||||
logger.info(
|
||||
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
|
||||
)
|
||||
config_updated = True
|
||||
|
||||
# 如果配置有更新,立即保存到文件
|
||||
if config_updated:
|
||||
with open(old_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(format_toml_string(old_config))
|
||||
logger.info(f"已保存更新后的{config_name}配置文件")
|
||||
else:
|
||||
logger.info(f"未检测到{config_name}模板默认值变动")
|
||||
|
||||
# 检查 compare 下没有模板,或新模板版本更高,则复制
|
||||
if not os.path.exists(compare_path):
|
||||
shutil.copy2(template_path, compare_path)
|
||||
logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
|
||||
elif _version_tuple(template_version) > _version_tuple(compare_version):
|
||||
shutil.copy2(template_path, compare_path)
|
||||
logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}")
|
||||
else:
|
||||
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
|
||||
|
||||
# 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次)
|
||||
if old_config is None:
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
# new_config 已经读取
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version") # type: ignore
|
||||
new_version = new_config["inner"].get("version") # type: ignore
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(
|
||||
f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
|
||||
)
|
||||
else:
|
||||
logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建old目录(如果不存在)
|
||||
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
|
||||
|
||||
# 移动旧配置文件到old目录
|
||||
shutil.move(old_config_path, old_backup_path)
|
||||
logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path, new_config_path)
|
||||
logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
|
||||
|
||||
# 输出新增和删减项及注释
|
||||
if old_config:
|
||||
logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
|
||||
if logs := compare_dicts(new_config, old_config):
|
||||
for log in logs:
|
||||
logger.info(log)
|
||||
else:
|
||||
logger.info("无新增或删减项")
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info(f"开始合并{config_name}新旧配置...")
|
||||
_update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式,数组多行格式化)
|
||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(format_toml_string(new_config))
|
||||
logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||
|
||||
|
||||
def update_config():
|
||||
"""更新bot_config.toml配置文件"""
|
||||
_update_config_generic("bot_config", "bot_config_template")
|
||||
|
||||
|
||||
def update_model_config():
|
||||
"""更新model_config.toml配置文件"""
|
||||
_update_config_generic("model_config", "model_config_template")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(ConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
|
||||
bot: BotConfig = Field(default_factory=BotConfig)
|
||||
"""机器人配置类"""
|
||||
|
||||
bot: BotConfig
|
||||
personality: PersonalityConfig
|
||||
relationship: RelationshipConfig
|
||||
chat: ChatConfig
|
||||
message_receive: MessageReceiveConfig
|
||||
emoji: EmojiConfig
|
||||
expression: ExpressionConfig
|
||||
keyword_reaction: KeywordReactionConfig
|
||||
chinese_typo: ChineseTypoConfig
|
||||
response_post_process: ResponsePostProcessConfig
|
||||
response_splitter: ResponseSplitterConfig
|
||||
telemetry: TelemetryConfig
|
||||
webui: WebUIConfig
|
||||
experimental: ExperimentalConfig
|
||||
maim_message: MaimMessageConfig
|
||||
lpmm_knowledge: LPMMKnowledgeConfig
|
||||
tool: ToolConfig
|
||||
memory: MemoryConfig
|
||||
debug: DebugConfig
|
||||
voice: VoiceConfig
|
||||
dream: DreamConfig
|
||||
personality: PersonalityConfig = Field(default_factory=PersonalityConfig)
|
||||
"""人格配置类"""
|
||||
|
||||
expression: ExpressionConfig = Field(default_factory=ExpressionConfig)
|
||||
"""表达配置类"""
|
||||
|
||||
chat: ChatConfig = Field(default_factory=ChatConfig)
|
||||
"""聊天配置类"""
|
||||
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig)
|
||||
"""记忆配置类"""
|
||||
|
||||
relationship: RelationshipConfig = Field(default_factory=RelationshipConfig)
|
||||
"""关系配置类"""
|
||||
|
||||
message_receive: MessageReceiveConfig = Field(default_factory=MessageReceiveConfig)
|
||||
"""消息接收配置类"""
|
||||
|
||||
dream: DreamConfig = Field(default_factory=DreamConfig)
|
||||
"""做梦配置类"""
|
||||
|
||||
tool: ToolConfig = Field(default_factory=ToolConfig)
|
||||
"""工具配置类"""
|
||||
|
||||
voice: VoiceConfig = Field(default_factory=VoiceConfig)
|
||||
"""语音配置类"""
|
||||
|
||||
emoji: EmojiConfig = Field(default_factory=EmojiConfig)
|
||||
"""表情包配置类"""
|
||||
|
||||
keyword_reaction: KeywordReactionConfig = Field(default_factory=KeywordReactionConfig)
|
||||
"""关键词反应配置类"""
|
||||
|
||||
response_post_process: ResponsePostProcessConfig = Field(default_factory=ResponsePostProcessConfig)
|
||||
"""回复后处理配置类"""
|
||||
|
||||
chinese_typo: ChineseTypoConfig = Field(default_factory=ChineseTypoConfig)
|
||||
"""中文错别字生成器配置类"""
|
||||
|
||||
response_splitter: ResponseSplitterConfig = Field(default_factory=ResponseSplitterConfig)
|
||||
"""回复分割器配置类"""
|
||||
|
||||
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig)
|
||||
"""遥测配置类"""
|
||||
|
||||
debug: DebugConfig = Field(default_factory=DebugConfig)
|
||||
"""调试配置类"""
|
||||
|
||||
experimental: ExperimentalConfig = Field(default_factory=ExperimentalConfig)
|
||||
"""实验性功能配置类"""
|
||||
|
||||
maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig)
|
||||
"""maim_message配置类"""
|
||||
|
||||
lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig)
|
||||
"""LPMM知识库配置类"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIAdapterConfig(ConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
class ModelConfig(ConfigBase):
|
||||
"""模型配置类"""
|
||||
|
||||
models: List[ModelInfo]
|
||||
"""模型列表"""
|
||||
models: list[ModelInfo] = Field(default_factory=list)
|
||||
"""模型配置列表"""
|
||||
|
||||
model_task_config: ModelTaskConfig
|
||||
model_task_config: ModelTaskConfig = Field(default_factory=ModelTaskConfig)
|
||||
"""模型任务配置"""
|
||||
|
||||
api_providers: List[APIProvider] = field(default_factory=list)
|
||||
api_providers: list[APIProvider] = Field(default_factory=list)
|
||||
"""API提供商列表"""
|
||||
|
||||
def __post_init__(self):
|
||||
def model_post_init(self, context: Any = None):
|
||||
if not self.models:
|
||||
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
|
||||
if not self.api_providers:
|
||||
|
|
@ -388,78 +148,134 @@ class APIAdapterConfig(ConfigBase):
|
|||
if len(model_names) != len(set(model_names)):
|
||||
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||||
|
||||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self.models_dict = {model.name: model for model in self.models}
|
||||
api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
|
||||
for model in self.models:
|
||||
if not model.model_identifier:
|
||||
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
|
||||
if not model.api_provider or model.api_provider not in self.api_providers_dict:
|
||||
if not model.api_provider or model.api_provider not in api_providers_dict:
|
||||
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
|
||||
|
||||
def get_model_info(self, model_name: str) -> ModelInfo:
|
||||
"""根据模型名称获取模型信息"""
|
||||
if not model_name:
|
||||
raise ValueError("模型名称不能为空")
|
||||
if model_name not in self.models_dict:
|
||||
raise KeyError(f"模型 '{model_name}' 不存在")
|
||||
return self.models_dict[model_name]
|
||||
|
||||
def get_provider(self, provider_name: str) -> APIProvider:
|
||||
"""根据提供商名称获取API提供商信息"""
|
||||
if not provider_name:
|
||||
raise ValueError("API提供商名称不能为空")
|
||||
if provider_name not in self.api_providers_dict:
|
||||
raise KeyError(f"API提供商 '{provider_name}' 不存在")
|
||||
return self.api_providers_dict[provider_name]
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
class ConfigManager:
|
||||
"""总配置管理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.bot_config_path: Path = BOT_CONFIG_PATH
|
||||
self.model_config_path: Path = MODEL_CONFIG_PATH
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def initialize(self):
|
||||
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
||||
logger.info("正在品鉴配置文件...")
|
||||
self.global_config: Config = self._load_global_config()
|
||||
self.model_config: ModelConfig = self._load_model_config()
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
|
||||
def _load_global_config(self) -> Config:
|
||||
config, updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION)
|
||||
if updated:
|
||||
sys.exit(0) # 先直接退出
|
||||
return config
|
||||
|
||||
def _load_model_config(self) -> ModelConfig:
|
||||
config, updated = load_config_from_file(ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True)
|
||||
if updated:
|
||||
sys.exit(0) # 先直接退出
|
||||
return config
|
||||
|
||||
def get_global_config(self) -> Config:
|
||||
return self.global_config
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
|
||||
def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None:
|
||||
"""生成新的配置文件
|
||||
|
||||
:param config_class: 配置类
|
||||
:param config_path: 配置文件路径
|
||||
:param inner_config_version: 配置文件版本号
|
||||
"""
|
||||
加载配置文件
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
Returns:
|
||||
Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
config = config_class()
|
||||
write_config_to_file(config, config_path, inner_config_version)
|
||||
|
||||
|
||||
def load_config_from_file(
|
||||
config_class: type[T], config_path: Path, new_ver: str, override_repr: bool = False
|
||||
) -> tuple[T, bool]:
|
||||
attribute_data = AttributeData()
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建Config对象
|
||||
old_ver: str = config_data["inner"]["version"] # type: ignore
|
||||
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
|
||||
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
|
||||
try:
|
||||
return Config.from_dict(config_data)
|
||||
updated: bool = False
|
||||
target_config = config_class.from_dict(attribute_data, config_data)
|
||||
if compare_versions(old_ver, new_ver):
|
||||
output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
|
||||
write_config_to_file(target_config, config_path, new_ver, override_repr)
|
||||
updated = True
|
||||
return target_config, updated
|
||||
except Exception as e:
|
||||
logger.critical("配置文件解析失败")
|
||||
logger.critical(f"配置文件{config_path.name}解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
def api_ada_load_config(config_path: str) -> APIAdapterConfig:
|
||||
def write_config_to_file(
|
||||
config: ConfigBase, config_path: Path, inner_config_version: str, override_repr: bool = False
|
||||
) -> None:
|
||||
"""将配置写入文件
|
||||
|
||||
:param config: 配置对象
|
||||
:param config_path: 配置文件路径
|
||||
"""
|
||||
加载API适配器配置文件
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
Returns:
|
||||
APIAdapterConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
# 创建空TOMLDocument
|
||||
full_config_data = tomlkit.document()
|
||||
|
||||
# 创建APIAdapterConfig对象
|
||||
try:
|
||||
return APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("API适配器配置文件解析失败")
|
||||
raise e
|
||||
# 首先写入配置文件版本信息
|
||||
version_table = tomlkit.table()
|
||||
version_table.add("version", inner_config_version)
|
||||
full_config_data.add("inner", version_table)
|
||||
|
||||
# 递归解析配置项为表格
|
||||
for config_item_name, config_item in type(config).model_fields.items():
|
||||
if not config_item.repr and not override_repr:
|
||||
continue
|
||||
if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]:
|
||||
continue
|
||||
config_field = getattr(config, config_item_name)
|
||||
if isinstance(config_field, ConfigBase):
|
||||
full_config_data.add(
|
||||
config_item_name, recursive_parse_item_to_table(config_field, override_repr=override_repr)
|
||||
)
|
||||
elif isinstance(config_field, list):
|
||||
aot = tomlkit.aot()
|
||||
for item in config_field:
|
||||
if not isinstance(item, ConfigBase):
|
||||
raise TypeError("配置写入只支持ConfigBase子类")
|
||||
aot.append(recursive_parse_item_to_table(item, override_repr=override_repr))
|
||||
full_config_data.add(config_item_name, aot)
|
||||
else:
|
||||
raise TypeError("配置写入只支持ConfigBase子类")
|
||||
|
||||
# 备份旧文件
|
||||
if config_path.exists():
|
||||
backup_root = config_path.parent / "old"
|
||||
backup_root.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = backup_root / f"{config_path.stem}_{timestamp}.toml"
|
||||
config_path.replace(backup_path)
|
||||
|
||||
# 写入文件
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(full_config_data, f)
|
||||
|
||||
|
||||
# 获取配置文件路径
|
||||
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
||||
update_config()
|
||||
update_model_config()
|
||||
|
||||
logger.info("正在品鉴配置文件...")
|
||||
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
|
||||
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION)
|
||||
config_manager = ConfigManager()
|
||||
config_manager.initialize()
|
||||
# global_config = config_manager.get_global_config()
|
||||
|
|
|
|||
|
|
@ -2,27 +2,35 @@ import ast
|
|||
import inspect
|
||||
import types
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set
|
||||
from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set, Literal
|
||||
|
||||
__all__ = ["ConfigBase", "Field"]
|
||||
__all__ = ["ConfigBase", "Field", "AttributeData"]
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("ConfigBase")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttributeData:
|
||||
missing_attributes: list[str] = field(default_factory=list)
|
||||
"""缺失的属性列表"""
|
||||
redundant_attributes: list[str] = field(default_factory=list)
|
||||
"""多余的属性列表"""
|
||||
|
||||
|
||||
class AttrDocBase:
|
||||
"""解析字段说明的基类"""
|
||||
|
||||
field_docs: dict[str, str] = {}
|
||||
|
||||
def __post_init__(self):
|
||||
self.field_docs = self._get_field_docs() # 全局仅获取一次并保留
|
||||
def __post_init__(self, allow_extra_methods: bool = False):
|
||||
self.field_docs = self._get_field_docs(allow_extra_methods) # 全局仅获取一次并保留
|
||||
|
||||
@classmethod
|
||||
def _get_field_docs(cls) -> dict[str, str]:
|
||||
def _get_field_docs(self, allow_extra_methods: bool) -> dict[str, str]:
|
||||
"""
|
||||
获取字段的说明字符串
|
||||
|
||||
|
|
@ -30,11 +38,11 @@ class AttrDocBase:
|
|||
:return: 字段说明字典,键为字段名,值为说明字符串
|
||||
"""
|
||||
# 获取类的源代码文本
|
||||
class_source = cls._get_class_source()
|
||||
class_source = self._get_class_source()
|
||||
# 解析源代码,找到对应的类定义节点
|
||||
class_node = cls._find_class_node(class_source)
|
||||
class_node = self._find_class_node(class_source)
|
||||
# 从类定义节点中提取字段文档
|
||||
return cls._extract_field_docs(class_node)
|
||||
return self._extract_field_docs(class_node, allow_extra_methods)
|
||||
|
||||
@classmethod
|
||||
def _get_class_source(cls) -> str:
|
||||
|
|
@ -57,21 +65,22 @@ class AttrDocBase:
|
|||
# 如果没有找到匹配的类定义,抛出异常
|
||||
raise AttributeError(f"Class {cls.__name__} not found in source.")
|
||||
|
||||
@classmethod
|
||||
def _extract_field_docs(cls, class_node: ast.ClassDef) -> dict[str, str]:
|
||||
def _extract_field_docs(self, class_node: ast.ClassDef, allow_extra_methods: bool) -> dict[str, str]:
|
||||
"""从类的 AST 节点中提取字段的文档字符串"""
|
||||
# sourcery skip: merge-nested-ifs
|
||||
doc_dict: dict[str, str] = {}
|
||||
class_body = class_node.body # 类属性节点列表
|
||||
for i in range(len(class_body)):
|
||||
body_item = class_body[i]
|
||||
|
||||
# 检查是否有非 model_post_init 的方法定义,如果有则抛出异常
|
||||
# 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法
|
||||
if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init":
|
||||
"""检验ConfigBase子类中是否有除model_post_init以外的方法,规范配置类的定义"""
|
||||
raise AttributeError(
|
||||
f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}"
|
||||
) from None
|
||||
if not allow_extra_methods:
|
||||
# 检查是否有非 model_post_init 的方法定义,如果有则抛出异常
|
||||
# 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法
|
||||
if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init":
|
||||
"""检验ConfigBase子类中是否有除model_post_init以外的方法,规范配置类的定义"""
|
||||
raise AttributeError(
|
||||
f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}"
|
||||
) from None
|
||||
|
||||
# 检查当前语句是否为带注解的赋值语句 (类型注解的字段定义)
|
||||
# 并且下一个语句存在
|
||||
|
|
@ -110,13 +119,53 @@ class AttrDocBase:
|
|||
class ConfigBase(BaseModel, AttrDocBase):
|
||||
model_config = ConfigDict(validate_assignment=True, extra="forbid")
|
||||
_validate_any: bool = True # 是否验证 Any 类型的使用,默认为 True
|
||||
suppress_any_warning: bool = False # 是否抑制 Any 类型使用的警告,默认为 False,仅仅在_validate_any 为 False 时生效
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):
|
||||
"""从字典创建配置对象,并收集缺失和多余的属性信息"""
|
||||
class_fields = set(cls.model_fields.keys())
|
||||
class_fields.remove("field_docs") # 忽略 field_docs 字段
|
||||
if "_validate_any" in class_fields:
|
||||
class_fields.remove("_validate_any") # 忽略 _validate_any 字段
|
||||
if "suppress_any_warning" in class_fields:
|
||||
class_fields.remove("suppress_any_warning") # 忽略 suppress_any_warning 字
|
||||
for class_field in class_fields:
|
||||
if class_field not in data:
|
||||
attribute_data.missing_attributes.append(class_field) # 记录缺失的属性
|
||||
cleaned_data_list: list[str] = []
|
||||
for data_field in data:
|
||||
if data_field not in class_fields:
|
||||
cleaned_data_list.append(data_field)
|
||||
attribute_data.redundant_attributes.append(data_field) # 记录多余的属性
|
||||
for redundant_field in cleaned_data_list:
|
||||
data.pop(redundant_field) # 移除多余的属性
|
||||
# 对于是ConfigBase子类的字段,递归调用from_dict
|
||||
class_field_infos = dict(cls.model_fields.items())
|
||||
for field_data in data:
|
||||
if info := class_field_infos.get(field_data):
|
||||
field_type = info.annotation
|
||||
if inspect.isclass(field_type) and issubclass(field_type, ConfigBase):
|
||||
data[field_data] = field_type.from_dict(attribute_data, data[field_data])
|
||||
if get_origin(field_type) in {list, List}:
|
||||
elem_type = get_args(field_type)[0]
|
||||
if inspect.isclass(elem_type) and issubclass(elem_type, ConfigBase):
|
||||
data[field_data] = [elem_type.from_dict(attribute_data, item) for item in data[field_data]]
|
||||
# 没有set,因为ConfigBase is not Hashable
|
||||
if get_origin(field_type) in {dict, Dict}:
|
||||
val_type = get_args(field_type)[1]
|
||||
if inspect.isclass(val_type) and issubclass(val_type, ConfigBase):
|
||||
data[field_data] = {
|
||||
key: val_type.from_dict(attribute_data, val) for key, val in data[field_data].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
def _discourage_any_usage(self, field_name: str) -> None:
|
||||
"""警告使用 Any 类型的字段(可被suppress)"""
|
||||
if self._validate_any:
|
||||
raise TypeError(f"字段'{field_name}'中不允许使用 Any 类型注解")
|
||||
else:
|
||||
logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议避免使用。")
|
||||
if not self.suppress_any_warning:
|
||||
logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议使用更具体的类型注解以提高类型安全性")
|
||||
|
||||
def _get_real_type(self, annotation: type[Any] | Any | None):
|
||||
"""获取真实类型,处理 dict 等没有参数的情况"""
|
||||
|
|
@ -157,10 +206,14 @@ class ConfigBase(BaseModel, AttrDocBase):
|
|||
elem = args[0]
|
||||
if elem is Any:
|
||||
self._discourage_any_usage(field_name)
|
||||
if get_origin(elem) is not None:
|
||||
elif get_origin(elem) is not None:
|
||||
raise TypeError(
|
||||
f"类'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。"
|
||||
)
|
||||
elif inspect.isclass(elem) and issubclass(elem, ConfigBase) and origin in (set, Set):
|
||||
raise TypeError(
|
||||
f"类'{type(self).__name__}'字段'{field_name}'中不允许使用 ConfigBase 子类作为 set 元素类型: {annotation}。ConfigBase is not Hashable。"
|
||||
)
|
||||
|
||||
def _validate_dict_type(self, annotation: Any | None, field_name: str):
|
||||
"""验证 dict 类型的使用"""
|
||||
|
|
@ -215,7 +268,7 @@ class ConfigBase(BaseModel, AttrDocBase):
|
|||
if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore
|
||||
continue
|
||||
# 只允许 list, set, dict 三类泛型
|
||||
if origin_type not in (list, set, dict, List, Set, Dict):
|
||||
if origin_type not in (list, set, dict, List, Set, Dict, Literal):
|
||||
raise TypeError(
|
||||
f"仅允许使用list, set, dict三种泛型类型注解,类'{type(self).__name__}'字段'{field_name}'中使用了: {annotation}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,127 @@
|
|||
from pydantic.fields import FieldInfo
|
||||
from typing import Any, get_args, get_origin, TYPE_CHECKING, Literal, List, Set, Tuple, Dict
|
||||
from tomlkit import items
|
||||
import tomlkit
|
||||
|
||||
from .config_base import ConfigBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import AttributeData
|
||||
|
||||
|
||||
def recursive_parse_item_to_table(
|
||||
config: ConfigBase, is_inline_table: bool = False, override_repr: bool = False
|
||||
) -> items.Table | items.InlineTable:
|
||||
# sourcery skip: merge-else-if-into-elif, reintroduce-else
|
||||
"""递归解析配置项为表格"""
|
||||
config_table = tomlkit.table()
|
||||
if is_inline_table:
|
||||
config_table = tomlkit.inline_table()
|
||||
for config_item_name, config_item_info in type(config).model_fields.items():
|
||||
if not config_item_info.repr and not override_repr:
|
||||
continue
|
||||
value = getattr(config, config_item_name)
|
||||
if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]:
|
||||
continue
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, ConfigBase):
|
||||
config_table.add(config_item_name, recursive_parse_item_to_table(value, override_repr=override_repr))
|
||||
else:
|
||||
config_table.add(
|
||||
config_item_name, convert_field(config_item_name, config_item_info, value, override_repr=override_repr)
|
||||
)
|
||||
if not is_inline_table:
|
||||
config_table = comment_doc_string(config, config_item_name, config_table)
|
||||
return config_table
|
||||
|
||||
|
||||
def comment_doc_string(
|
||||
config: ConfigBase, field_name: str, toml_table: items.Table | items.InlineTable
|
||||
) -> items.Table | items.InlineTable:
|
||||
"""将配置类中的注释加入toml表格中"""
|
||||
if doc_string := config.field_docs.get(field_name, ""):
|
||||
doc_string_splitted = doc_string.splitlines()
|
||||
if len(doc_string_splitted) == 1 and not doc_string_splitted[0].strip().startswith("_wrap_"):
|
||||
if isinstance(toml_table[field_name], bool):
|
||||
# tomlkit 故意设计的行为,布尔值不能直接添加注释
|
||||
value = toml_table[field_name]
|
||||
item = tomlkit.item(value)
|
||||
item.comment(doc_string_splitted[0])
|
||||
toml_table[field_name] = item
|
||||
else:
|
||||
toml_table[field_name].comment(doc_string_splitted[0])
|
||||
else:
|
||||
if doc_string_splitted[0].strip().startswith("_wrap_"):
|
||||
doc_string_splitted[0] = doc_string_splitted[0].replace("_wrap_", "", 1).strip()
|
||||
for line in doc_string_splitted:
|
||||
toml_table.add(tomlkit.comment(line))
|
||||
toml_table.add(tomlkit.nl())
|
||||
return toml_table
|
||||
|
||||
|
||||
def convert_field(config_item_name: str, config_item_info: FieldInfo, value: Any, override_repr: bool = False):
|
||||
# sourcery skip: extract-method
|
||||
"""将非可直接表达类转换为toml可表达类"""
|
||||
field_type_origin = get_origin(config_item_info.annotation)
|
||||
field_type_args = get_args(config_item_info.annotation)
|
||||
|
||||
if not field_type_origin: # 基础类型int,bool等直接添加
|
||||
return value
|
||||
elif field_type_origin in {list, set, List, Set}:
|
||||
toml_list = tomlkit.array()
|
||||
if field_type_args and isinstance(field_type_args[0], type) and issubclass(field_type_args[0], ConfigBase):
|
||||
for item in value:
|
||||
toml_list.append(recursive_parse_item_to_table(item, True, override_repr))
|
||||
else:
|
||||
for item in value:
|
||||
toml_list.append(item)
|
||||
return toml_list
|
||||
elif field_type_origin in (tuple, Tuple):
|
||||
toml_list = tomlkit.array()
|
||||
for field_arg, item in zip(field_type_args, value, strict=True):
|
||||
if isinstance(field_arg, type) and issubclass(field_arg, ConfigBase):
|
||||
toml_list.append(recursive_parse_item_to_table(item, True, override_repr))
|
||||
else:
|
||||
toml_list.append(item)
|
||||
return toml_list
|
||||
elif field_type_origin in (dict, Dict):
|
||||
if len(field_type_args) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {config_item_name}")
|
||||
toml_sub_table = tomlkit.inline_table()
|
||||
key_type, value_type = field_type_args
|
||||
if key_type is not str:
|
||||
raise TypeError(f"TOML only supports string keys for tables, got {key_type} for {config_item_name}")
|
||||
for k, v in value.items():
|
||||
if isinstance(value_type, type) and issubclass(value_type, ConfigBase):
|
||||
toml_sub_table.add(k, recursive_parse_item_to_table(v, True, override_repr))
|
||||
else:
|
||||
toml_sub_table.add(k, v)
|
||||
return toml_sub_table
|
||||
elif field_type_origin is Literal:
|
||||
if value not in field_type_args:
|
||||
raise ValueError(f"Value {value} not in Literal options {field_type_args} for {config_item_name}")
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Unsupported field type for {config_item_name}: {config_item_info.annotation}")
|
||||
|
||||
|
||||
def output_config_changes(attr_data: "AttributeData", logger, old_ver: str, new_ver: str, file_name: str):
|
||||
"""输出配置变更信息"""
|
||||
logger.info("-------- 配置文件变更信息 --------")
|
||||
logger.info(f"新增配置数量: {len(attr_data.missing_attributes)}")
|
||||
for attr in attr_data.missing_attributes:
|
||||
logger.info(f"配置文件中新增配置项: {attr}")
|
||||
logger.info(f"移除配置数量: {len(attr_data.redundant_attributes)}")
|
||||
for attr in attr_data.redundant_attributes:
|
||||
logger.warning(f"移除配置项: {attr}")
|
||||
logger.info(
|
||||
f"{file_name}配置文件已经更新. Old: {old_ver} -> New: {new_ver} 建议检查新配置文件中的内容, 以免丢失重要信息"
|
||||
)
|
||||
|
||||
|
||||
def compare_versions(old_ver: str, new_ver: str) -> bool:
|
||||
"""比较版本号,返回是否有更新"""
|
||||
old_parts = [int(part) for part in old_ver.split(".")]
|
||||
new_parts = [int(part) for part in new_ver.split(".")]
|
||||
return new_parts > old_parts
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
from typing import Any
|
||||
from .config_base import ConfigBase, Field
|
||||
|
||||
|
||||
class APIProvider(ConfigBase):
|
||||
"""API提供商配置类"""
|
||||
|
||||
name: str = ""
|
||||
"""API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)"""
|
||||
|
||||
base_url: str = ""
|
||||
"""API服务商的BaseURL"""
|
||||
|
||||
api_key: str = Field(default_factory=str, repr=False)
|
||||
"""API密钥"""
|
||||
|
||||
client_type: str = Field(default="openai")
|
||||
"""客户端类型 (可选: openai/google, 默认为openai)"""
|
||||
|
||||
max_retry: int = Field(default=2)
|
||||
"""最大重试次数 (单个模型API调用失败, 最多重试的次数)"""
|
||||
|
||||
timeout: int = 10
|
||||
"""API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)"""
|
||||
|
||||
retry_interval: int = 10
|
||||
"""重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)"""
|
||||
|
||||
def model_post_init(self, context: Any = None):
|
||||
"""确保api_key在repr中不被显示"""
|
||||
if not self.api_key:
|
||||
raise ValueError("API密钥不能为空, 请在配置中设置有效的API密钥。")
|
||||
if not self.base_url and self.client_type != "gemini": # TODO: 允许gemini使用base_url
|
||||
raise ValueError("API基础URL不能为空, 请在配置中设置有效的基础URL。")
|
||||
if not self.name:
|
||||
raise ValueError("API提供商名称不能为空, 请在配置中设置有效的名称。")
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
class ModelInfo(ConfigBase):
|
||||
"""单个模型信息配置类"""
|
||||
_validate_any: bool = False
|
||||
suppress_any_warning: bool = True
|
||||
|
||||
model_identifier: str = ""
|
||||
"""模型标识符 (API服务商提供的模型标识符)"""
|
||||
|
||||
name: str = ""
|
||||
"""模型名称 (可随意命名, 在models中需使用这个命名)"""
|
||||
|
||||
api_provider: str = ""
|
||||
"""API服务商名称 (对应在api_providers中配置的服务商名称)"""
|
||||
|
||||
price_in: float = Field(default=0.0)
|
||||
"""输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
|
||||
|
||||
price_out: float = Field(default=0.0)
|
||||
"""输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
|
||||
|
||||
temperature: float | None = Field(default=None)
|
||||
"""模型级别温度(可选),会覆盖任务配置中的温度"""
|
||||
|
||||
max_tokens: int | None = Field(default=None)
|
||||
"""模型级别最大token数(可选),会覆盖任务配置中的max_tokens"""
|
||||
|
||||
force_stream_mode: bool = Field(default=False)
|
||||
"""强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)"""
|
||||
|
||||
extra_params: dict[str, Any] = Field(default_factory=dict)
|
||||
"""额外参数 (用于API调用时的额外配置)"""
|
||||
|
||||
def model_post_init(self, context: Any = None):
|
||||
if not self.model_identifier:
|
||||
raise ValueError("模型标识符不能为空, 请在配置中设置有效的模型标识符。")
|
||||
if not self.name:
|
||||
raise ValueError("模型名称不能为空, 请在配置中设置有效的模型名称。")
|
||||
if not self.api_provider:
|
||||
raise ValueError("API提供商不能为空, 请在配置中设置有效的API提供商。")
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
class TaskConfig(ConfigBase):
|
||||
"""任务配置类"""
|
||||
|
||||
model_list: list[str] = Field(default_factory=list)
|
||||
"""使用的模型列表, 每个元素对应上面的模型名称(name)"""
|
||||
|
||||
max_tokens: int = 1024
|
||||
"""任务最大输出token数"""
|
||||
|
||||
temperature: float = 0.3
|
||||
"""模型温度"""
|
||||
|
||||
slow_threshold: float = 15.0
|
||||
"""慢请求阈值(秒),超过此值会输出警告日志"""
|
||||
|
||||
selection_strategy: str = Field(default="balance")
|
||||
"""模型选择策略:balance(负载均衡)或 random(随机选择)"""
|
||||
|
||||
|
||||
class ModelTaskConfig(ConfigBase):
|
||||
"""模型配置类"""
|
||||
|
||||
utils: TaskConfig = Field(default_factory=TaskConfig)
|
||||
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
|
||||
|
||||
replyer: TaskConfig = Field(default_factory=TaskConfig)
|
||||
"""首要回复模型配置, 还用于表达器和表达方式学习"""
|
||||
|
||||
vlm: TaskConfig = Field(default_factory=TaskConfig)
|
||||
"""视觉模型配置"""
|
||||
|
||||
voice: TaskConfig = Field(default_factory=TaskConfig)
|
||||
"""语音识别模型配置"""
|
||||
|
||||
tool_use: TaskConfig = Field(default_factory=TaskConfig)
|
||||
"""工具使用模型配置, 需要使用支持工具调用的模型"""
|
||||
|
||||
planner: TaskConfig = Field(default_factory=TaskConfig)
|
||||
"""规划模型配置"""
|
||||
|
||||
embedding: TaskConfig = Field(default_factory=TaskConfig)
|
||||
"""嵌入模型配置"""
|
||||
|
||||
lpmm_entity_extract: TaskConfig = Field(default_factory=TaskConfig)
|
||||
"""LPMM实体提取模型配置"""
|
||||
|
||||
lpmm_rdf_build: TaskConfig = Field(default_factory=TaskConfig)
|
||||
"""LPMM RDF构建模型配置"""
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,22 @@
|
|||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _SystemConstants:
|
||||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||
CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
PLUGINS_DIR: Path = (PROJECT_ROOT / "plugins").resolve().absolute()
|
||||
INTERNAL_PLUGINS_DIR: Path = (PROJECT_ROOT / "src" / "plugins").resolve().absolute()
|
||||
|
||||
|
||||
_system_constants = _SystemConstants()
|
||||
|
||||
PROJECT_ROOT: Path = _system_constants.PROJECT_ROOT
|
||||
CONFIG_DIR: Path = _system_constants.CONFIG_DIR
|
||||
BOT_CONFIG_PATH: Path = _system_constants.BOT_CONFIG_PATH
|
||||
MODEL_CONFIG_PATH: Path = _system_constants.MODEL_CONFIG_PATH
|
||||
PLUGINS_DIR: Path = _system_constants.PLUGINS_DIR
|
||||
INTERNAL_PLUGINS_DIR: Path = _system_constants.INTERNAL_PLUGINS_DIR
|
||||
Loading…
Reference in New Issue