Merge pull request #1398 from Ronifue/dev

fix: 修复保存toml时的空行累计bug和注释丢失问题
pull/1414/head
墨梓柒 2025-12-03 10:50:58 +08:00 committed by GitHub
commit f8446e6bf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 55 deletions

View File

@ -4,6 +4,7 @@ TOML 工具函数
提供 TOML 文件的格式化保存功能确保数组等元素以美观的多行格式输出 提供 TOML 文件的格式化保存功能确保数组等元素以美观的多行格式输出
""" """
import re
from typing import Any from typing import Any
import tomlkit import tomlkit
from tomlkit.items import AoT, Table, Array from tomlkit.items import AoT, Table, Array
@ -54,14 +55,71 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
return obj return obj
def save_toml_with_format(data: Any, file_path: str, multiline_threshold: int = 1) -> None: def _update_toml_doc(target: Any, source: Any) -> None:
"""格式化 TOML 数据并保存到文件""" """
递归合并字典 source 的值更新到 target 保留 target 的注释和格式
- 已存在的键更新值递归处理嵌套字典
- 新增的键添加到 target
- 跳过 version 字段
"""
if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict):
return
for key, value in source.items():
if key == "version":
continue
if key in target:
# 已存在的键:递归更新或直接赋值
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, dict):
_update_toml_doc(target_value, value)
else:
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
else:
# 新增的键:添加到 target
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
def save_toml_with_format(
data: Any, file_path: str, multiline_threshold: int = 1, preserve_comments: bool = True
) -> None:
"""
格式化 TOML 数据并保存到文件
Args:
data: 要保存的数据dict tomlkit 文档
file_path: 保存路径
multiline_threshold: 数组多行格式化阈值-1 表示不格式化
preserve_comments: 是否保留原文件的注释和格式默认 True
若为 True 且文件已存在且 data 不是 tomlkit 文档会先读取原文件再将 data 合并进去
"""
import os
from tomlkit import TOMLDocument
# 如果需要保留注释、文件存在、且 data 不是已有的 tomlkit 文档,先读取原文件再合并
if preserve_comments and os.path.exists(file_path) and not isinstance(data, TOMLDocument):
with open(file_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
_update_toml_doc(doc, data)
data = doc
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
output = re.sub(r'\n{3,}', '\n\n', output)
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(formatted, f) f.write(output)
def format_toml_string(data: Any, multiline_threshold: int = 1) -> str: def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
"""格式化 TOML 数据并返回字符串""" """格式化 TOML 数据并返回字符串"""
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
return tomlkit.dumps(formatted) output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
return re.sub(r'\n{3,}', '\n\n', output)

View File

@ -8,7 +8,7 @@ from fastapi import APIRouter, HTTPException, Body
from typing import Any, Annotated from typing import Any, Annotated
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import ( from src.config.official_configs import (
BotConfig, BotConfig,
@ -51,40 +51,6 @@ PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"]) router = APIRouter(prefix="/config", tags=["config"])
# ===== 辅助函数 =====
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
"""
递归合并字典保留 target 中的注释和格式
source 的值更新到 target 仅更新已存在的键
Args:
target: 目标字典tomlkit 对象包含注释
source: 源字典普通 dict list
"""
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
if isinstance(source, list):
return # 调用者需要直接赋值
# 如果都是字典,递归合并
if isinstance(source, dict) and isinstance(target, dict):
for key, value in source.items():
if key == "version":
continue # 跳过版本号
if key in target:
target_value = target[key]
# 递归处理嵌套字典
if isinstance(value, dict) and isinstance(target_value, dict):
_update_dict_preserve_comments(target_value, value)
else:
# 使用 tomlkit.item 保持类型
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
# ===== 架构获取接口 ===== # ===== 架构获取接口 =====
@ -238,7 +204,7 @@ async def update_bot_config(config_data: ConfigBody):
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(格式化数组为多行 # 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "bot_config.toml") config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
save_toml_with_format(config_data, config_path) save_toml_with_format(config_data, config_path)
@ -261,7 +227,7 @@ async def update_model_config(config_data: ConfigBody):
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(格式化数组为多行 # 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "model_config.toml") config_path = os.path.join(CONFIG_DIR, "model_config.toml")
save_toml_with_format(config_data, config_path) save_toml_with_format(config_data, config_path)
@ -300,7 +266,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
config_data[section_name] = section_data config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict): elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并 # 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data) _update_toml_doc(config_data[section_name], section_data)
else: else:
# 其他类型直接替换 # 其他类型直接替换
config_data[section_name] = section_data config_data[section_name] = section_data
@ -398,7 +364,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
config_data[section_name] = section_data config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict): elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并 # 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data) _update_toml_doc(config_data[section_name], section_data)
else: else:
# 其他类型直接替换 # 其他类型直接替换
config_data[section_name] = section_data config_data[section_name] = section_data

View File

@ -1420,18 +1420,8 @@ async def update_plugin_config(
shutil.copy(config_path, backup_path) shutil.copy(config_path, backup_path)
logger.info(f"已备份配置文件: {backup_path}") logger.info(f"已备份配置文件: {backup_path}")
# 写入新配置(使用 tomlkit 保留注释) # 写入新配置(自动保留注释和格式)
import tomlkit save_toml_with_format(request.config, str(config_path))
# 先读取原配置以保留注释和格式
existing_doc = tomlkit.document()
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
existing_doc = tomlkit.load(f)
# 更新值
for key, value in request.config.items():
existing_doc[key] = value
save_toml_with_format(existing_doc, str(config_path))
logger.info(f"已更新插件配置: {plugin_id}") logger.info(f"已更新插件配置: {plugin_id}")