mirror of https://github.com/Mai-with-u/MaiBot.git
Merge pull request #9 from XXXxx7258/XXXxx7258-patch-1
Implement utility functions for plugin config handlingpull/1412/head
commit
9188b944bf
|
|
@ -1,11 +1,12 @@
|
||||||
from fastapi import APIRouter, HTTPException, Header, Cookie
|
from fastapi import APIRouter, HTTPException, Header, Cookie
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any, get_origin
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.toml_utils import save_toml_with_format
|
from src.common.toml_utils import save_toml_with_format
|
||||||
from src.config.config import MMC_VERSION
|
from src.config.config import MMC_VERSION
|
||||||
|
from src.plugin_system.base.config_types import ConfigField
|
||||||
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||||
from .token_manager import get_token_manager
|
from .token_manager import get_token_manager
|
||||||
from .plugin_progress_ws import update_progress
|
from .plugin_progress_ws import update_progress
|
||||||
|
|
@ -65,6 +66,88 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||||
return (0, 0, 0)
|
return (0, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 工具函数(避免在请求内重复定义) ============
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
将形如 {'a.b': 1} 的键展开为嵌套结构 {'a': {'b': 1}}。
|
||||||
|
若遇到中间节点已存在且非字典,记录日志并覆盖为字典。
|
||||||
|
"""
|
||||||
|
def _deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None:
|
||||||
|
for k, v in src.items():
|
||||||
|
if k in dst and isinstance(dst[k], dict) and isinstance(v, dict):
|
||||||
|
_deep_merge(dst[k], v)
|
||||||
|
else:
|
||||||
|
dst[k] = v
|
||||||
|
|
||||||
|
result: Dict[str, Any] = {}
|
||||||
|
dotted_items = []
|
||||||
|
|
||||||
|
# 先处理非点号键,避免后续展开覆盖已有结构
|
||||||
|
for k, v in obj.items():
|
||||||
|
if "." in k:
|
||||||
|
dotted_items.append((k, v))
|
||||||
|
else:
|
||||||
|
result[k] = normalize_dotted_keys(v) if isinstance(v, dict) else v
|
||||||
|
|
||||||
|
# 再处理点号键
|
||||||
|
for dotted_key, v in dotted_items:
|
||||||
|
value = normalize_dotted_keys(v) if isinstance(v, dict) else v
|
||||||
|
parts = [p for p in dotted_key.split(".") if p]
|
||||||
|
if not parts:
|
||||||
|
logger.warning(f"忽略空键路径: '{dotted_key}'")
|
||||||
|
continue
|
||||||
|
current = result
|
||||||
|
# 中间层
|
||||||
|
for part in parts[:-1]:
|
||||||
|
if part in current and not isinstance(current[part], dict):
|
||||||
|
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key}")
|
||||||
|
current[part] = {}
|
||||||
|
current = current.setdefault(part, {})
|
||||||
|
# 最后一层
|
||||||
|
last_part = parts[-1]
|
||||||
|
if last_part in current and isinstance(current[last_part], dict) and isinstance(value, dict):
|
||||||
|
_deep_merge(current[last_part], value)
|
||||||
|
else:
|
||||||
|
current[last_part] = value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。
|
||||||
|
"""
|
||||||
|
def _is_list_type(tp: Any) -> bool:
|
||||||
|
origin = get_origin(tp)
|
||||||
|
return tp is list or origin is list
|
||||||
|
|
||||||
|
for key, schema_val in schema_part.items():
|
||||||
|
if key not in config_part:
|
||||||
|
continue
|
||||||
|
value = config_part[key]
|
||||||
|
if isinstance(schema_val, ConfigField):
|
||||||
|
if _is_list_type(schema_val.type) and isinstance(value, str):
|
||||||
|
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
|
||||||
|
elif isinstance(schema_val, dict) and isinstance(value, dict):
|
||||||
|
coerce_types(schema_val, value)
|
||||||
|
|
||||||
|
|
||||||
|
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
按 plugin_id 或 plugin_name 查找已加载的插件实例。
|
||||||
|
局部导入 plugin_manager 以规避循环依赖。
|
||||||
|
"""
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
|
||||||
|
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
|
||||||
|
if instance and (instance.plugin_name == plugin_id or instance.get_manifest_info("id", "") == plugin_id):
|
||||||
|
return instance
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# ============ 请求/响应模型 ============
|
# ============ 请求/响应模型 ============
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1388,6 +1471,14 @@ async def update_plugin_config(
|
||||||
logger.info(f"更新插件配置: {plugin_id}")
|
logger.info(f"更新插件配置: {plugin_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
plugin_instance = find_plugin_instance(plugin_id)
|
||||||
|
|
||||||
|
# 纠正 WebUI 提交的数据结构(扁平键与字符串列表)
|
||||||
|
if plugin_instance and isinstance(request.config, dict):
|
||||||
|
request.config = normalize_dotted_keys(request.config)
|
||||||
|
if isinstance(plugin_instance.config_schema, dict):
|
||||||
|
coerce_types(plugin_instance.config_schema, request.config)
|
||||||
|
|
||||||
# 查找插件目录
|
# 查找插件目录
|
||||||
plugins_dir = Path("plugins")
|
plugins_dir = Path("plugins")
|
||||||
plugin_path = None
|
plugin_path = None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue