Implement utility functions for plugin config handling

Added utility functions for normalizing dotted keys and coercing types in plugin configuration.
pull/1412/head
晴空 2025-12-06 22:45:38 +08:00 committed by GitHub
parent d7cf57f48e
commit a14e5b84c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 92 additions and 1 deletions

View File

@ -1,11 +1,12 @@
from fastapi import APIRouter, HTTPException, Header, Cookie from fastapi import APIRouter, HTTPException, Header, Cookie
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any, get_origin
from pathlib import Path from pathlib import Path
import json import json
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION from src.config.config import MMC_VERSION
from src.plugin_system.base.config_types import ConfigField
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
from .token_manager import get_token_manager from .token_manager import get_token_manager
from .plugin_progress_ws import update_progress from .plugin_progress_ws import update_progress
@ -65,6 +66,88 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
return (0, 0, 0) return (0, 0, 0)
# ============ 工具函数(避免在请求内重复定义) ============
def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
"""
将形如 {'a.b': 1} 的键展开为嵌套结构 {'a': {'b': 1}}
若遇到中间节点已存在且非字典记录日志并覆盖为字典
"""
def _deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None:
for k, v in src.items():
if k in dst and isinstance(dst[k], dict) and isinstance(v, dict):
_deep_merge(dst[k], v)
else:
dst[k] = v
result: Dict[str, Any] = {}
dotted_items = []
# 先处理非点号键,避免后续展开覆盖已有结构
for k, v in obj.items():
if "." in k:
dotted_items.append((k, v))
else:
result[k] = normalize_dotted_keys(v) if isinstance(v, dict) else v
# 再处理点号键
for dotted_key, v in dotted_items:
value = normalize_dotted_keys(v) if isinstance(v, dict) else v
parts = [p for p in dotted_key.split(".") if p]
if not parts:
logger.warning(f"忽略空键路径: '{dotted_key}'")
continue
current = result
# 中间层
for part in parts[:-1]:
if part in current and not isinstance(current[part], dict):
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key}")
current[part] = {}
current = current.setdefault(part, {})
# 最后一层
last_part = parts[-1]
if last_part in current and isinstance(current[last_part], dict) and isinstance(value, dict):
_deep_merge(current[last_part], value)
else:
current[last_part] = value
return result
def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None:
"""
根据 schema 将配置中的类型纠正目前只纠正 list-from-str
"""
def _is_list_type(tp: Any) -> bool:
origin = get_origin(tp)
return tp is list or origin is list
for key, schema_val in schema_part.items():
if key not in config_part:
continue
value = config_part[key]
if isinstance(schema_val, ConfigField):
if _is_list_type(schema_val.type) and isinstance(value, str):
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(schema_val, dict) and isinstance(value, dict):
coerce_types(schema_val, value)
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
"""
plugin_id plugin_name 查找已加载的插件实例
局部导入 plugin_manager 以规避循环依赖
"""
from src.plugin_system.core.plugin_manager import plugin_manager
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
if instance and (instance.plugin_name == plugin_id or instance.get_manifest_info("id", "") == plugin_id):
return instance
return None
# ============ 请求/响应模型 ============ # ============ 请求/响应模型 ============
@ -1388,6 +1471,14 @@ async def update_plugin_config(
logger.info(f"更新插件配置: {plugin_id}") logger.info(f"更新插件配置: {plugin_id}")
try: try:
plugin_instance = find_plugin_instance(plugin_id)
# 纠正 WebUI 提交的数据结构(扁平键与字符串列表)
if plugin_instance and isinstance(request.config, dict):
request.config = normalize_dotted_keys(request.config)
if isinstance(plugin_instance.config_schema, dict):
coerce_types(plugin_instance.config_schema, request.config)
# 查找插件目录 # 查找插件目录
plugins_dir = Path("plugins") plugins_dir = Path("plugins")
plugin_path = None plugin_path = None