Merge pull request #9 from XXXxx7258/XXXxx7258-patch-1

Implement utility functions for plugin config handling
pull/1412/head
晴空 2025-12-06 22:53:36 +08:00 committed by GitHub
commit 9188b944bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 92 additions and 1 deletions

View File

@ -1,11 +1,12 @@
from fastapi import APIRouter, HTTPException, Header, Cookie
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, get_origin
from pathlib import Path
import json
from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION
from src.plugin_system.base.config_types import ConfigField
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
from .token_manager import get_token_manager
from .plugin_progress_ws import update_progress
@ -65,6 +66,88 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
return (0, 0, 0)
# ============ 工具函数(避免在请求内重复定义) ============
def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
"""
将形如 {'a.b': 1} 的键展开为嵌套结构 {'a': {'b': 1}}
若遇到中间节点已存在且非字典记录日志并覆盖为字典
"""
def _deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None:
for k, v in src.items():
if k in dst and isinstance(dst[k], dict) and isinstance(v, dict):
_deep_merge(dst[k], v)
else:
dst[k] = v
result: Dict[str, Any] = {}
dotted_items = []
# 先处理非点号键,避免后续展开覆盖已有结构
for k, v in obj.items():
if "." in k:
dotted_items.append((k, v))
else:
result[k] = normalize_dotted_keys(v) if isinstance(v, dict) else v
# 再处理点号键
for dotted_key, v in dotted_items:
value = normalize_dotted_keys(v) if isinstance(v, dict) else v
parts = [p for p in dotted_key.split(".") if p]
if not parts:
logger.warning(f"忽略空键路径: '{dotted_key}'")
continue
current = result
# 中间层
for part in parts[:-1]:
if part in current and not isinstance(current[part], dict):
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key}")
current[part] = {}
current = current.setdefault(part, {})
# 最后一层
last_part = parts[-1]
if last_part in current and isinstance(current[last_part], dict) and isinstance(value, dict):
_deep_merge(current[last_part], value)
else:
current[last_part] = value
return result
def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None:
"""
根据 schema 将配置中的类型纠正目前只纠正 list-from-str
"""
def _is_list_type(tp: Any) -> bool:
origin = get_origin(tp)
return tp is list or origin is list
for key, schema_val in schema_part.items():
if key not in config_part:
continue
value = config_part[key]
if isinstance(schema_val, ConfigField):
if _is_list_type(schema_val.type) and isinstance(value, str):
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(schema_val, dict) and isinstance(value, dict):
coerce_types(schema_val, value)
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
"""
plugin_id plugin_name 查找已加载的插件实例
局部导入 plugin_manager 以规避循环依赖
"""
from src.plugin_system.core.plugin_manager import plugin_manager
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
if instance and (instance.plugin_name == plugin_id or instance.get_manifest_info("id", "") == plugin_id):
return instance
return None
# ============ 请求/响应模型 ============
@ -1388,6 +1471,14 @@ async def update_plugin_config(
logger.info(f"更新插件配置: {plugin_id}")
try:
plugin_instance = find_plugin_instance(plugin_id)
# 纠正 WebUI 提交的数据结构(扁平键与字符串列表)
if plugin_instance and isinstance(request.config, dict):
request.config = normalize_dotted_keys(request.config)
if isinstance(plugin_instance.config_schema, dict):
coerce_types(plugin_instance.config_schema, request.config)
# 查找插件目录
plugins_dir = Path("plugins")
plugin_path = None