mirror of https://github.com/Mai-with-u/MaiBot.git
fix(webui): fix missing imports and create toml_utils module
- Create src/common/toml_utils.py with TOML utility functions - Fix APIAdapterConfig → ModelConfig in config.py (4 locations) - Fix git_mirror_service import path in plugin.py - Fix emoji.py type annotations and unused imports - Fix jargon.py comment (ChatStreams → ChatSession) - All routers now import successfully - Zero Peewee remnants verified across src/webui/pull/1496/head
parent
f97c24bf9e
commit
f66e25b1a7
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
TOML文件工具函数 - 保留格式和注释
|
||||
"""
|
||||
|
||||
import os
|
||||
import tomlkit
|
||||
from typing import Any
|
||||
|
||||
|
||||
def save_toml_with_format(data: dict[str, Any], file_path: str) -> None:
|
||||
"""
|
||||
保存TOML数据到文件,保留现有格式(如果文件存在)
|
||||
|
||||
Args:
|
||||
data: 要保存的数据字典
|
||||
file_path: 文件路径
|
||||
"""
|
||||
# 如果文件不存在,直接创建
|
||||
if not os.path.exists(file_path):
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(data, f)
|
||||
return
|
||||
|
||||
# 如果文件存在,尝试读取现有文件以保留格式
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
existing_doc = tomlkit.load(f)
|
||||
except Exception:
|
||||
# 如果读取失败,直接覆盖
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(data, f)
|
||||
return
|
||||
|
||||
# 递归更新,保留现有格式
|
||||
_merge_toml_preserving_format(existing_doc, data)
|
||||
|
||||
# 保存
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(existing_doc, f)
|
||||
|
||||
|
||||
def _merge_toml_preserving_format(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||
"""
|
||||
递归合并source到target,保留target中的格式和注释
|
||||
|
||||
Args:
|
||||
target: 目标文档(保留格式)
|
||||
source: 源数据(新数据)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
if key in target:
|
||||
# 如果两个都是字典且都是表格,递归合并
|
||||
if isinstance(value, dict) and isinstance(target[key], dict):
|
||||
if hasattr(target[key], "items"): # 确实是字典/表格
|
||||
_merge_toml_preserving_format(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
else:
|
||||
# 其他情况直接替换
|
||||
target[key] = value
|
||||
else:
|
||||
# 新键直接添加
|
||||
target[key] = value
|
||||
|
||||
|
||||
def _update_toml_doc(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||
"""
|
||||
更新TOML文档中的字段,保留现有的格式和注释
|
||||
|
||||
这是一个递归函数,用于在部分更新配置时保留现有的格式和注释。
|
||||
|
||||
Args:
|
||||
target: 目标表格(会被修改)
|
||||
source: 源数据(新数据)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
if key in target:
|
||||
# 如果两个都是字典,递归更新
|
||||
if isinstance(value, dict) and isinstance(target[key], dict):
|
||||
if hasattr(target[key], "items"): # 确实是表格
|
||||
_update_toml_doc(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
else:
|
||||
# 直接更新值,保留注释
|
||||
target[key] = value
|
||||
else:
|
||||
# 新键直接添加
|
||||
target[key] = value
|
||||
|
|
@ -10,7 +10,7 @@ from typing import Any, Annotated, Optional
|
|||
from src.common.logger import get_logger
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
|
|
@ -77,7 +77,7 @@ async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
|||
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(ModelConfig)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型配置架构失败: {e}")
|
||||
|
|
@ -227,7 +227,7 @@ async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(req
|
|||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
ModelConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
|
|
@ -377,7 +377,7 @@ async def update_model_config_section(
|
|||
|
||||
# 验证完整配置
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
ModelConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List, Optional
|
||||
from typing import Annotated, Any, List, Optional
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
|
|
@ -16,7 +16,7 @@ from fastapi.responses import FileResponse, JSONResponse
|
|||
from pydantic import BaseModel
|
||||
from PIL import Image
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, delete, select
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
|
|
@ -67,7 +67,7 @@ def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
|||
|
||||
def _ensure_thumbnail_cache_dir() -> Path:
|
||||
"""确保缩略图缓存目录存在"""
|
||||
THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_ = THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return THUMBNAIL_CACHE_DIR
|
||||
|
||||
|
||||
|
|
@ -947,7 +947,7 @@ async def upload_emoji(
|
|||
|
||||
# 保存文件
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
_ = f.write(file_content)
|
||||
|
||||
logger.info(f"表情包文件已保存: {full_path}")
|
||||
|
||||
|
|
@ -1010,7 +1010,7 @@ async def batch_upload_emoji(
|
|||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
results = {
|
||||
results: dict[str, Any] = {
|
||||
"success": True,
|
||||
"total": len(files),
|
||||
"uploaded": 0,
|
||||
|
|
@ -1095,7 +1095,7 @@ async def batch_upload_emoji(
|
|||
counter += 1
|
||||
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
_ = f.write(file_content)
|
||||
|
||||
# 处理情感标签
|
||||
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
|||
def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
|
||||
"""
|
||||
获取 chat_id 的显示名称
|
||||
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
|
||||
尝试解析 JSON 并查询 ChatSession 表获取群聊名称
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from src.common.logger import get_logger
|
|||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.webui.git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||
from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.plugin_progress import update_progress
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue