mirror of https://github.com/Mai-with-u/MaiBot.git
commit
f8446e6bf7
|
|
@ -4,6 +4,7 @@ TOML 工具函数
|
||||||
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
|
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import tomlkit
|
import tomlkit
|
||||||
from tomlkit.items import AoT, Table, Array
|
from tomlkit.items import AoT, Table, Array
|
||||||
|
|
@ -54,14 +55,71 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def save_toml_with_format(data: Any, file_path: str, multiline_threshold: int = 1) -> None:
|
def _update_toml_doc(target: Any, source: Any) -> None:
|
||||||
"""格式化 TOML 数据并保存到文件"""
|
"""
|
||||||
|
递归合并字典,将 source 的值更新到 target 中,保留 target 的注释和格式。
|
||||||
|
- 已存在的键:更新值(递归处理嵌套字典)
|
||||||
|
- 新增的键:添加到 target
|
||||||
|
- 跳过 version 字段
|
||||||
|
"""
|
||||||
|
if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
for key, value in source.items():
|
||||||
|
if key == "version":
|
||||||
|
continue
|
||||||
|
if key in target:
|
||||||
|
# 已存在的键:递归更新或直接赋值
|
||||||
|
target_value = target[key]
|
||||||
|
if isinstance(value, dict) and isinstance(target_value, dict):
|
||||||
|
_update_toml_doc(target_value, value)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
target[key] = tomlkit.item(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
target[key] = value
|
||||||
|
else:
|
||||||
|
# 新增的键:添加到 target
|
||||||
|
try:
|
||||||
|
target[key] = tomlkit.item(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
target[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def save_toml_with_format(
|
||||||
|
data: Any, file_path: str, multiline_threshold: int = 1, preserve_comments: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
格式化 TOML 数据并保存到文件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 要保存的数据(dict 或 tomlkit 文档)
|
||||||
|
file_path: 保存路径
|
||||||
|
multiline_threshold: 数组多行格式化阈值,-1 表示不格式化
|
||||||
|
preserve_comments: 是否保留原文件的注释和格式(默认 True)
|
||||||
|
若为 True 且文件已存在且 data 不是 tomlkit 文档,会先读取原文件,再将 data 合并进去
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from tomlkit import TOMLDocument
|
||||||
|
|
||||||
|
# 如果需要保留注释、文件存在、且 data 不是已有的 tomlkit 文档,先读取原文件再合并
|
||||||
|
if preserve_comments and os.path.exists(file_path) and not isinstance(data, TOMLDocument):
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
doc = tomlkit.load(f)
|
||||||
|
_update_toml_doc(doc, data)
|
||||||
|
data = doc
|
||||||
|
|
||||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||||
|
output = tomlkit.dumps(formatted)
|
||||||
|
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||||
|
output = re.sub(r'\n{3,}', '\n\n', output)
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
tomlkit.dump(formatted, f)
|
f.write(output)
|
||||||
|
|
||||||
|
|
||||||
def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
|
def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
|
||||||
"""格式化 TOML 数据并返回字符串"""
|
"""格式化 TOML 数据并返回字符串"""
|
||||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||||
return tomlkit.dumps(formatted)
|
output = tomlkit.dumps(formatted)
|
||||||
|
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||||
|
return re.sub(r'\n{3,}', '\n\n', output)
|
||||||
|
|
@ -8,7 +8,7 @@ from fastapi import APIRouter, HTTPException, Body
|
||||||
from typing import Any, Annotated
|
from typing import Any, Annotated
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.toml_utils import save_toml_with_format
|
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||||
from src.config.official_configs import (
|
from src.config.official_configs import (
|
||||||
BotConfig,
|
BotConfig,
|
||||||
|
|
@ -51,40 +51,6 @@ PathBody = Annotated[dict[str, str], Body()]
|
||||||
router = APIRouter(prefix="/config", tags=["config"])
|
router = APIRouter(prefix="/config", tags=["config"])
|
||||||
|
|
||||||
|
|
||||||
# ===== 辅助函数 =====
|
|
||||||
|
|
||||||
|
|
||||||
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
|
|
||||||
"""
|
|
||||||
递归合并字典,保留 target 中的注释和格式
|
|
||||||
将 source 的值更新到 target 中(仅更新已存在的键)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target: 目标字典(tomlkit 对象,包含注释)
|
|
||||||
source: 源字典(普通 dict 或 list)
|
|
||||||
"""
|
|
||||||
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
|
|
||||||
if isinstance(source, list):
|
|
||||||
return # 调用者需要直接赋值
|
|
||||||
|
|
||||||
# 如果都是字典,递归合并
|
|
||||||
if isinstance(source, dict) and isinstance(target, dict):
|
|
||||||
for key, value in source.items():
|
|
||||||
if key == "version":
|
|
||||||
continue # 跳过版本号
|
|
||||||
if key in target:
|
|
||||||
target_value = target[key]
|
|
||||||
# 递归处理嵌套字典
|
|
||||||
if isinstance(value, dict) and isinstance(target_value, dict):
|
|
||||||
_update_dict_preserve_comments(target_value, value)
|
|
||||||
else:
|
|
||||||
# 使用 tomlkit.item 保持类型
|
|
||||||
try:
|
|
||||||
target[key] = tomlkit.item(value)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
target[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
# ===== 架构获取接口 =====
|
# ===== 架构获取接口 =====
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -238,7 +204,7 @@ async def update_bot_config(config_data: ConfigBody):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置文件(格式化数组为多行)
|
# 保存配置文件(自动保留注释和格式)
|
||||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
save_toml_with_format(config_data, config_path)
|
save_toml_with_format(config_data, config_path)
|
||||||
|
|
||||||
|
|
@ -261,7 +227,7 @@ async def update_model_config(config_data: ConfigBody):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置文件(格式化数组为多行)
|
# 保存配置文件(自动保留注释和格式)
|
||||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||||
save_toml_with_format(config_data, config_path)
|
save_toml_with_format(config_data, config_path)
|
||||||
|
|
||||||
|
|
@ -300,7 +266,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
||||||
config_data[section_name] = section_data
|
config_data[section_name] = section_data
|
||||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||||
# 字典递归合并
|
# 字典递归合并
|
||||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
_update_toml_doc(config_data[section_name], section_data)
|
||||||
else:
|
else:
|
||||||
# 其他类型直接替换
|
# 其他类型直接替换
|
||||||
config_data[section_name] = section_data
|
config_data[section_name] = section_data
|
||||||
|
|
@ -398,7 +364,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||||
config_data[section_name] = section_data
|
config_data[section_name] = section_data
|
||||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||||
# 字典递归合并
|
# 字典递归合并
|
||||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
_update_toml_doc(config_data[section_name], section_data)
|
||||||
else:
|
else:
|
||||||
# 其他类型直接替换
|
# 其他类型直接替换
|
||||||
config_data[section_name] = section_data
|
config_data[section_name] = section_data
|
||||||
|
|
|
||||||
|
|
@ -1420,18 +1420,8 @@ async def update_plugin_config(
|
||||||
shutil.copy(config_path, backup_path)
|
shutil.copy(config_path, backup_path)
|
||||||
logger.info(f"已备份配置文件: {backup_path}")
|
logger.info(f"已备份配置文件: {backup_path}")
|
||||||
|
|
||||||
# 写入新配置(使用 tomlkit 保留注释)
|
# 写入新配置(自动保留注释和格式)
|
||||||
import tomlkit
|
save_toml_with_format(request.config, str(config_path))
|
||||||
|
|
||||||
# 先读取原配置以保留注释和格式
|
|
||||||
existing_doc = tomlkit.document()
|
|
||||||
if config_path.exists():
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
existing_doc = tomlkit.load(f)
|
|
||||||
# 更新值
|
|
||||||
for key, value in request.config.items():
|
|
||||||
existing_doc[key] = value
|
|
||||||
save_toml_with_format(existing_doc, str(config_path))
|
|
||||||
|
|
||||||
logger.info(f"已更新插件配置: {plugin_id}")
|
logger.info(f"已更新插件配置: {plugin_id}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue