mirror of https://github.com/Mai-with-u/MaiBot.git
feat:自动检测并升级旧版配置文件
parent
049027a48f
commit
3a5fd9d3e3
|
|
@ -2,10 +2,13 @@ from pathlib import Path
|
|||
from typing import TypeVar
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
import copy
|
||||
|
||||
import tomlkit
|
||||
import sys
|
||||
|
||||
from .legacy_migration import try_migrate_legacy_bot_config_dict
|
||||
|
||||
from .official_configs import (
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
|
|
@ -48,7 +51,7 @@ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
|||
CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "0.13.0"
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.0.0"
|
||||
MODEL_CONFIG_VERSION: str = "1.12.0"
|
||||
|
||||
|
|
@ -216,9 +219,28 @@ def load_config_from_file(
|
|||
old_ver: str = config_data["inner"]["version"] # type: ignore
|
||||
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
|
||||
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
|
||||
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
|
||||
original_data: dict[str, Any] = copy.deepcopy(config_data)
|
||||
try:
|
||||
updated: bool = False
|
||||
target_config = config_class.from_dict(attribute_data, config_data)
|
||||
try:
|
||||
target_config = config_class.from_dict(attribute_data, config_data)
|
||||
except TypeError as e:
|
||||
# 可拔插的旧配置修复(仅针对 bot_config.toml 的已知结构变更)
|
||||
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
|
||||
# 基于未被部分构造污染的 original_data 做迁移尝试
|
||||
mig = try_migrate_legacy_bot_config_dict(original_data)
|
||||
if mig.migrated:
|
||||
logger.warning(
|
||||
f"检测到旧版配置结构,已尝试自动修复: {mig.reason}。"
|
||||
f"建议稍后检查并保存生成的新配置文件。"
|
||||
)
|
||||
migrated_data = mig.data
|
||||
target_config = config_class.from_dict(attribute_data, migrated_data)
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
if compare_versions(old_ver, new_ver):
|
||||
output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
|
||||
write_config_to_file(target_config, config_path, new_ver, override_repr)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from pydantic.fields import FieldInfo
|
||||
from typing import Any, get_args, get_origin, TYPE_CHECKING, Literal, List, Set, Tuple, Dict
|
||||
from typing import Any, get_args, get_origin, TYPE_CHECKING, Literal, List, Set, Tuple, Dict, Union
|
||||
import types
|
||||
from tomlkit import items
|
||||
import tomlkit
|
||||
|
||||
|
|
@ -66,7 +67,30 @@ def convert_field(config_item_name: str, config_item_info: FieldInfo, value: Any
|
|||
field_type_origin = get_origin(config_item_info.annotation)
|
||||
field_type_args = get_args(config_item_info.annotation)
|
||||
|
||||
if not field_type_origin: # 基础类型int,bool等直接添加
|
||||
# 处理 Optional[T] / Union[T, None] / PEP604 的 T | None
|
||||
if field_type_origin in (Union, types.UnionType):
|
||||
# 只处理 "某类型 + None" 的情况,等价于 Optional[T]
|
||||
non_none_args = tuple(a for a in field_type_args if a is not type(None))
|
||||
if len(non_none_args) == 1:
|
||||
inner = non_none_args[0]
|
||||
inner_origin = get_origin(inner)
|
||||
inner_args = get_args(inner)
|
||||
# Optional[基础类型] 直接按基础类型处理
|
||||
if inner_origin is None and isinstance(inner, type) and inner in (int, float, str, bool):
|
||||
return value
|
||||
# Optional[Literal[...]] 的情况
|
||||
if inner_origin is Literal:
|
||||
if value not in inner_args:
|
||||
raise ValueError(f"Value {value} not in Literal options {inner_args} for {config_item_name}")
|
||||
return value
|
||||
# 其它 Optional[...],后续按去掉 None 的泛型再走一遍逻辑
|
||||
field_type_origin = inner_origin
|
||||
field_type_args = inner_args
|
||||
else:
|
||||
# 复杂 Union 不支持写回,只能报错
|
||||
raise TypeError(f"Unsupported Union type for {config_item_name}: {config_item_info.annotation}")
|
||||
|
||||
if not field_type_origin: # 基础类型 int,bool,str,float 等直接添加
|
||||
return value
|
||||
elif field_type_origin in {list, set, List, Set}:
|
||||
toml_list = tomlkit.array()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,281 @@
|
|||
"""
|
||||
legacy_migration.py
|
||||
|
||||
一个“可随时拔掉”的旧配置兼容层:
|
||||
- 仅在配置解析失败时尝试修复旧格式数据(7.x -> 8.x 这一类结构性变更)
|
||||
- 不依赖 Pydantic / ConfigBase,仅对 dict 做最小转换
|
||||
- 成功则返回(修复后的 dict, True),失败则返回(原 dict, False)
|
||||
|
||||
设计目标:与现有 config 加载逻辑的接触点尽可能小,未来不需要时可一键移除。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("legacy_migration")
|
||||
|
||||
|
||||
# 方便未来快速关闭/移除
|
||||
ENABLE_LEGACY_MIGRATION: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationResult:
|
||||
data: dict[str, Any]
|
||||
migrated: bool
|
||||
reason: str = ""
|
||||
|
||||
|
||||
def _as_dict(x: Any) -> Optional[dict[str, Any]]:
|
||||
return x if isinstance(x, dict) else None
|
||||
|
||||
|
||||
def _as_list(x: Any) -> Optional[list[Any]]:
|
||||
return x if isinstance(x, list) else None
|
||||
|
||||
|
||||
def _parse_triplet_target(s: str) -> Optional[dict[str, str]]:
|
||||
"""
|
||||
解析 "platform:id:type" -> {platform,item_id,rule_type}
|
||||
返回 None 表示无法解析。
|
||||
"""
|
||||
if not isinstance(s, str):
|
||||
return None
|
||||
parts = s.split(":", 2)
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform, item_id, rule_type = parts
|
||||
if rule_type not in ("group", "private"):
|
||||
return None
|
||||
return {"platform": platform, "item_id": item_id, "rule_type": rule_type}
|
||||
|
||||
|
||||
def _parse_quad_prompt(s: str) -> Optional[dict[str, str]]:
|
||||
"""
|
||||
解析 "platform:id:type:prompt" -> {platform,item_id,rule_type,prompt}
|
||||
prompt 允许包含冒号,因此只切前三个冒号。
|
||||
"""
|
||||
if not isinstance(s, str):
|
||||
return None
|
||||
parts = s.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
return None
|
||||
platform, item_id, rule_type, prompt = parts
|
||||
if rule_type not in ("group", "private"):
|
||||
return None
|
||||
if not prompt:
|
||||
return None
|
||||
return {"platform": platform, "item_id": item_id, "rule_type": rule_type, "prompt": prompt}
|
||||
|
||||
|
||||
def _parse_enable_disable(v: Any) -> Optional[bool]:
|
||||
"""
|
||||
兼容旧值 "enable"/"disable" 以及 bool。
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
vv = v.strip().lower()
|
||||
if vv == "enable":
|
||||
return True
|
||||
if vv == "disable":
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
||||
"""
|
||||
旧:
|
||||
learning_list = [
|
||||
["", "enable", "enable", "enable"],
|
||||
["qq:1919810:group", "enable", "enable", "enable"],
|
||||
]
|
||||
新:
|
||||
[[expression.learning_list]]
|
||||
platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true
|
||||
"""
|
||||
ll = _as_list(expr.get("learning_list"))
|
||||
if ll is None:
|
||||
return False
|
||||
|
||||
# 如果已经是新格式(列表里是 dict),跳过
|
||||
if ll and all(isinstance(i, dict) for i in ll):
|
||||
return False
|
||||
|
||||
migrated_items: list[dict[str, Any]] = []
|
||||
for row in ll:
|
||||
r = _as_list(row)
|
||||
if r is None or len(r) < 4:
|
||||
# 行结构不对,无法安全迁移
|
||||
return False
|
||||
|
||||
target_raw = r[0]
|
||||
use_expression = _parse_enable_disable(r[1])
|
||||
enable_learning = _parse_enable_disable(r[2])
|
||||
enable_jargon_learning = _parse_enable_disable(r[3])
|
||||
if use_expression is None or enable_learning is None or enable_jargon_learning is None:
|
||||
return False
|
||||
|
||||
# 旧格式中 target 允许为空字符串:表示全局;新结构必须有三元组字段
|
||||
if target_raw == "" or target_raw is None:
|
||||
target = {"platform": "", "item_id": "", "rule_type": "group"}
|
||||
else:
|
||||
target = _parse_triplet_target(str(target_raw))
|
||||
if target is None:
|
||||
return False
|
||||
|
||||
migrated_items.append(
|
||||
{
|
||||
"platform": target["platform"],
|
||||
"item_id": target["item_id"],
|
||||
"rule_type": target["rule_type"],
|
||||
"use_expression": use_expression,
|
||||
"enable_learning": enable_learning,
|
||||
"enable_jargon_learning": enable_jargon_learning,
|
||||
}
|
||||
)
|
||||
|
||||
expr["learning_list"] = migrated_items
|
||||
return True
|
||||
|
||||
|
||||
def _migrate_expression_groups(expr: dict[str, Any]) -> bool:
|
||||
"""
|
||||
旧:
|
||||
expression_groups = [
|
||||
["qq:1:group","qq:2:group"],
|
||||
["qq:3:group"],
|
||||
]
|
||||
新:
|
||||
expression_groups = [
|
||||
{ expression_groups = [ {platform="qq", item_id="1", rule_type="group"}, ... ] },
|
||||
{ expression_groups = [ ... ] },
|
||||
]
|
||||
"""
|
||||
eg = _as_list(expr.get("expression_groups"))
|
||||
if eg is None:
|
||||
return False
|
||||
|
||||
# 已经是新格式(列表里是 dict 且包含 expression_groups),跳过
|
||||
if eg and all(isinstance(i, dict) for i in eg):
|
||||
return False
|
||||
|
||||
migrated: list[dict[str, Any]] = []
|
||||
for group in eg:
|
||||
g = _as_list(group)
|
||||
if g is None:
|
||||
return False
|
||||
targets: list[dict[str, str]] = []
|
||||
for item in g:
|
||||
parsed = _parse_triplet_target(str(item))
|
||||
if parsed is None:
|
||||
return False
|
||||
targets.append(parsed)
|
||||
migrated.append({"expression_groups": targets})
|
||||
|
||||
expr["expression_groups"] = migrated
|
||||
return True
|
||||
|
||||
|
||||
def _migrate_target_item_list(parent: dict[str, Any], key: str) -> bool:
|
||||
"""
|
||||
将 list[str] 的 "platform:id:type" 迁移为 list[{platform,item_id,rule_type}]
|
||||
用于:memory.global_memory_blacklist / expression.allow_reflect 等。
|
||||
"""
|
||||
raw = _as_list(parent.get(key))
|
||||
if raw is None:
|
||||
return False
|
||||
if raw and all(isinstance(i, dict) for i in raw):
|
||||
return False
|
||||
targets: list[dict[str, str]] = []
|
||||
for item in raw:
|
||||
parsed = _parse_triplet_target(str(item))
|
||||
if parsed is None:
|
||||
return False
|
||||
targets.append(parsed)
|
||||
parent[key] = targets
|
||||
return True
|
||||
|
||||
|
||||
def _migrate_extra_prompt_list(exp: dict[str, Any], key: str) -> bool:
|
||||
"""
|
||||
将 list[str] 的 "platform:id:type:prompt" 迁移为 list[{platform,item_id,rule_type,prompt}]
|
||||
用于:experimental.chat_prompts
|
||||
"""
|
||||
raw = _as_list(exp.get(key))
|
||||
if raw is None:
|
||||
return False
|
||||
if raw and all(isinstance(i, dict) for i in raw):
|
||||
return False
|
||||
items: list[dict[str, str]] = []
|
||||
for item in raw:
|
||||
parsed = _parse_quad_prompt(str(item))
|
||||
if parsed is None:
|
||||
return False
|
||||
items.append(parsed)
|
||||
exp[key] = items
|
||||
return True
|
||||
|
||||
|
||||
def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
"""
|
||||
尝试对“总配置 bot_config.toml”的 dict(已 unwrap)进行旧格式修复。
|
||||
仅做我们明确知道的结构性变更;其它字段不动。
|
||||
"""
|
||||
if not ENABLE_LEGACY_MIGRATION:
|
||||
return MigrationResult(data=data, migrated=False, reason="disabled")
|
||||
|
||||
migrated_any = False
|
||||
reasons: list[str] = []
|
||||
|
||||
expr = _as_dict(data.get("expression"))
|
||||
if expr is not None:
|
||||
if _migrate_expression_learning_list(expr):
|
||||
migrated_any = True
|
||||
reasons.append("expression.learning_list")
|
||||
if _migrate_expression_groups(expr):
|
||||
migrated_any = True
|
||||
reasons.append("expression.expression_groups")
|
||||
# allow_reflect: 旧 list[str] -> 新 list[TargetItem]
|
||||
if _migrate_target_item_list(expr, "allow_reflect"):
|
||||
migrated_any = True
|
||||
reasons.append("expression.allow_reflect")
|
||||
# manual_reflect_operator_id: 旧 str -> 新 Optional[TargetItem]
|
||||
mroi = expr.get("manual_reflect_operator_id")
|
||||
if isinstance(mroi, str) and mroi.strip():
|
||||
parsed = _parse_triplet_target(mroi.strip())
|
||||
if parsed is not None:
|
||||
expr["manual_reflect_operator_id"] = parsed
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id")
|
||||
|
||||
mem = _as_dict(data.get("memory"))
|
||||
if mem is not None:
|
||||
if _migrate_target_item_list(mem, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
reasons.append("memory.global_memory_blacklist")
|
||||
|
||||
exp = _as_dict(data.get("experimental"))
|
||||
if exp is not None:
|
||||
if _migrate_extra_prompt_list(exp, "chat_prompts"):
|
||||
migrated_any = True
|
||||
reasons.append("experimental.chat_prompts")
|
||||
|
||||
# ExpressionConfig 中的 manual_reflect_operator_id:
|
||||
# 旧版本可能是 ""(字符串),新版本期望 Optional[TargetItem]。
|
||||
# 空字符串视为未配置,转换为 None/删除键以避免校验错误。
|
||||
expr = _as_dict(data.get("expression"))
|
||||
if expr is not None:
|
||||
mroi = expr.get("manual_reflect_operator_id")
|
||||
if isinstance(mroi, str) and not mroi.strip():
|
||||
expr.pop("manual_reflect_operator_id", None)
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id_empty")
|
||||
|
||||
reason = ",".join(reasons)
|
||||
return MigrationResult(data=data, migrated=migrated_any, reason=reason)
|
||||
|
||||
Loading…
Reference in New Issue