diff --git a/bot.py b/bot.py index 7faad3ed..02dee81c 100644 --- a/bot.py +++ b/bot.py @@ -1,4 +1,4 @@ -raise RuntimeError("System Not Ready") +# raise RuntimeError("System Not Ready") import asyncio import hashlib import os @@ -181,14 +181,14 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression logger.info("正在优雅关闭麦麦...") # 关闭 WebUI 服务器 - try: - from src.webui.webui_server import get_webui_server + # try: + # from src.webui.webui_server import get_webui_server - webui_server = get_webui_server() - if webui_server and webui_server._server: - await webui_server.shutdown() - except Exception as e: - logger.warning(f"关闭 WebUI 服务器时出错: {e}") + # webui_server = get_webui_server() + # if webui_server and webui_server._server: + # await webui_server.shutdown() + # except Exception as e: + # logger.warning(f"关闭 WebUI 服务器时出错: {e}") from src.plugin_system.core.events_manager import events_manager from src.plugin_system.base.component_types import EventType diff --git a/changelogs/mai_next_design.md b/changelogs/mai_next_design.md index fc870960..dcee4681 100644 --- a/changelogs/mai_next_design.md +++ b/changelogs/mai_next_design.md @@ -2,15 +2,12 @@ Version 0.2.2 - 2025-11-05 ## 配置文件设计 -- [x] 使用 `toml` 作为配置文件格式 -- [x] 合理使用注释说明当前配置作用(提案) -- [x] 使用 python 方法作为配置项说明(提案) - - [x] 取消`bot_config_template.toml` - - [x] 取消`model_config_template.toml` -- [x] 配置类中的所有原子项目应该只包含以下类型: `str`, `int`, `float`, `bool`, `list`, `dict`, `set` - - [ ] 暂时禁止使用 `Union` 类型(尚未支持解析) - - [ ] 不建议使用`tuple`类型,使用时会发出警告,考虑使用嵌套`dataclass`替代 - - [x] 复杂类型使用嵌套配置类实现 +主体利用`pydantic`的`BaseModel`进行配置类设计`ConfigBase`类 +要求每个属性必须具有类型注解,且类型注解满足以下要求: +- 原子类型仅允许使用: `str`, `int`, `float`, `bool`, 以及基于`ConfigBase`的嵌套配置类 +- 复杂类型允许使用: `list`, `dict`, `set`,但其内部类型必须为原子类型或嵌套配置类,不可使用`list[list[int]]`,`list[dict[str, int]]`等写法 +- 禁止了使用`Union`, `tuple/Tuple`类型 + - 但是`Optional`仍然允许使用 ### 移除template的方案提案
配置项说明的废案 diff --git a/pytests/config_test/test_config_base.py b/pytests/config_test/test_config_base.py index e210a82f..fea2fe8e 100644 --- a/pytests/config_test/test_config_base.py +++ b/pytests/config_test/test_config_base.py @@ -46,6 +46,10 @@ def patch_attrdoc_post_init(): config_base_module.logger = logging.getLogger("config_base_test_logger") +class SimpleClass(ConfigBase): + a: int = 1 + b: str = "test" + class TestConfigBase: # --------------------------------------------------------- @@ -267,6 +271,18 @@ class TestConfigBase: "不允许嵌套泛型类型", id="listset-validation-nested-generic-inner-list", ), + pytest.param( + List[SimpleClass], + False, + None, + id="listset-validation-list-configbase-element_allow", + ), + pytest.param( + Set[SimpleClass], + True, + "ConfigBase is not Hashable", + id="listset-validation-set-configbase-element_reject", + ) ], ) def test_validate_list_set_type(self, annotation, expect_error, error_fragment): @@ -319,6 +335,12 @@ class TestConfigBase: "必须指定键和值的类型参数", id="dict-validation-missing-args", ), + pytest.param( + Dict[str, SimpleClass], + False, + None, + id="dict-validation-happy-configbase-value", + ) ], ) def test_validate_dict_type(self, annotation, expect_error, error_fragment): @@ -370,6 +392,22 @@ class TestConfigBase: # Assert assert "字段'field_y'中使用了 Any 类型注解" in caplog.text + + def test_discourage_any_usage_suppressed_warning(self, caplog): + class Sample(ConfigBase): + _validate_any: bool = False + suppress_any_warning: bool = True + + instance = Sample() + + # Arrange + caplog.set_level(logging.WARNING, logger="config_base_test_logger") + + # Act + instance._discourage_any_usage("field_z") + + # Assert + assert "字段'field_z'中使用了 Any 类型注解" not in caplog.text # --------------------------------------------------------- # model_post_init 规则覆盖(错误与边界情况) diff --git a/src/common/remote.py b/src/common/remote.py index 5380cd01..098c13d1 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -5,7 +5,7 @@ import platform from src.common.logger import get_logger from src.common.tcp_connector import get_tcp_connector -from src.config.config import global_config +from src.config.config import global_config, MMC_VERSION from src.manager.async_task_manager import AsyncTask from src.manager.local_store_manager import local_storage @@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask): info_dict = { "os_type": "Unknown", "py_version": platform.python_version(), - "mmc_version": global_config.MMC_VERSION, + "mmc_version": MMC_VERSION, } match platform.system(): diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py deleted file mode 100644 index f450e91f..00000000 --- a/src/config/api_ada_configs.py +++ /dev/null @@ -1,139 +0,0 @@ -from dataclasses import dataclass, field - -from .config_base import ConfigBase - - -@dataclass -class APIProvider(ConfigBase): - """API提供商配置类""" - - name: str - """API提供商名称""" - - base_url: str - """API基础URL""" - - api_key: str = field(default_factory=str, repr=False) - """API密钥列表""" - - client_type: str = field(default="openai") - """客户端类型(如openai/google等,默认为openai)""" - - max_retry: int = 2 - """最大重试次数(单个模型API调用失败,最多重试的次数)""" - - timeout: int = 10 - """API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)""" - - retry_interval: int = 10 - """重试间隔(如果API调用失败,重试的间隔时间,单位:秒)""" - - def get_api_key(self) -> str: - return self.api_key - - def __post_init__(self): - """确保api_key在repr中不被显示""" - if not self.api_key: - raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。") - if not self.base_url and self.client_type != "gemini": - raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。") - if not self.name: - raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。") - - -@dataclass -class ModelInfo(ConfigBase): - """单个模型信息配置类""" - - model_identifier: str - """模型标识符(用于URL调用)""" - - name: str - """模型名称(用于模块调用)""" - - api_provider: str - """API提供商(如OpenAI、Azure等)""" - - price_in: float = field(default=0.0) - """每M token输入价格""" - - price_out: float = field(default=0.0) - """每M token输出价格""" - - temperature: float | None = field(default=None) - """模型级别温度(可选),会覆盖任务配置中的温度""" - - max_tokens: int | None = field(default=None) - """模型级别最大token数(可选),会覆盖任务配置中的max_tokens""" - - force_stream_mode: bool = field(default=False) - """是否强制使用流式输出模式""" - - extra_params: dict = field(default_factory=dict) - """额外参数(用于API调用时的额外配置)""" - - def __post_init__(self): - if not self.model_identifier: - raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。") - if not self.name: - raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。") - if not self.api_provider: - raise ValueError("API提供商不能为空,请在配置中设置有效的API提供商。") - - -@dataclass -class TaskConfig(ConfigBase): - """任务配置类""" - - model_list: list[str] = field(default_factory=list) - """任务使用的模型列表""" - - max_tokens: int = 1024 - """任务最大输出token数""" - - temperature: float = 0.3 - """模型温度""" - - slow_threshold: float = 15.0 - """慢请求阈值(秒),超过此值会输出警告日志""" - - selection_strategy: str = field(default="balance") - """模型选择策略:balance(负载均衡)或 random(随机选择)""" - - -@dataclass -class ModelTaskConfig(ConfigBase): - """模型配置类""" - - utils: TaskConfig - """组件模型配置""" - - replyer: TaskConfig - """normal_chat首要回复模型模型配置""" - - vlm: TaskConfig - """视觉语言模型配置""" - - voice: TaskConfig - """语音识别模型配置""" - - tool_use: TaskConfig - """专注工具使用模型配置""" - - planner: TaskConfig - """规划模型配置""" - - embedding: TaskConfig - """嵌入模型配置""" - - lpmm_entity_extract: TaskConfig - """LPMM实体提取模型配置""" - - lpmm_rdf_build: TaskConfig - """LPMM RDF构建模型配置""" - - def get_task(self, task_name: str) -> TaskConfig: - """获取指定任务的配置""" - if hasattr(self, task_name): - return getattr(self, task_name) - raise ValueError(f"任务 '{task_name}' 未找到对应的配置") diff --git a/src/config/config.py b/src/config/config.py index 0a3ba48c..a6685bc0 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,19 +1,12 @@ -import os +from pathlib import Path +from typing import TypeVar +from datetime import datetime +from typing import Any + import tomlkit -import shutil import sys -from datetime import datetime -from tomlkit import TOMLDocument -from tomlkit.items import Table, KeyType -from dataclasses import field, dataclass -from rich.traceback import install -from typing import List, Optional - -from src.common.logger import get_logger -from src.common.toml_utils import format_toml_string -from src.config.config_base import ConfigBase -from src.config.official_configs import ( +from .official_configs import ( BotConfig, PersonalityConfig, ExpressionConfig, @@ -34,345 +27,112 @@ from src.config.official_configs import ( MemoryConfig, DebugConfig, DreamConfig, - WebUIConfig, ) +from .model_configs import ModelInfo, ModelTaskConfig, APIProvider +from .config_base import ConfigBase, Field, AttributeData +from .config_utils import recursive_parse_item_to_table, output_config_changes, compare_versions -from .api_ada_configs import ( - ModelTaskConfig, - ModelInfo, - APIProvider, -) +from src.common.logger import get_logger +""" +如果你想要修改配置文件,请递增version的值 -install(extra_lines=3) +版本格式:主版本号.次版本号.修订号,版本号递增规则如下: + 主版本号:MMC版本更新 + 次版本号:配置文件内容大更新 + 修订号:配置文件内容小更新 +""" +PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve() +CONFIG_DIR: Path = PROJECT_ROOT / "config" +BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() +MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() +MMC_VERSION: str = "0.13.0" +CONFIG_VERSION: str = "8.0.0" +MODEL_CONFIG_VERSION: str = "1.12.0" -# 配置主程序日志格式 logger = get_logger("config") -# 获取当前文件所在目录的父目录的父目录(即MaiBot项目根目录) -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) -CONFIG_DIR = os.path.join(PROJECT_ROOT, "config") -TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") - -# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 -# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ -MMC_VERSION = "0.13.0-snapshot.1" +T = TypeVar("T", bound="ConfigBase") -def get_key_comment(toml_table, key): - # 获取key的注释(如果有) - if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"): - return toml_table.trivia.comment - if hasattr(toml_table, "value") and isinstance(toml_table.value, dict): - item = toml_table.value.get(key) - if item is not None and hasattr(item, "trivia"): - return item.trivia.comment - if hasattr(toml_table, "keys"): - for k in toml_table.keys(): - if isinstance(k, KeyType) and k.key == key: # type: ignore - return k.trivia.comment # type: ignore - return None - - -def compare_dicts(new, old, path=None, logs=None): - # 递归比较两个dict,找出新增和删减项,收集注释 - if path is None: - path = [] - if logs is None: - logs = [] - # 新增项 - for key in new: - if key == "version": - continue - if key not in old: - comment = get_key_comment(new, key) - logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") - elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): - compare_dicts(new[key], old[key], path + [str(key)], logs) - # 删减项 - for key in old: - if key == "version": - continue - if key not in new: - comment = get_key_comment(old, key) - logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") - return logs - - -def get_value_by_path(d, path): - for k in path: - if isinstance(d, dict) and k in d: - d = d[k] - else: - return None - return d - - -def set_value_by_path(d, path, value): - """设置嵌套字典中指定路径的值""" - for k in path[:-1]: - if k not in d or not isinstance(d[k], dict): - d[k] = {} - d = d[k] - - # 使用 tomlkit.item 来保持 TOML 格式 - try: - d[path[-1]] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - d[path[-1]] = value - - -def compare_default_values(new, old, path=None, logs=None, changes=None): - # 递归比较两个dict,找出默认值变化项 - if path is None: - path = [] - if logs is None: - logs = [] - if changes is None: - changes = [] - for key in new: - if key == "version": - continue - if key in old: - if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)): - compare_default_values(new[key], old[key], path + [str(key)], logs, changes) - elif new[key] != old[key]: - logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}") - changes.append((path + [str(key)], old[key], new[key])) - return logs, changes - - -def _get_version_from_toml(toml_path) -> Optional[str]: - """从TOML文件中获取版本号""" - if not os.path.exists(toml_path): - return None - with open(toml_path, "r", encoding="utf-8") as f: - doc = tomlkit.load(f) - if "inner" in doc and "version" in doc["inner"]: # type: ignore - return doc["inner"]["version"] # type: ignore - return None - - -def _version_tuple(v): - """将版本字符串转换为元组以便比较""" - if v is None: - return (0,) - return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) - - -def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): - """ - 将source字典的值更新到target字典中(如果target中存在相同的键) - """ - for key, value in source.items(): - # 跳过version字段的更新 - if key == "version": - continue - if key in target: - target_value = target[key] - if isinstance(value, dict) and isinstance(target_value, (dict, Table)): - _update_dict(target_value, value) - else: - try: - # 统一使用 tomlkit.item 来保持原生类型与转义,不对列表做字符串化处理 - target[key] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - target[key] = value - - -def _update_config_generic(config_name: str, template_name: str): - """ - 通用的配置文件更新函数 - - Args: - config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config' - template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template' - """ - # 获取根目录路径 - old_config_dir = os.path.join(CONFIG_DIR, "old") - compare_dir = os.path.join(TEMPLATE_DIR, "compare") - - # 定义文件路径 - template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml") - old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml") - new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml") - compare_path = os.path.join(compare_dir, f"{template_name}.toml") - - # 创建compare目录(如果不存在) - os.makedirs(compare_dir, exist_ok=True) - - template_version = _get_version_from_toml(template_path) - compare_version = _get_version_from_toml(compare_path) - - # 检查配置文件是否存在 - if not os.path.exists(old_config_path): - logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置") - os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 - shutil.copy2(template_path, old_config_path) # 复制模板文件 - logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}") - # 新创建配置文件,退出 - sys.exit(0) - - compare_config = None - new_config = None - old_config = None - - # 先读取 compare 下的模板(如果有),用于默认值变动检测 - if os.path.exists(compare_path): - with open(compare_path, "r", encoding="utf-8") as f: - compare_config = tomlkit.load(f) - - # 读取当前模板 - with open(template_path, "r", encoding="utf-8") as f: - new_config = tomlkit.load(f) - - # 检查默认值变化并处理(只有 compare_config 存在时才做) - if compare_config: - # 读取旧配置 - with open(old_config_path, "r", encoding="utf-8") as f: - old_config = tomlkit.load(f) - logs, changes = compare_default_values(new_config, compare_config) - if logs: - logger.info(f"检测到{config_name}模板默认值变动如下:") - for log in logs: - logger.info(log) - # 检查旧配置是否等于旧默认值,如果是则更新为新默认值 - config_updated = False - for path, old_default, new_default in changes: - old_value = get_value_by_path(old_config, path) - if old_value == old_default: - set_value_by_path(old_config, path, new_default) - logger.info( - f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" - ) - config_updated = True - - # 如果配置有更新,立即保存到文件 - if config_updated: - with open(old_config_path, "w", encoding="utf-8") as f: - f.write(format_toml_string(old_config)) - logger.info(f"已保存更新后的{config_name}配置文件") - else: - logger.info(f"未检测到{config_name}模板默认值变动") - - # 检查 compare 下没有模板,或新模板版本更高,则复制 - if not os.path.exists(compare_path): - shutil.copy2(template_path, compare_path) - logger.info(f"已将{config_name}模板文件复制到: {compare_path}") - elif _version_tuple(template_version) > _version_tuple(compare_version): - shutil.copy2(template_path, compare_path) - logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}") - else: - logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}") - - # 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次) - if old_config is None: - with open(old_config_path, "r", encoding="utf-8") as f: - old_config = tomlkit.load(f) - # new_config 已经读取 - - # 检查version是否相同 - if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") # type: ignore - new_version = new_config["inner"].get("version") # type: ignore - if old_version and new_version and old_version == new_version: - logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新") - return - else: - logger.info( - f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" - ) - else: - logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新") - - # 创建old目录(如果不存在) - os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml") - - # 移动旧配置文件到old目录 - shutil.move(old_config_path, old_backup_path) - logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}") - - # 复制模板文件到配置目录 - shutil.copy2(template_path, new_config_path) - logger.info(f"已创建新{config_name}配置文件: {new_config_path}") - - # 输出新增和删减项及注释 - if old_config: - logger.info(f"{config_name}配置项变动如下:\n----------------------------------------") - if logs := compare_dicts(new_config, old_config): - for log in logs: - logger.info(log) - else: - logger.info("无新增或删减项") - - # 将旧配置的值更新到新配置中 - logger.info(f"开始合并{config_name}新旧配置...") - _update_dict(new_config, old_config) - - # 保存更新后的配置(保留注释和格式,数组多行格式化) - with open(new_config_path, "w", encoding="utf-8") as f: - f.write(format_toml_string(new_config)) - logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") - - -def update_config(): - """更新bot_config.toml配置文件""" - _update_config_generic("bot_config", "bot_config_template") - - -def update_model_config(): - """更新model_config.toml配置文件""" - _update_config_generic("model_config", "model_config_template") - - -@dataclass class Config(ConfigBase): """总配置类""" - MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息 + bot: BotConfig = Field(default_factory=BotConfig) + """机器人配置类""" - bot: BotConfig - personality: PersonalityConfig - relationship: RelationshipConfig - chat: ChatConfig - message_receive: MessageReceiveConfig - emoji: EmojiConfig - expression: ExpressionConfig - keyword_reaction: KeywordReactionConfig - chinese_typo: ChineseTypoConfig - response_post_process: ResponsePostProcessConfig - response_splitter: ResponseSplitterConfig - telemetry: TelemetryConfig - webui: WebUIConfig - experimental: ExperimentalConfig - maim_message: MaimMessageConfig - lpmm_knowledge: LPMMKnowledgeConfig - tool: ToolConfig - memory: MemoryConfig - debug: DebugConfig - voice: VoiceConfig - dream: DreamConfig + personality: PersonalityConfig = Field(default_factory=PersonalityConfig) + """人格配置类""" + + expression: ExpressionConfig = Field(default_factory=ExpressionConfig) + """表达配置类""" + + chat: ChatConfig = Field(default_factory=ChatConfig) + """聊天配置类""" + + memory: MemoryConfig = Field(default_factory=MemoryConfig) + """记忆配置类""" + + relationship: RelationshipConfig = Field(default_factory=RelationshipConfig) + """关系配置类""" + + message_receive: MessageReceiveConfig = Field(default_factory=MessageReceiveConfig) + """消息接收配置类""" + + dream: DreamConfig = Field(default_factory=DreamConfig) + """做梦配置类""" + + tool: ToolConfig = Field(default_factory=ToolConfig) + """工具配置类""" + + voice: VoiceConfig = Field(default_factory=VoiceConfig) + """语音配置类""" + + emoji: EmojiConfig = Field(default_factory=EmojiConfig) + """表情包配置类""" + + keyword_reaction: KeywordReactionConfig = Field(default_factory=KeywordReactionConfig) + """关键词反应配置类""" + + response_post_process: ResponsePostProcessConfig = Field(default_factory=ResponsePostProcessConfig) + """回复后处理配置类""" + + chinese_typo: ChineseTypoConfig = Field(default_factory=ChineseTypoConfig) + """中文错别字生成器配置类""" + + response_splitter: ResponseSplitterConfig = Field(default_factory=ResponseSplitterConfig) + """回复分割器配置类""" + + telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig) + """遥测配置类""" + + debug: DebugConfig = Field(default_factory=DebugConfig) + """调试配置类""" + + experimental: ExperimentalConfig = Field(default_factory=ExperimentalConfig) + """实验性功能配置类""" + + maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig) + """maim_message配置类""" + + lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig) + """LPMM知识库配置类""" -@dataclass -class APIAdapterConfig(ConfigBase): - """API Adapter配置类""" +class ModelConfig(ConfigBase): + """模型配置类""" - models: List[ModelInfo] - """模型列表""" + models: list[ModelInfo] = Field(default_factory=list) + """模型配置列表""" - model_task_config: ModelTaskConfig + model_task_config: ModelTaskConfig = Field(default_factory=ModelTaskConfig) """模型任务配置""" - api_providers: List[APIProvider] = field(default_factory=list) + api_providers: list[APIProvider] = Field(default_factory=list) """API提供商列表""" - def __post_init__(self): + def model_post_init(self, context: Any = None): if not self.models: raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。") if not self.api_providers: @@ -388,78 +148,134 @@ class APIAdapterConfig(ConfigBase): if len(model_names) != len(set(model_names)): raise ValueError("模型名称存在重复,请检查配置文件。") - self.api_providers_dict = {provider.name: provider for provider in self.api_providers} - self.models_dict = {model.name: model for model in self.models} + api_providers_dict = {provider.name: provider for provider in self.api_providers} for model in self.models: if not model.model_identifier: raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空") - if not model.api_provider or model.api_provider not in self.api_providers_dict: + if not model.api_provider or model.api_provider not in api_providers_dict: raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在") - - def get_model_info(self, model_name: str) -> ModelInfo: - """根据模型名称获取模型信息""" - if not model_name: - raise ValueError("模型名称不能为空") - if model_name not in self.models_dict: - raise KeyError(f"模型 '{model_name}' 不存在") - return self.models_dict[model_name] - - def get_provider(self, provider_name: str) -> APIProvider: - """根据提供商名称获取API提供商信息""" - if not provider_name: - raise ValueError("API提供商名称不能为空") - if provider_name not in self.api_providers_dict: - raise KeyError(f"API提供商 '{provider_name}' 不存在") - return self.api_providers_dict[provider_name] + return super().model_post_init(context) -def load_config(config_path: str) -> Config: +class ConfigManager: + """总配置管理类""" + + def __init__(self): + self.bot_config_path: Path = BOT_CONFIG_PATH + self.model_config_path: Path = MODEL_CONFIG_PATH + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + def initialize(self): + logger.info(f"MaiCore当前版本: {MMC_VERSION}") + logger.info("正在品鉴配置文件...") + self.global_config: Config = self._load_global_config() + self.model_config: ModelConfig = self._load_model_config() + logger.info("非常的新鲜,非常的美味!") + + def _load_global_config(self) -> Config: + config, updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION) + if updated: + sys.exit(0) # 先直接退出 + return config + + def _load_model_config(self) -> ModelConfig: + config, updated = load_config_from_file(ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True) + if updated: + sys.exit(0) # 先直接退出 + return config + + def get_global_config(self) -> Config: + return self.global_config + + def get_model_config(self) -> ModelConfig: + return self.model_config + + +def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None: + """生成新的配置文件 + + :param config_class: 配置类 + :param config_path: 配置文件路径 + :param inner_config_version: 配置文件版本号 """ - 加载配置文件 - Args: - config_path: 配置文件路径 - Returns: - Config对象 - """ - # 读取配置文件 + config = config_class() + write_config_to_file(config, config_path, inner_config_version) + + +def load_config_from_file( + config_class: type[T], config_path: Path, new_ver: str, override_repr: bool = False +) -> tuple[T, bool]: + attribute_data = AttributeData() with open(config_path, "r", encoding="utf-8") as f: config_data = tomlkit.load(f) - - # 创建Config对象 + old_ver: str = config_data["inner"]["version"] # type: ignore + config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理 + config_data = config_data.unwrap() # 转换为普通字典,方便后续处理 try: - return Config.from_dict(config_data) + updated: bool = False + target_config = config_class.from_dict(attribute_data, config_data) + if compare_versions(old_ver, new_ver): + output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name) + write_config_to_file(target_config, config_path, new_ver, override_repr) + updated = True + return target_config, updated except Exception as e: - logger.critical("配置文件解析失败") + logger.critical(f"配置文件{config_path.name}解析失败") raise e -def api_ada_load_config(config_path: str) -> APIAdapterConfig: +def write_config_to_file( + config: ConfigBase, config_path: Path, inner_config_version: str, override_repr: bool = False +) -> None: + """将配置写入文件 + + :param config: 配置对象 + :param config_path: 配置文件路径 """ - 加载API适配器配置文件 - Args: - config_path: 配置文件路径 - Returns: - APIAdapterConfig对象 - """ - # 读取配置文件 - with open(config_path, "r", encoding="utf-8") as f: - config_data = tomlkit.load(f) + # 创建空TOMLDocument + full_config_data = tomlkit.document() - # 创建APIAdapterConfig对象 - try: - return APIAdapterConfig.from_dict(config_data) - except Exception as e: - logger.critical("API适配器配置文件解析失败") - raise e + # 首先写入配置文件版本信息 + version_table = tomlkit.table() + version_table.add("version", inner_config_version) + full_config_data.add("inner", version_table) + + # 递归解析配置项为表格 + for config_item_name, config_item in type(config).model_fields.items(): + if not config_item.repr and not override_repr: + continue + if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]: + continue + config_field = getattr(config, config_item_name) + if isinstance(config_field, ConfigBase): + full_config_data.add( + config_item_name, recursive_parse_item_to_table(config_field, override_repr=override_repr) + ) + elif isinstance(config_field, list): + aot = tomlkit.aot() + for item in config_field: + if not isinstance(item, ConfigBase): + raise TypeError("配置写入只支持ConfigBase子类") + aot.append(recursive_parse_item_to_table(item, override_repr=override_repr)) + full_config_data.add(config_item_name, aot) + else: + raise TypeError("配置写入只支持ConfigBase子类") + + # 备份旧文件 + if config_path.exists(): + backup_root = config_path.parent / "old" + backup_root.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = backup_root / f"{config_path.stem}_{timestamp}.toml" + config_path.replace(backup_path) + + # 写入文件 + with open(config_path, "w", encoding="utf-8") as f: + tomlkit.dump(full_config_data, f) -# 获取配置文件路径 -logger.info(f"MaiCore当前版本: {MMC_VERSION}") -update_config() -update_model_config() - -logger.info("正在品鉴配置文件...") -global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml")) -model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml")) -logger.info("非常的新鲜,非常的美味!") +# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION) +config_manager = ConfigManager() +config_manager.initialize() +# global_config = config_manager.get_global_config() diff --git a/src/config/config_base.py b/src/config/config_base.py index 02332ad6..7baf1bd6 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -2,27 +2,35 @@ import ast import inspect import types +from dataclasses import dataclass, field from pathlib import Path from pydantic import BaseModel, ConfigDict, Field -from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set +from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set, Literal -__all__ = ["ConfigBase", "Field"] +__all__ = ["ConfigBase", "Field", "AttributeData"] from src.common.logger import get_logger logger = get_logger("ConfigBase") +@dataclass +class AttributeData: + missing_attributes: list[str] = field(default_factory=list) + """缺失的属性列表""" + redundant_attributes: list[str] = field(default_factory=list) + """多余的属性列表""" + + class AttrDocBase: """解析字段说明的基类""" field_docs: dict[str, str] = {} - def __post_init__(self): - self.field_docs = self._get_field_docs() # 全局仅获取一次并保留 + def __post_init__(self, allow_extra_methods: bool = False): + self.field_docs = self._get_field_docs(allow_extra_methods) # 全局仅获取一次并保留 - @classmethod - def _get_field_docs(cls) -> dict[str, str]: + def _get_field_docs(self, allow_extra_methods: bool) -> dict[str, str]: """ 获取字段的说明字符串 @@ -30,11 +38,11 @@ class AttrDocBase: :return: 字段说明字典,键为字段名,值为说明字符串 """ # 获取类的源代码文本 - class_source = cls._get_class_source() + class_source = self._get_class_source() # 解析源代码,找到对应的类定义节点 - class_node = cls._find_class_node(class_source) + class_node = self._find_class_node(class_source) # 从类定义节点中提取字段文档 - return cls._extract_field_docs(class_node) + return self._extract_field_docs(class_node, allow_extra_methods) @classmethod def _get_class_source(cls) -> str: @@ -57,21 +65,22 @@ class AttrDocBase: # 如果没有找到匹配的类定义,抛出异常 raise AttributeError(f"Class {cls.__name__} not found in source.") - @classmethod - def _extract_field_docs(cls, class_node: ast.ClassDef) -> dict[str, str]: + def _extract_field_docs(self, class_node: ast.ClassDef, allow_extra_methods: bool) -> dict[str, str]: """从类的 AST 节点中提取字段的文档字符串""" + # sourcery skip: merge-nested-ifs doc_dict: dict[str, str] = {} class_body = class_node.body # 类属性节点列表 for i in range(len(class_body)): body_item = class_body[i] - # 检查是否有非 model_post_init 的方法定义,如果有则抛出异常 - # 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法 - if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init": - """检验ConfigBase子类中是否有除model_post_init以外的方法,规范配置类的定义""" - raise AttributeError( - f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}" - ) from None + if not allow_extra_methods: + # 检查是否有非 model_post_init 的方法定义,如果有则抛出异常 + # 这个限制确保 AttrDocBase 子类只包含字段定义和 model_post_init 方法 + if isinstance(body_item, ast.FunctionDef) and body_item.name != "model_post_init": + """检验ConfigBase子类中是否有除model_post_init以外的方法,规范配置类的定义""" + raise AttributeError( + f"Methods are not allowed in AttrDocBase subclasses except model_post_init, found {str(body_item.name)}" + ) from None # 检查当前语句是否为带注解的赋值语句 (类型注解的字段定义) # 并且下一个语句存在 @@ -110,13 +119,53 @@ class AttrDocBase: class ConfigBase(BaseModel, AttrDocBase): model_config = ConfigDict(validate_assignment=True, extra="forbid") _validate_any: bool = True # 是否验证 Any 类型的使用,默认为 True + suppress_any_warning: bool = False # 是否抑制 Any 类型使用的警告,默认为 False,仅仅在_validate_any 为 False 时生效 + + @classmethod + def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]): + """从字典创建配置对象,并收集缺失和多余的属性信息""" + class_fields = set(cls.model_fields.keys()) + class_fields.remove("field_docs") # 忽略 field_docs 字段 + if "_validate_any" in class_fields: + class_fields.remove("_validate_any") # 忽略 _validate_any 字段 + if "suppress_any_warning" in class_fields: + class_fields.remove("suppress_any_warning") # 忽略 suppress_any_warning 字 + for class_field in class_fields: + if class_field not in data: + attribute_data.missing_attributes.append(class_field) # 记录缺失的属性 + cleaned_data_list: list[str] = [] + for data_field in data: + if data_field not in class_fields: + cleaned_data_list.append(data_field) + attribute_data.redundant_attributes.append(data_field) # 记录多余的属性 + for redundant_field in cleaned_data_list: + data.pop(redundant_field) # 移除多余的属性 + # 对于是ConfigBase子类的字段,递归调用from_dict + class_field_infos = dict(cls.model_fields.items()) + for field_data in data: + if info := class_field_infos.get(field_data): + field_type = info.annotation + if inspect.isclass(field_type) and issubclass(field_type, ConfigBase): + data[field_data] = field_type.from_dict(attribute_data, data[field_data]) + if get_origin(field_type) in {list, List}: + elem_type = get_args(field_type)[0] + if inspect.isclass(elem_type) and issubclass(elem_type, ConfigBase): + data[field_data] = [elem_type.from_dict(attribute_data, item) for item in data[field_data]] + # 没有set,因为ConfigBase is not Hashable + if get_origin(field_type) in {dict, Dict}: + val_type = get_args(field_type)[1] + if inspect.isclass(val_type) and issubclass(val_type, ConfigBase): + data[field_data] = { + key: val_type.from_dict(attribute_data, val) for key, val in data[field_data].items() + } + return cls(**data) def _discourage_any_usage(self, field_name: str) -> None: """警告使用 Any 类型的字段(可被suppress)""" if self._validate_any: raise TypeError(f"字段'{field_name}'中不允许使用 Any 类型注解") - else: - logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议避免使用。") + if not self.suppress_any_warning: + logger.warning(f"字段'{field_name}'中使用了 Any 类型注解,建议使用更具体的类型注解以提高类型安全性") def _get_real_type(self, annotation: type[Any] | Any | None): """获取真实类型,处理 dict 等没有参数的情况""" @@ -157,10 +206,14 @@ class ConfigBase(BaseModel, AttrDocBase): elem = args[0] if elem is Any: self._discourage_any_usage(field_name) - if get_origin(elem) is not None: + elif get_origin(elem) is not None: raise TypeError( f"类'{type(self).__name__}'字段'{field_name}'中不允许嵌套泛型类型: {annotation},请使用自定义类代替。" ) + elif inspect.isclass(elem) and issubclass(elem, ConfigBase) and origin in (set, Set): + raise TypeError( + f"类'{type(self).__name__}'字段'{field_name}'中不允许使用 ConfigBase 子类作为 set 元素类型: {annotation}。ConfigBase is not Hashable。" + ) def _validate_dict_type(self, annotation: Any | None, field_name: str): """验证 dict 类型的使用""" @@ -215,7 +268,7 @@ class ConfigBase(BaseModel, AttrDocBase): if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore continue # 只允许 list, set, dict 三类泛型 - if origin_type not in (list, set, dict, List, Set, Dict): + if origin_type not in (list, set, dict, List, Set, Dict, Literal): raise TypeError( f"仅允许使用list, set, dict三种泛型类型注解,类'{type(self).__name__}'字段'{field_name}'中使用了: {annotation}" ) diff --git a/src/config/config_utils.py b/src/config/config_utils.py new file mode 100644 index 00000000..c18bccb1 --- /dev/null +++ b/src/config/config_utils.py @@ -0,0 +1,127 @@ +from pydantic.fields import FieldInfo +from typing import Any, get_args, get_origin, TYPE_CHECKING, Literal, List, Set, Tuple, Dict +from tomlkit import items +import tomlkit + +from .config_base import ConfigBase + +if TYPE_CHECKING: + from .config import AttributeData + + +def recursive_parse_item_to_table( + config: ConfigBase, is_inline_table: bool = False, override_repr: bool = False +) -> items.Table | items.InlineTable: + # sourcery skip: merge-else-if-into-elif, reintroduce-else + """递归解析配置项为表格""" + config_table = tomlkit.table() + if is_inline_table: + config_table = tomlkit.inline_table() + for config_item_name, config_item_info in type(config).model_fields.items(): + if not config_item_info.repr and not override_repr: + continue + value = getattr(config, config_item_name) + if config_item_name in ["field_docs", "_validate_any", "suppress_any_warning"]: + continue + if value is None: + continue + if isinstance(value, ConfigBase): + config_table.add(config_item_name, recursive_parse_item_to_table(value, override_repr=override_repr)) + else: + config_table.add( + config_item_name, convert_field(config_item_name, config_item_info, value, override_repr=override_repr) + ) + if not is_inline_table: + config_table = comment_doc_string(config, config_item_name, config_table) + return config_table + + +def comment_doc_string( + config: ConfigBase, field_name: str, toml_table: items.Table | items.InlineTable +) -> items.Table | items.InlineTable: + """将配置类中的注释加入toml表格中""" + if doc_string := config.field_docs.get(field_name, ""): + doc_string_splitted = doc_string.splitlines() + if len(doc_string_splitted) == 1 and not doc_string_splitted[0].strip().startswith("_wrap_"): + if isinstance(toml_table[field_name], bool): + # tomlkit 故意设计的行为,布尔值不能直接添加注释 + value = toml_table[field_name] + item = tomlkit.item(value) + item.comment(doc_string_splitted[0]) + toml_table[field_name] = item + else: + toml_table[field_name].comment(doc_string_splitted[0]) + else: + if doc_string_splitted[0].strip().startswith("_wrap_"): + doc_string_splitted[0] = doc_string_splitted[0].replace("_wrap_", "", 1).strip() + for line in doc_string_splitted: + toml_table.add(tomlkit.comment(line)) + toml_table.add(tomlkit.nl()) + return toml_table + + +def convert_field(config_item_name: str, config_item_info: FieldInfo, value: Any, override_repr: bool = False): + # sourcery skip: extract-method + """将非可直接表达类转换为toml可表达类""" + field_type_origin = get_origin(config_item_info.annotation) + field_type_args = get_args(config_item_info.annotation) + + if not field_type_origin: # 基础类型int,bool等直接添加 + return value + elif field_type_origin in {list, set, List, Set}: + toml_list = tomlkit.array() + if field_type_args and isinstance(field_type_args[0], type) and issubclass(field_type_args[0], ConfigBase): + for item in value: + toml_list.append(recursive_parse_item_to_table(item, True, override_repr)) + else: + for item in value: + toml_list.append(item) + return toml_list + elif field_type_origin in (tuple, Tuple): + toml_list = tomlkit.array() + for field_arg, item in zip(field_type_args, value, strict=True): + if isinstance(field_arg, type) and issubclass(field_arg, ConfigBase): + toml_list.append(recursive_parse_item_to_table(item, True, override_repr)) + else: + toml_list.append(item) + return toml_list + elif field_type_origin in (dict, Dict): + if len(field_type_args) != 2: + raise TypeError(f"Expected a dictionary with two type arguments for {config_item_name}") + toml_sub_table = tomlkit.inline_table() + key_type, value_type = field_type_args + if key_type is not str: + raise TypeError(f"TOML only supports string keys for tables, got {key_type} for {config_item_name}") + for k, v in value.items(): + if isinstance(value_type, type) and issubclass(value_type, ConfigBase): + toml_sub_table.add(k, recursive_parse_item_to_table(v, True, override_repr)) + else: + toml_sub_table.add(k, v) + return toml_sub_table + elif field_type_origin is Literal: + if value not in field_type_args: + raise ValueError(f"Value {value} not in Literal options {field_type_args} for {config_item_name}") + return value + else: + raise TypeError(f"Unsupported field type for {config_item_name}: {config_item_info.annotation}") + + +def output_config_changes(attr_data: "AttributeData", logger, old_ver: str, new_ver: str, file_name: str): + """输出配置变更信息""" + logger.info("-------- 配置文件变更信息 --------") + logger.info(f"新增配置数量: {len(attr_data.missing_attributes)}") + for attr in attr_data.missing_attributes: + logger.info(f"配置文件中新增配置项: {attr}") + logger.info(f"移除配置数量: {len(attr_data.redundant_attributes)}") + for attr in attr_data.redundant_attributes: + logger.warning(f"移除配置项: {attr}") + logger.info( + f"{file_name}配置文件已经更新. Old: {old_ver} -> New: {new_ver} 建议检查新配置文件中的内容, 以免丢失重要信息" + ) + + +def compare_versions(old_ver: str, new_ver: str) -> bool: + """比较版本号,返回是否有更新""" + old_parts = [int(part) for part in old_ver.split(".")] + new_parts = [int(part) for part in new_ver.split(".")] + return new_parts > old_parts diff --git a/src/config/model_configs.py b/src/config/model_configs.py new file mode 100644 index 00000000..374aef59 --- /dev/null +++ b/src/config/model_configs.py @@ -0,0 +1,129 @@ +from typing import Any +from .config_base import ConfigBase, Field + + +class APIProvider(ConfigBase): + """API提供商配置类""" + + name: str = "" + """API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)""" + + base_url: str = "" + """API服务商的BaseURL""" + + api_key: str = Field(default_factory=str, repr=False) + """API密钥""" + + client_type: str = Field(default="openai") + """客户端类型 (可选: openai/google, 默认为openai)""" + + max_retry: int = Field(default=2) + """最大重试次数 (单个模型API调用失败, 最多重试的次数)""" + + timeout: int = 10 + """API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)""" + + retry_interval: int = 10 + """重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)""" + + def model_post_init(self, context: Any = None): + """确保api_key在repr中不被显示""" + if not self.api_key: + raise ValueError("API密钥不能为空, 请在配置中设置有效的API密钥。") + if not self.base_url and self.client_type != "gemini": # TODO: 允许gemini使用base_url + raise ValueError("API基础URL不能为空, 请在配置中设置有效的基础URL。") + if not self.name: + raise ValueError("API提供商名称不能为空, 请在配置中设置有效的名称。") + return super().model_post_init(context) + + +class ModelInfo(ConfigBase): + """单个模型信息配置类""" + _validate_any: bool = False + suppress_any_warning: bool = True + + model_identifier: str = "" + """模型标识符 (API服务商提供的模型标识符)""" + + name: str = "" + """模型名称 (可随意命名, 在models中需使用这个命名)""" + + api_provider: str = "" + """API服务商名称 (对应在api_providers中配置的服务商名称)""" + + price_in: float = Field(default=0.0) + """输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)""" + + price_out: float = Field(default=0.0) + """输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)""" + + temperature: float | None = Field(default=None) + """模型级别温度(可选),会覆盖任务配置中的温度""" + + max_tokens: int | None = Field(default=None) + """模型级别最大token数(可选),会覆盖任务配置中的max_tokens""" + + force_stream_mode: bool = Field(default=False) + """强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)""" + + extra_params: dict[str, Any] = Field(default_factory=dict) + """额外参数 (用于API调用时的额外配置)""" + + def model_post_init(self, context: Any = None): + if not self.model_identifier: + raise ValueError("模型标识符不能为空, 请在配置中设置有效的模型标识符。") + if not self.name: + raise ValueError("模型名称不能为空, 请在配置中设置有效的模型名称。") + if not self.api_provider: + raise ValueError("API提供商不能为空, 请在配置中设置有效的API提供商。") + return super().model_post_init(context) + + +class TaskConfig(ConfigBase): + """任务配置类""" + + model_list: list[str] = Field(default_factory=list) + """使用的模型列表, 每个元素对应上面的模型名称(name)""" + + max_tokens: int = 1024 + """任务最大输出token数""" + + temperature: float = 0.3 + """模型温度""" + + slow_threshold: float = 15.0 + """慢请求阈值(秒),超过此值会输出警告日志""" + + selection_strategy: str = Field(default="balance") + """模型选择策略:balance(负载均衡)或 random(随机选择)""" + + +class ModelTaskConfig(ConfigBase): + """模型配置类""" + + utils: TaskConfig = Field(default_factory=TaskConfig) + """组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型""" + + replyer: TaskConfig = Field(default_factory=TaskConfig) + """首要回复模型配置, 还用于表达器和表达方式学习""" + + vlm: TaskConfig = Field(default_factory=TaskConfig) + """视觉模型配置""" + + voice: TaskConfig = Field(default_factory=TaskConfig) + """语音识别模型配置""" + + tool_use: TaskConfig = Field(default_factory=TaskConfig) + """工具使用模型配置, 需要使用支持工具调用的模型""" + + planner: TaskConfig = Field(default_factory=TaskConfig) + """规划模型配置""" + + embedding: TaskConfig = Field(default_factory=TaskConfig) + """嵌入模型配置""" + + lpmm_entity_extract: TaskConfig = Field(default_factory=TaskConfig) + """LPMM实体提取模型配置""" + + lpmm_rdf_build: TaskConfig = Field(default_factory=TaskConfig) + """LPMM RDF构建模型配置""" diff --git a/src/config/official_configs.py b/src/config/official_configs.py index a6652e0e..7291d188 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,117 +1,106 @@ +from .config_base import ConfigBase, Field import re - -from dataclasses import dataclass, field -from typing import Literal, Optional -import time - -from src.config.config_base import ConfigBase +from typing import Optional, Literal """ 须知: 1. 本文件中记录了所有的配置项 2. 所有新增的class都需要继承自ConfigBase -3. 所有新增的class都应在config.py中的Config类中添加字段 -4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default +3. 所有新增的class都应在official_configs.py中的Config类中添加字段 +4. 对于新增的字段,若为可选项,则应在其后添加Field()并设置default_factory或default +5. 所有的配置项都应该按照如下方法添加字段说明: +class ExampleConfig(ConfigBase): + example_field: str + \"""This is an example field\""" + - 注释前面增加_warp_标记可以实现配置文件中注释在配置项前面单独一行显示 """ -@dataclass class BotConfig(ConfigBase): - """QQ机器人配置类""" + """机器人配置类""" - platform: str - """平台""" - - qq_account: str - """QQ账号""" - - nickname: str - """昵称""" - - platforms: list[str] = field(default_factory=lambda: []) - """其他平台列表""" - - alias_names: list[str] = field(default_factory=lambda: []) + alias_names: list[str] = Field(default_factory=lambda: []) """别名列表""" -@dataclass class PersonalityConfig(ConfigBase): """人格配置类""" - personality: str - """人格""" + personality: str = "是一个大二在读女大学生,现在正在上网和群友聊天,有时有点攻击性,有时比较温柔" + """人格,建议100字以内,描述人格特质和身份特征""" - reply_style: str = "" - """默认表达风格""" + reply_style: str = "请不要刻意突出自身学科背景。可以参考贴吧,知乎和微博的回复风格。" + """默认表达风格,描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容,建议1-2行""" - multiple_reply_style: list[str] = field(default_factory=lambda: []) + multiple_reply_style: list[str] = Field(default_factory=lambda: []) """可选的多种表达风格列表,当配置不为空时可按概率随机替换 reply_style""" - multiple_probability: float = 0.0 + multiple_probability: float = 0.3 """每次构建回复时,从 multiple_reply_style 中随机替换 reply_style 的概率(0.0-1.0)""" - plan_style: str = "" - """说话规则,行为风格""" + plan_style: str = """1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用 +2.如果相同的action已经被执行,请不要重复执行该action +3.如果有人对你感到厌烦,请减少回复 +4.如果有人在追问你,或者话题没有说完,请你继续回复 +5.请分析哪些对话是和你说的,哪些是其他人之间的互动,不要误认为其他人之间的互动是和你说的""" + """_wrap_麦麦的说话规则和行为规则""" - visual_style: str = "" - """图片提示词""" + visual_style: str = "请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本" + """_wrap_识图提示词,不建议修改""" - states: list[str] = field(default_factory=lambda: []) - """状态列表,用于随机替换personality""" + states: list[str] = Field( + default_factory=lambda: [ + "是一个女大学生,喜欢上网聊天,会刷小红书。", + "是一个大二心理学生,会刷贴吧和中国知网。", + "是一个赛博网友,最近很想吐槽人。", + ] + ) + """_wrap_状态列表,用于随机替换personality""" - state_probability: float = 0.0 + state_probability: float = 0.3 """状态概率,每次构建人格时替换personality的概率""" -@dataclass class RelationshipConfig(ConfigBase): """关系配置类""" enable_relationship: bool = True - """是否启用关系系统""" + """是否启用关系系统,关系系统被移除,此部分配置暂时无效""" + + +class TalkRulesItem(ConfigBase): + platform: str = "" + """平台,留空表示全局""" + + id: str = "" + """用户ID""" + + rule_type: Literal["group", "private"] = "group" + """聊天流类型,group(群聊)或private(私聊)""" + + time: str = "" + """时间段,格式为 "HH:MM-HH:MM",支持跨夜区间""" + + value: float = 0.5 + """聊天频率值,范围0-1""" -@dataclass class ChatConfig(ConfigBase): """聊天配置类""" - max_context_size: int = 18 - """上下文长度""" + talk_value: float = 1 + """聊天频率,越小越沉默,范围0-1""" mentioned_bot_reply: bool = True """是否启用提及必回复""" - at_bot_inevitable_reply: float = 1 - """@bot 必然回复,1为100%回复,0为不额外增幅""" + max_context_size: int = 30 + """上下文长度""" planner_smooth: float = 3 - """规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐2-5,0为关闭,必须大于等于0""" + """规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐1-5,0为关闭,必须大于等于0""" - talk_value: float = 1 - """思考频率""" - - enable_talk_value_rules: bool = True - """是否启用动态发言频率规则""" - - talk_value_rules: list[dict] = field(default_factory=lambda: []) - """ - 思考频率规则列表,支持按聊天流/按日内时段配置。 - 规则格式:{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 } - - 示例: - [ - ["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静 - ["", "09:00-22:59", 1.0], # 全局规则:白天正常 - ["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言 - ["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静 - ] - - 匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\"). - 时间区间支持跨夜,例如 "23:00-02:00"。 - """ - - think_mode: Literal["classic", "deep", "dynamic"] = "classic" + think_mode: Literal["classic", "deep", "dynamic"] = "dynamic" """ 思考模式配置 - classic: 默认think_level为0(轻量回复,不需要思考和回忆) @@ -125,162 +114,63 @@ class ChatConfig(ConfigBase): llm_quote: bool = False """是否在 reply action 中启用 quote 参数,启用后 LLM 可以控制是否引用消息""" - def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: - """与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。""" - try: - parts = stream_config_str.split(":") - if len(parts) != 3: - return None + enable_talk_value_rules: bool = True + """是否启用动态发言频率规则""" - platform = parts[0] - id_str = parts[1] - stream_type = parts[2] - - is_group = stream_type == "group" - - from src.chat.message_receive.chat_stream import get_chat_manager - - return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) - - except (ValueError, IndexError): - return None - - def _now_minutes(self) -> int: - """返回本地时间的分钟数(0-1439)。""" - lt = time.localtime() - return lt.tm_hour * 60 + lt.tm_min - - def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]: - """解析 "HH:MM-HH:MM" 到 (start_min, end_min)。""" - try: - start_str, end_str = [s.strip() for s in range_str.split("-")] - sh, sm = [int(x) for x in start_str.split(":")] - eh, em = [int(x) for x in end_str.split(":")] - return sh * 60 + sm, eh * 60 + em - except Exception: - return None - - def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool: - """ - 判断 now_min 是否在 [start_min, end_min] 区间内。 - 支持跨夜:如果 start > end,则表示跨越午夜。 - """ - if start_min <= end_min: - return start_min <= now_min <= end_min - # 跨夜:例如 23:00-02:00 - return now_min >= start_min or now_min <= end_min - - def get_talk_value(self, chat_id: Optional[str]) -> float: - """根据规则返回当前 chat 的动态 talk_value,未匹配则回退到基础值。""" - if not self.enable_talk_value_rules or not self.talk_value_rules: - result = self.talk_value - # 防止返回0值,自动转换为0.0001 - if result == 0: - return 0.0000001 - return result - - now_min = self._now_minutes() - - # 1) 先尝试匹配指定 chat 的规则 - if chat_id: - for rule in self.talk_value_rules: - if not isinstance(rule, dict): - continue - target = rule.get("target", "") - time_range = rule.get("time", "") - value = rule.get("value", None) - if not isinstance(time_range, str): - continue - # 跳过全局 - if target == "": - continue - config_chat_id = self._parse_stream_config_to_chat_id(str(target)) - if config_chat_id is None or config_chat_id != chat_id: - continue - parsed = self._parse_range(time_range) - if not parsed: - continue - start_min, end_min = parsed - if self._in_range(now_min, start_min, end_min): - try: - result = float(value) - # 防止返回0值,自动转换为0.0001 - if result == 0: - return 0.0000001 - return result - except Exception: - continue - - # 2) 再匹配全局规则("") - for rule in self.talk_value_rules: - if not isinstance(rule, dict): - continue - target = rule.get("target", None) - time_range = rule.get("time", "") - value = rule.get("value", None) - if target != "" or not isinstance(time_range, str): - continue - parsed = self._parse_range(time_range) - if not parsed: - continue - start_min, end_min = parsed - if self._in_range(now_min, start_min, end_min): - try: - result = float(value) - # 防止返回0值,自动转换为0.0001 - if result == 0: - return 0.0000001 - return result - except Exception: - continue - - # 3) 未命中规则返回基础值 - result = self.talk_value - # 防止返回0值,自动转换为0.0001 - if result == 0: - return 0.0000001 - return result + talk_value_rules: list[TalkRulesItem] = Field( + default_factory=lambda: [ + TalkRulesItem(platform="", id="", rule_type="group", time="00:00-08:59", value=0.8), + TalkRulesItem(platform="", id="", rule_type="group", time="09:00-18:59", value=1.0), + ] + ) + """ + _wrap_思考频率规则列表,支持按聊天流/按日内时段配置。 + """ -@dataclass class MessageReceiveConfig(ConfigBase): """消息接收配置类""" - ban_words: set[str] = field(default_factory=lambda: set()) + ban_words: set[str] = Field(default_factory=lambda: set()) """过滤词列表""" - ban_msgs_regex: set[str] = field(default_factory=lambda: set()) + ban_msgs_regex: set[str] = Field(default_factory=lambda: set()) """过滤正则表达式列表""" + def model_post_init(self, context: Optional[dict] = None) -> None: + for pattern in self.ban_msgs_regex: + try: + re.compile(pattern) + except re.error as e: + raise ValueError(f"Invalid regex pattern in ban_msgs_regex: '{pattern}'") from e + return super().model_post_init(context) + + +class TargetItem(ConfigBase): + platform: str = "" + """平台,留空表示全局""" + + id: str = "" + """用户ID""" + + rule_type: Literal["group", "private"] = "group" + """聊天流类型,group(群聊)或private(私聊)""" + -@dataclass class MemoryConfig(ConfigBase): """记忆配置类""" max_agent_iterations: int = 5 - """Agent最多迭代轮数(最低为1)""" + """记忆思考深度(最低为1)""" agent_timeout_seconds: float = 120.0 - """Agent超时时间(秒)""" + """最长回忆时间(秒)""" global_memory: bool = False """是否允许记忆检索在聊天记录中进行全局查询(忽略当前chat_id,仅对 search_chat_history 等工具生效)""" - global_memory_blacklist: list[str] = field(default_factory=lambda: []) - """ - 全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索 - 格式: ["platform:id:type", ...] - - 示例: - [ - "qq:1919810:private", # 排除特定私聊 - "qq:114514:group", # 排除特定群聊 - ] - - 说明: - - 当启用全局记忆时,黑名单中的聊天流不会被检索 - - 当在黑名单中的聊天流进行查询时,仅使用该聊天流的本地记忆 - """ + global_memory_blacklist: list[TargetItem] = Field(default_factory=lambda: []) + """_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索""" chat_history_topic_check_message_threshold: int = 80 """聊天历史话题检查的消息数量阈值,当累积消息数达到此值时触发话题检查""" @@ -297,234 +187,137 @@ class MemoryConfig(ConfigBase): chat_history_finalize_message_count: int = 5 """聊天历史话题打包存储的消息条数阈值,当话题的消息条数超过此值时触发打包存储""" - def __post_init__(self): + chat_history_topic_check_message_threshold: int = 80 + """聊天历史话题检查的消息数量阈值,当累积消息数达到此值时触发话题检查""" + + chat_history_topic_check_time_hours: float = 8.0 + """聊天历史话题检查的时间阈值(小时),当距离上次检查超过此时间且消息数达到最小阈值时触发话题检查""" + + chat_history_topic_check_min_messages: int = 20 + """聊天历史话题检查的时间触发模式下的最小消息数阈值""" + + chat_history_finalize_no_update_checks: int = 3 + """聊天历史话题打包存储的连续无更新检查次数阈值,当话题连续N次检查无新增内容时触发打包存储""" + + chat_history_finalize_message_count: int = 5 + """聊天历史话题打包存储的消息条数阈值,当话题的消息条数超过此值时触发打包存储""" + + def model_post_init(self, context: Optional[dict] = None) -> None: """验证配置值""" if self.max_agent_iterations < 1: raise ValueError(f"max_agent_iterations 必须至少为1,当前值: {self.max_agent_iterations}") if self.agent_timeout_seconds <= 0: raise ValueError(f"agent_timeout_seconds 必须大于0,当前值: {self.agent_timeout_seconds}") if self.chat_history_topic_check_message_threshold < 1: - raise ValueError(f"chat_history_topic_check_message_threshold 必须至少为1,当前值: {self.chat_history_topic_check_message_threshold}") + raise ValueError( + f"chat_history_topic_check_message_threshold 必须至少为1,当前值: {self.chat_history_topic_check_message_threshold}" + ) if self.chat_history_topic_check_time_hours <= 0: - raise ValueError(f"chat_history_topic_check_time_hours 必须大于0,当前值: {self.chat_history_topic_check_time_hours}") + raise ValueError( + f"chat_history_topic_check_time_hours 必须大于0,当前值: {self.chat_history_topic_check_time_hours}" + ) if self.chat_history_topic_check_min_messages < 1: - raise ValueError(f"chat_history_topic_check_min_messages 必须至少为1,当前值: {self.chat_history_topic_check_min_messages}") + raise ValueError( + f"chat_history_topic_check_min_messages 必须至少为1,当前值: {self.chat_history_topic_check_min_messages}" + ) if self.chat_history_finalize_no_update_checks < 1: - raise ValueError(f"chat_history_finalize_no_update_checks 必须至少为1,当前值: {self.chat_history_finalize_no_update_checks}") + raise ValueError( + f"chat_history_finalize_no_update_checks 必须至少为1,当前值: {self.chat_history_finalize_no_update_checks}" + ) if self.chat_history_finalize_message_count < 1: - raise ValueError(f"chat_history_finalize_message_count 必须至少为1,当前值: {self.chat_history_finalize_message_count}") + raise ValueError( + f"chat_history_finalize_message_count 必须至少为1,当前值: {self.chat_history_finalize_message_count}" + ) + return super().model_post_init(context) + + +class LearningItem(ConfigBase): + platform: str = "" + """平台,留空表示全局""" + + id: str = "" + """用户ID""" + + rule_type: Literal["group", "private"] = "group" + """聊天流类型,group(群聊)或private(私聊)""" + + use_expression: bool = True + """是否启用表达学习""" + + enable_learning: bool = True + """是否启用表达优化学习""" + + enable_jargon_learning: bool = False + """是否启用jargon学习""" + + +class ExpressionGroup(ConfigBase): + """表达互通组配置类,若列表为空代表全局共享""" + + expression_groups: list[TargetItem] = Field(default_factory=lambda: []) + """_wrap_表达学习互通组""" -@dataclass class ExpressionConfig(ConfigBase): """表达配置类""" - learning_list: list[list] = field(default_factory=lambda: []) - """ - 表达学习配置列表,支持按聊天流配置 - 格式: [["chat_stream_id", "use_expression", "enable_learning", "enable_jargon_learning"], ...] - - 示例: - [ - ["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习 - ["qq:1919810:private", "enable", "enable", "enable"], # 特定私聊配置:使用表达,启用学习,启用jargon学习 - ["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习 - ] - - 说明: - - 第一位: chat_stream_id,空字符串表示全局配置 - - 第二位: 是否使用学到的表达 ("enable"/"disable") - - 第三位: 是否学习表达 ("enable"/"disable") - - 第四位: 是否启用jargon学习 ("enable"/"disable") - """ + learning_list: list[LearningItem] = Field( + default_factory=lambda: [ + LearningItem( + platform="", + id="", + rule_type="group", + use_expression=True, + enable_learning=True, + enable_jargon_learning=True, + ) + ] + ) + """_wrap_表达学习配置列表,支持按聊天流配置""" - expression_groups: list[list[str]] = field(default_factory=list) - """ - 表达学习互通组 - 格式: [["qq:12345:group", "qq:67890:private"]] - """ + expression_groups: list[ExpressionGroup] = Field(default_factory=list) + """_wrap_表达学习互通组""" - expression_self_reflect: bool = False + expression_checked_only: bool = True + """是否仅选择已检查且未拒绝的表达方式""" + + expression_self_reflect: bool = True """是否启用自动表达优化""" - + + expression_auto_check_interval: int = 600 + """表达方式自动检查的间隔时间(秒)""" + + expression_auto_check_count: int = 20 + """每次自动检查时随机选取的表达方式数量""" + + expression_auto_check_custom_criteria: list[str] = Field(default_factory=list) + """表达方式自动检查的额外自定义评估标准""" + expression_manual_reflect: bool = False """是否启用手动表达优化""" - manual_reflect_operator_id: str = "" - """表达反思操作员ID""" + manual_reflect_operator_id: Optional[TargetItem] = None + """手动表达优化操作员ID""" - allow_reflect: list[str] = field(default_factory=list) - """ - 允许进行表达反思的聊天流ID列表 - 格式: ["qq:123456:private", "qq:654321:group", ...] - 只有在此列表中的聊天流才会提出问题并跟踪 - 如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true) - """ + allow_reflect: list[TargetItem] = Field(default_factory=list) + """允许进行表达反思的聊天流ID列表,只有在此列表中的聊天流才会提出问题并跟踪。如果列表为空,则所有聊天流都可以进行表达反思(前提是reflect为true)""" - all_global_jargon: bool = False - """是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id。注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除""" + all_global_jargon: bool = True + """是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除""" enable_jargon_explanation: bool = True """是否在回复前尝试对上下文中的黑话进行解释(关闭可减少一次LLM调用,仅影响回复前的黑话匹配与解释,不影响黑话学习)""" - jargon_mode: Literal["context", "planner"] = "context" + jargon_mode: Literal["context", "planner"] = "planner" """ - 黑话解释来源模式: - - "context": 使用上下文自动匹配黑话并解释(原有模式) - - "planner": 仅使用 Planner 在 reply 动作中给出的 unknown_words 列表进行黑话检索 - """ - - expression_checked_only: bool = False - """ - 是否仅选择已检查且未拒绝的表达方式 - 当设置为 true 时,只有 checked=True 且 rejected=False 的表达方式才会被选择 - 当设置为 false 时,保留旧的筛选原则(仅排除 rejected=True 的表达方式) + 黑话解释来源模式 + + 可选: + - "context":使用上下文自动匹配黑话 + - "planner":仅使用Planner在reply动作中给出的unknown_words列表 """ - expression_auto_check_interval: int = 3600 - """ - 表达方式自动检查的间隔时间(单位:秒) - 默认值:3600秒(1小时) - """ - - expression_auto_check_count: int = 10 - """ - 每次自动检查时随机选取的表达方式数量 - 默认值:10条 - """ - - expression_auto_check_custom_criteria: list[str] = field(default_factory=list) - """ - 表达方式自动检查的额外自定义评估标准 - 格式: ["标准1", "标准2", "标准3", ...] - 这些标准会被添加到评估提示词中,作为额外的评估要求 - 默认值:空列表 - """ - - def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: - """ - 解析流配置字符串并生成对应的 chat_id - - Args: - stream_config_str: 格式为 "platform:id:type" 的字符串 - - Returns: - str: 生成的 chat_id,如果解析失败则返回 None - """ - try: - parts = stream_config_str.split(":") - if len(parts) != 3: - return None - - platform = parts[0] - id_str = parts[1] - stream_type = parts[2] - - # 判断是否为群聊 - is_group = stream_type == "group" - - # 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑 - from src.chat.message_receive.chat_stream import get_chat_manager - - return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) - - except (ValueError, IndexError): - return None - - def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]: - """ - 根据聊天流ID获取表达配置 - - Args: - chat_stream_id: 聊天流ID,格式为哈希值 - - Returns: - tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习) - """ - if not self.learning_list: - # 如果没有配置,使用默认值:启用表达,启用学习,启用jargon学习 - return True, True, True - - # 优先检查聊天流特定的配置 - if chat_stream_id: - specific_expression_config = self._get_stream_specific_config(chat_stream_id) - if specific_expression_config is not None: - return specific_expression_config - - # 检查全局配置(第一个元素为空字符串的配置) - global_expression_config = self._get_global_config() - if global_expression_config is not None: - return global_expression_config - - # 如果都没有匹配,返回默认值:启用表达,启用学习,启用jargon学习 - return True, True, True - - def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, bool]]: - """ - 获取特定聊天流的表达配置 - - Args: - chat_stream_id: 聊天流ID(哈希值) - - Returns: - tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None - """ - for config_item in self.learning_list: - if not config_item or len(config_item) < 4: - continue - - stream_config_str = config_item[0] # 例如 "qq:1026294844:group" - - # 如果是空字符串,跳过(这是全局配置) - if stream_config_str == "": - continue - - # 解析配置字符串并生成对应的 chat_id - config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str) - if config_chat_id is None: - continue - - # 比较生成的 chat_id - if config_chat_id != chat_stream_id: - continue - - # 解析配置 - try: - use_expression: bool = config_item[1].lower() == "enable" - enable_learning: bool = config_item[2].lower() == "enable" - enable_jargon_learning: bool = config_item[3].lower() == "enable" - return use_expression, enable_learning, enable_jargon_learning # type: ignore - except (ValueError, IndexError): - continue - - return None - - def _get_global_config(self) -> Optional[tuple[bool, bool, bool]]: - """ - 获取全局表达配置 - - Returns: - tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None - """ - for config_item in self.learning_list: - if not config_item or len(config_item) < 4: - continue - - # 检查是否为全局配置(第一个元素为空字符串) - if config_item[0] == "": - try: - use_expression: bool = config_item[1].lower() == "enable" - enable_learning: bool = config_item[2].lower() == "enable" - enable_jargon_learning: bool = config_item[3].lower() == "enable" - return use_expression, enable_learning, enable_jargon_learning # type: ignore - except (ValueError, IndexError): - continue - - return None - - -@dataclass class ToolConfig(ConfigBase): """工具配置类""" @@ -532,54 +325,51 @@ class ToolConfig(ConfigBase): """是否在聊天中启用工具""" -@dataclass class VoiceConfig(ConfigBase): """语音识别配置类""" enable_asr: bool = False - """是否启用语音识别""" + """是否启用语音识别,启用后麦麦可以识别语音消息""" -@dataclass class EmojiConfig(ConfigBase): """表情包配置类""" - emoji_chance: float = 0.6 + emoji_chance: float = 0.4 """发送表情包的基础概率""" - max_reg_num: int = 200 + max_reg_num: int = 100 """表情包最大注册数量""" do_replace: bool = True - """达到最大注册数量时替换旧表情包""" + """达到最大注册数量时替换旧表情包,关闭则达到最大数量时不会继续收集表情包""" - check_interval: int = 120 + check_interval: int = 10 """表情包检查间隔(分钟)""" steal_emoji: bool = True - """是否偷取表情包,让麦麦可以发送她保存的这些表情包""" + """是否偷取表情包,让麦麦可以将一些表情包据为己有""" content_filtration: bool = False - """是否开启表情包过滤""" + """是否启用表情包过滤,只有符合该要求的表情包才会被保存""" filtration_prompt: str = "符合公序良俗" - """表情包过滤要求""" + """表情包过滤要求,只有符合该要求的表情包才会被保存""" -@dataclass class KeywordRuleConfig(ConfigBase): """关键词规则配置类""" - keywords: list[str] = field(default_factory=lambda: []) + keywords: list[str] = Field(default_factory=lambda: []) """关键词列表""" - regex: list[str] = field(default_factory=lambda: []) + regex: list[str] = Field(default_factory=lambda: []) """正则表达式列表""" reaction: str = "" """关键词触发的反应""" - def __post_init__(self): + def model_post_init(self, context: Optional[dict] = None) -> None: """验证配置""" if not self.keywords and not self.regex: raise ValueError("关键词规则必须至少包含keywords或regex中的一个") @@ -587,33 +377,31 @@ class KeywordRuleConfig(ConfigBase): if not self.reaction: raise ValueError("关键词规则必须包含reaction") - # 验证正则表达式 for pattern in self.regex: try: re.compile(pattern) except re.error as e: raise ValueError(f"无效的正则表达式 '{pattern}': {str(e)}") from e + return super().model_post_init(context) -@dataclass class KeywordReactionConfig(ConfigBase): """关键词配置类""" - keyword_rules: list[KeywordRuleConfig] = field(default_factory=lambda: []) + keyword_rules: list[KeywordRuleConfig] = Field(default_factory=lambda: []) """关键词规则列表""" - regex_rules: list[KeywordRuleConfig] = field(default_factory=lambda: []) + regex_rules: list[KeywordRuleConfig] = Field(default_factory=lambda: []) """正则表达式规则列表""" - def __post_init__(self): + def model_post_init(self, context: Optional[dict] = None) -> None: """验证配置""" - # 验证所有规则 for rule in self.keyword_rules + self.regex_rules: if not isinstance(rule, KeywordRuleConfig): raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}") + return super().model_post_init(context) -@dataclass class ResponsePostProcessConfig(ConfigBase): """回复后处理配置类""" @@ -621,7 +409,6 @@ class ResponsePostProcessConfig(ConfigBase): """是否启用回复后处理,包括错别字生成器,回复分割器""" -@dataclass class ChineseTypoConfig(ConfigBase): """中文错别字配置类""" @@ -641,27 +428,25 @@ class ChineseTypoConfig(ConfigBase): """整词替换概率""" -@dataclass class ResponseSplitterConfig(ConfigBase): """回复分割器配置类""" enable: bool = True """是否启用回复分割器""" - max_length: int = 256 + max_length: int = 512 """回复允许的最大长度""" - max_sentence_num: int = 3 + max_sentence_num: int = 8 """回复允许的最大句子数""" enable_kaomoji_protection: bool = False """是否启用颜文字保护""" enable_overflow_return_all: bool = False - """是否在超出句子数量限制时合并后一次性返回""" + """是否在句子数量超出回复允许的最大句子数时一次性返回全部内容""" -@dataclass class TelemetryConfig(ConfigBase): """遥测配置类""" @@ -669,36 +454,6 @@ class TelemetryConfig(ConfigBase): """是否启用遥测""" -@dataclass -class WebUIConfig(ConfigBase): - """WebUI配置类 - - 注意: host 和 port 配置已移至环境变量 WEBUI_HOST 和 WEBUI_PORT - """ - - enabled: bool = True - """是否启用WebUI""" - - mode: Literal["development", "production"] = "production" - """运行模式:development(开发) 或 production(生产)""" - - anti_crawler_mode: Literal["false", "strict", "loose", "basic"] = "basic" - """防爬虫模式:false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止)""" - - allowed_ips: str = "127.0.0.1" - """IP白名单(逗号分隔,支持精确IP、CIDR格式和通配符)""" - - trusted_proxies: str = "" - """信任的代理IP列表(逗号分隔),只有来自这些IP的X-Forwarded-For才被信任""" - - trust_xff: bool = False - """是否启用X-Forwarded-For代理解析(默认false)""" - - secure_cookie: bool = False - """是否启用安全Cookie(仅通过HTTPS传输,默认false)""" - - -@dataclass class DebugConfig(ConfigBase): """调试配置类""" @@ -718,47 +473,51 @@ class DebugConfig(ConfigBase): """是否显示记忆检索相关prompt""" show_planner_prompt: bool = False - """是否显示planner相关提示词""" + """是否显示planner的prompt和原始返回结果""" show_lpmm_paragraph: bool = False """是否显示lpmm找到的相关文段日志""" -@dataclass +class ExtraPromptItem(ConfigBase): + platform: str = "" + """平台,留空无效""" + + id: str = "" + """用户ID,留空无效""" + + rule_type: Literal["group", "private"] = "group" + """聊天流类型,group(群聊)或private(私聊)""" + + prompt: str = "" + """额外的prompt内容""" + + def model_post_init(self, context: Optional[dict] = None) -> None: + if not self.platform or not self.id or not self.prompt: + raise ValueError("ExtraPromptItem 中 platform, id 和 prompt 不能为空") + return super().model_post_init(context) + + class ExperimentalConfig(ConfigBase): """实验功能配置类""" - private_plan_style: str = "" - """私聊说话规则,行为风格(实验性功能)""" + private_plan_style: str = """ +1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用 +2.如果相同的内容已经被执行,请不要重复执行 +3.某句话如果已经被回复过,不要重复回复""" + """_wrap_私聊说话规则,行为风格(实验性功能)""" - chat_prompts: list[str] = field(default_factory=lambda: []) - """ - 为指定聊天添加额外的prompt配置列表 - 格式: ["platform:id:type:prompt内容", ...] - - 示例: - [ - "qq:114514:group:这是一个摄影群,你精通摄影知识", - "qq:19198:group:这是一个二次元交流群", - "qq:114514:private:这是你与好朋友的私聊" - ] - - 说明: - - platform: 平台名称,如 "qq" - - id: 群ID或用户ID - - type: "group" 或 "private" - - prompt内容: 要添加的额外prompt文本 - """ + chat_prompts: list[ExtraPromptItem] = Field(default_factory=lambda: []) + """_wrap_为指定聊天添加额外的prompt配置列表""" lpmm_memory: bool = False """是否将聊天历史总结导入到LPMM知识库。开启后,chat_history_summarizer总结出的历史记录会同时导入到知识库""" -@dataclass class MaimMessageConfig(ConfigBase): """maim_message配置类""" - auth_token: list[str] = field(default_factory=lambda: []) + auth_token: list[str] = Field(default_factory=lambda: []) """认证令牌,用于旧版API验证,为空则不启用验证""" enable_api_server: bool = False @@ -779,11 +538,10 @@ class MaimMessageConfig(ConfigBase): api_server_key_file: str = "" """新版API Server SSL密钥文件路径""" - api_server_allowed_api_keys: list[str] = field(default_factory=lambda: []) + api_server_allowed_api_keys: list[str] = Field(default_factory=lambda: []) """新版API Server允许的API Key列表,为空则允许所有连接""" -@dataclass class LPMMKnowledgeConfig(ConfigBase): """LPMM知识库配置类""" @@ -791,40 +549,40 @@ class LPMMKnowledgeConfig(ConfigBase): """是否启用LPMM知识库""" lpmm_mode: Literal["classic", "agent"] = "classic" - """LPMM知识库模式,可选:classic经典模式,agent 模式,结合最新的记忆一同使用""" + """LPMM知识库模式,可选:classic经典模式,agent 模式""" rag_synonym_search_top_k: int = 10 - """RAG同义词搜索的Top K数量""" + """同义检索TopK""" rag_synonym_threshold: float = 0.8 - """RAG同义词搜索的相似度阈值""" + """同义阈值,相似度高于该值的关系会被当作同义词""" info_extraction_workers: int = 3 - """信息提取工作线程数""" + """实体抽取同时执行线程数,非Pro模型不要设置超过5""" qa_relation_search_top_k: int = 10 - """QA关系搜索的Top K数量""" + """关系检索TopK""" qa_relation_threshold: float = 0.75 - """QA关系搜索的相似度阈值""" + """关系阈值,相似度高于该值的关系会被认为是相关关系""" qa_paragraph_search_top_k: int = 1000 - """QA段落搜索的Top K数量""" + """段落检索TopK(不能过小,可能影响搜索结果)""" qa_paragraph_node_weight: float = 0.05 - """QA段落节点权重""" + """段落节点权重(在图搜索&PPR计算中的权重,当搜索仅使用DPR时,此参数不起作用)""" qa_ent_filter_top_k: int = 10 - """QA实体过滤的Top K数量""" + """实体过滤TopK""" qa_ppr_damping: float = 0.8 - """QA PageRank阻尼系数""" + """PPR阻尼系数""" qa_res_top_k: int = 10 - """QA最终结果的Top K数量""" + """最终提供段落TopK""" embedding_dimension: int = 1024 - """嵌入向量维度,应该与模型的输出维度一致""" + """嵌入向量维度,输出维度""" max_embedding_workers: int = 3 """嵌入/抽取并发线程数""" @@ -839,7 +597,6 @@ class LPMMKnowledgeConfig(ConfigBase): """是否启用PPR,低配机器可关闭""" -@dataclass class DreamConfig(ConfigBase): """Dream配置类""" @@ -849,91 +606,23 @@ class DreamConfig(ConfigBase): max_iterations: int = 20 """做梦最大轮次,默认20轮""" - first_delay_seconds: int = 60 - """程序启动后首次做梦前的延迟时间(秒),默认60秒""" + first_delay_seconds: int = 1800 + """程序启动后首次做梦前的延迟时间(秒),默认1800秒""" dream_send: str = "" - """ - 做梦结果推送目标,格式为 "platform:user_id" - 例如: "qq:123456" 表示在做梦结束后,将梦境文本额外发送给该QQ私聊用户。 - 为空字符串时不推送。 - """ + """做梦结果推送目标,格式为 "platform:user_id,为空则不发送""" - dream_time_ranges: list[str] = field(default_factory=lambda: []) - """ - 做梦时间段配置列表,格式:["HH:MM-HH:MM", ...] - 如果列表为空,则表示全天允许做梦。 - 如果配置了时间段,则只有在这些时间段内才会实际执行做梦流程。 - 时间段外,调度器仍会按间隔检查,但不会进入做梦流程。 - - 示例: - [ - "09:00-22:00", # 白天允许做梦 - "23:00-02:00", # 跨夜时间段(23:00到次日02:00) - ] - - 支持跨夜区间,例如 "23:00-02:00" 表示从23:00到次日02:00。 - """ - - def _now_minutes(self) -> int: - """返回本地时间的分钟数(0-1439)。""" - lt = time.localtime() - return lt.tm_hour * 60 + lt.tm_min - - def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]: - """解析 "HH:MM-HH:MM" 到 (start_min, end_min)。""" - try: - start_str, end_str = [s.strip() for s in range_str.split("-")] - sh, sm = [int(x) for x in start_str.split(":")] - eh, em = [int(x) for x in end_str.split(":")] - return sh * 60 + sm, eh * 60 + em - except Exception: - return None - - def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool: - """ - 判断 now_min 是否在 [start_min, end_min] 区间内。 - 支持跨夜:如果 start > end,则表示跨越午夜。 - """ - if start_min <= end_min: - return start_min <= now_min <= end_min - # 跨夜:例如 23:00-02:00 - return now_min >= start_min or now_min <= end_min - - def is_in_dream_time(self) -> bool: - """ - 检查当前时间是否在允许做梦的时间段内。 - 如果 dream_time_ranges 为空,则返回 True(全天允许)。 - """ - if not self.dream_time_ranges: - return True - - now_min = self._now_minutes() - - for time_range in self.dream_time_ranges: - if not isinstance(time_range, str): - continue - parsed = self._parse_range(time_range) - if not parsed: - continue - start_min, end_min = parsed - if self._in_range(now_min, start_min, end_min): - return True - - return False + dream_time_ranges: list[str] = Field(default_factory=lambda: ["23:00-10:00"]) + """_wrap_做梦时间段配置列表""" dream_visible: bool = False - """ - 做梦结果是否存储到上下文 - - True: 将梦境发送给配置的用户后,也会存储到聊天上下文中,在后续对话中可见 - - False: 仅发送梦境但不存储,不在后续对话上下文中出现 - """ + """做梦结果发送后是否存储到上下文""" - def __post_init__(self): - """验证配置值""" + def model_post_init(self, context: Optional[dict] = None) -> None: if self.interval_minutes < 1: raise ValueError(f"interval_minutes 必须至少为1,当前值: {self.interval_minutes}") if self.max_iterations < 1: raise ValueError(f"max_iterations 必须至少为1,当前值: {self.max_iterations}") if self.first_delay_seconds < 0: raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}") + return super().model_post_init(context) diff --git a/src/plugin_system/apis/constants.py b/src/plugin_system/apis/constants.py new file mode 100644 index 00000000..88d74dca --- /dev/null +++ b/src/plugin_system/apis/constants.py @@ -0,0 +1,22 @@ +from pathlib import Path +from dataclasses import dataclass + + +@dataclass(frozen=True) +class _SystemConstants: + PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.parent.absolute().resolve() + CONFIG_DIR: Path = PROJECT_ROOT / "config" + BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() + MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() + PLUGINS_DIR: Path = (PROJECT_ROOT / "plugins").resolve().absolute() + INTERNAL_PLUGINS_DIR: Path = (PROJECT_ROOT / "src" / "plugins").resolve().absolute() + + +_system_constants = _SystemConstants() + +PROJECT_ROOT: Path = _system_constants.PROJECT_ROOT +CONFIG_DIR: Path = _system_constants.CONFIG_DIR +BOT_CONFIG_PATH: Path = _system_constants.BOT_CONFIG_PATH +MODEL_CONFIG_PATH: Path = _system_constants.MODEL_CONFIG_PATH +PLUGINS_DIR: Path = _system_constants.PLUGINS_DIR +INTERNAL_PLUGINS_DIR: Path = _system_constants.INTERNAL_PLUGINS_DIR