mirror of https://github.com/Mai-with-u/MaiBot.git
Implement utility functions for plugin config handling
Added utility functions for normalizing dotted keys and coercing types in plugin configuration.pull/1412/head
parent
d7cf57f48e
commit
a14e5b84c1
|
|
@ -1,11 +1,12 @@
|
|||
from fastapi import APIRouter, HTTPException, Header, Cookie
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, get_origin
|
||||
from pathlib import Path
|
||||
import json
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||
from .token_manager import get_token_manager
|
||||
from .plugin_progress_ws import update_progress
|
||||
|
|
@ -65,6 +66,88 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
|||
return (0, 0, 0)
|
||||
|
||||
|
||||
# ============ 工具函数(避免在请求内重复定义) ============
|
||||
|
||||
|
||||
def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将形如 {'a.b': 1} 的键展开为嵌套结构 {'a': {'b': 1}}。
|
||||
若遇到中间节点已存在且非字典,记录日志并覆盖为字典。
|
||||
"""
|
||||
def _deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None:
|
||||
for k, v in src.items():
|
||||
if k in dst and isinstance(dst[k], dict) and isinstance(v, dict):
|
||||
_deep_merge(dst[k], v)
|
||||
else:
|
||||
dst[k] = v
|
||||
|
||||
result: Dict[str, Any] = {}
|
||||
dotted_items = []
|
||||
|
||||
# 先处理非点号键,避免后续展开覆盖已有结构
|
||||
for k, v in obj.items():
|
||||
if "." in k:
|
||||
dotted_items.append((k, v))
|
||||
else:
|
||||
result[k] = normalize_dotted_keys(v) if isinstance(v, dict) else v
|
||||
|
||||
# 再处理点号键
|
||||
for dotted_key, v in dotted_items:
|
||||
value = normalize_dotted_keys(v) if isinstance(v, dict) else v
|
||||
parts = [p for p in dotted_key.split(".") if p]
|
||||
if not parts:
|
||||
logger.warning(f"忽略空键路径: '{dotted_key}'")
|
||||
continue
|
||||
current = result
|
||||
# 中间层
|
||||
for part in parts[:-1]:
|
||||
if part in current and not isinstance(current[part], dict):
|
||||
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key}")
|
||||
current[part] = {}
|
||||
current = current.setdefault(part, {})
|
||||
# 最后一层
|
||||
last_part = parts[-1]
|
||||
if last_part in current and isinstance(current[last_part], dict) and isinstance(value, dict):
|
||||
_deep_merge(current[last_part], value)
|
||||
else:
|
||||
current[last_part] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None:
|
||||
"""
|
||||
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。
|
||||
"""
|
||||
def _is_list_type(tp: Any) -> bool:
|
||||
origin = get_origin(tp)
|
||||
return tp is list or origin is list
|
||||
|
||||
for key, schema_val in schema_part.items():
|
||||
if key not in config_part:
|
||||
continue
|
||||
value = config_part[key]
|
||||
if isinstance(schema_val, ConfigField):
|
||||
if _is_list_type(schema_val.type) and isinstance(value, str):
|
||||
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(schema_val, dict) and isinstance(value, dict):
|
||||
coerce_types(schema_val, value)
|
||||
|
||||
|
||||
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
|
||||
"""
|
||||
按 plugin_id 或 plugin_name 查找已加载的插件实例。
|
||||
局部导入 plugin_manager 以规避循环依赖。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
|
||||
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
|
||||
if instance and (instance.plugin_name == plugin_id or instance.get_manifest_info("id", "") == plugin_id):
|
||||
return instance
|
||||
return None
|
||||
|
||||
|
||||
# ============ 请求/响应模型 ============
|
||||
|
||||
|
||||
|
|
@ -1388,6 +1471,14 @@ async def update_plugin_config(
|
|||
logger.info(f"更新插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_instance = find_plugin_instance(plugin_id)
|
||||
|
||||
# 纠正 WebUI 提交的数据结构(扁平键与字符串列表)
|
||||
if plugin_instance and isinstance(request.config, dict):
|
||||
request.config = normalize_dotted_keys(request.config)
|
||||
if isinstance(plugin_instance.config_schema, dict):
|
||||
coerce_types(plugin_instance.config_schema, request.config)
|
||||
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue