mirror of https://github.com/Mai-with-u/MaiBot.git
fix(webui): fix missing imports and create toml_utils module
- Create src/common/toml_utils.py with TOML utility functions - Fix APIAdapterConfig → ModelConfig in config.py (4 locations) - Fix git_mirror_service import path in plugin.py - Fix emoji.py type annotations and unused imports - Fix jargon.py comment (ChatStreams → ChatSession) - All routers now import successfully - Zero Peewee remnants verified across src/webui/pull/1496/head
parent
f97c24bf9e
commit
f66e25b1a7
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""
|
||||||
|
TOML文件工具函数 - 保留格式和注释
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tomlkit
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def save_toml_with_format(data: dict[str, Any], file_path: str) -> None:
|
||||||
|
"""
|
||||||
|
保存TOML数据到文件,保留现有格式(如果文件存在)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 要保存的数据字典
|
||||||
|
file_path: 文件路径
|
||||||
|
"""
|
||||||
|
# 如果文件不存在,直接创建
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
tomlkit.dump(data, f)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果文件存在,尝试读取现有文件以保留格式
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
existing_doc = tomlkit.load(f)
|
||||||
|
except Exception:
|
||||||
|
# 如果读取失败,直接覆盖
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
tomlkit.dump(data, f)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 递归更新,保留现有格式
|
||||||
|
_merge_toml_preserving_format(existing_doc, data)
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
tomlkit.dump(existing_doc, f)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_toml_preserving_format(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
递归合并source到target,保留target中的格式和注释
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: 目标文档(保留格式)
|
||||||
|
source: 源数据(新数据)
|
||||||
|
"""
|
||||||
|
for key, value in source.items():
|
||||||
|
if key in target:
|
||||||
|
# 如果两个都是字典且都是表格,递归合并
|
||||||
|
if isinstance(value, dict) and isinstance(target[key], dict):
|
||||||
|
if hasattr(target[key], "items"): # 确实是字典/表格
|
||||||
|
_merge_toml_preserving_format(target[key], value)
|
||||||
|
else:
|
||||||
|
target[key] = value
|
||||||
|
else:
|
||||||
|
# 其他情况直接替换
|
||||||
|
target[key] = value
|
||||||
|
else:
|
||||||
|
# 新键直接添加
|
||||||
|
target[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def _update_toml_doc(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
更新TOML文档中的字段,保留现有的格式和注释
|
||||||
|
|
||||||
|
这是一个递归函数,用于在部分更新配置时保留现有的格式和注释。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: 目标表格(会被修改)
|
||||||
|
source: 源数据(新数据)
|
||||||
|
"""
|
||||||
|
for key, value in source.items():
|
||||||
|
if key in target:
|
||||||
|
# 如果两个都是字典,递归更新
|
||||||
|
if isinstance(value, dict) and isinstance(target[key], dict):
|
||||||
|
if hasattr(target[key], "items"): # 确实是表格
|
||||||
|
_update_toml_doc(target[key], value)
|
||||||
|
else:
|
||||||
|
target[key] = value
|
||||||
|
else:
|
||||||
|
# 直接更新值,保留注释
|
||||||
|
target[key] = value
|
||||||
|
else:
|
||||||
|
# 新键直接添加
|
||||||
|
target[key] = value
|
||||||
|
|
@ -10,7 +10,7 @@ from typing import Any, Annotated, Optional
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
|
||||||
from src.config.official_configs import (
|
from src.config.official_configs import (
|
||||||
BotConfig,
|
BotConfig,
|
||||||
PersonalityConfig,
|
PersonalityConfig,
|
||||||
|
|
@ -77,7 +77,7 @@ async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||||
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||||
try:
|
try:
|
||||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
schema = ConfigSchemaGenerator.generate_config_schema(ModelConfig)
|
||||||
return {"success": True, "schema": schema}
|
return {"success": True, "schema": schema}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取模型配置架构失败: {e}")
|
logger.error(f"获取模型配置架构失败: {e}")
|
||||||
|
|
@ -227,7 +227,7 @@ async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(req
|
||||||
try:
|
try:
|
||||||
# 验证配置数据
|
# 验证配置数据
|
||||||
try:
|
try:
|
||||||
APIAdapterConfig.from_dict(config_data)
|
ModelConfig.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
@ -377,7 +377,7 @@ async def update_model_config_section(
|
||||||
|
|
||||||
# 验证完整配置
|
# 验证完整配置
|
||||||
try:
|
try:
|
||||||
APIAdapterConfig.from_dict(config_data)
|
ModelConfig.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, List, Optional
|
from typing import Annotated, Any, List, Optional
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
@ -16,7 +16,7 @@ from fastapi.responses import FileResponse, JSONResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlmodel import col, delete, select
|
from sqlmodel import col, select
|
||||||
|
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import Images, ImageType
|
from src.common.database.database_model import Images, ImageType
|
||||||
|
|
@ -67,7 +67,7 @@ def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
||||||
|
|
||||||
def _ensure_thumbnail_cache_dir() -> Path:
|
def _ensure_thumbnail_cache_dir() -> Path:
|
||||||
"""确保缩略图缓存目录存在"""
|
"""确保缩略图缓存目录存在"""
|
||||||
THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
_ = THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
return THUMBNAIL_CACHE_DIR
|
return THUMBNAIL_CACHE_DIR
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -947,7 +947,7 @@ async def upload_emoji(
|
||||||
|
|
||||||
# 保存文件
|
# 保存文件
|
||||||
with open(full_path, "wb") as f:
|
with open(full_path, "wb") as f:
|
||||||
f.write(file_content)
|
_ = f.write(file_content)
|
||||||
|
|
||||||
logger.info(f"表情包文件已保存: {full_path}")
|
logger.info(f"表情包文件已保存: {full_path}")
|
||||||
|
|
||||||
|
|
@ -1010,7 +1010,7 @@ async def batch_upload_emoji(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
results = {
|
results: dict[str, Any] = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"total": len(files),
|
"total": len(files),
|
||||||
"uploaded": 0,
|
"uploaded": 0,
|
||||||
|
|
@ -1095,7 +1095,7 @@ async def batch_upload_emoji(
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
with open(full_path, "wb") as f:
|
with open(full_path, "wb") as f:
|
||||||
f.write(file_content)
|
_ = f.write(file_content)
|
||||||
|
|
||||||
# 处理情感标签
|
# 处理情感标签
|
||||||
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||||
def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
|
def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
|
||||||
"""
|
"""
|
||||||
获取 chat_id 的显示名称
|
获取 chat_id 的显示名称
|
||||||
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
|
尝试解析 JSON 并查询 ChatSession 表获取群聊名称
|
||||||
"""
|
"""
|
||||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from src.common.logger import get_logger
|
||||||
from src.common.toml_utils import save_toml_with_format
|
from src.common.toml_utils import save_toml_with_format
|
||||||
from src.config.config import MMC_VERSION
|
from src.config.config import MMC_VERSION
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
from src.plugin_system.base.config_types import ConfigField
|
||||||
from src.webui.git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||||
from src.webui.core import get_token_manager
|
from src.webui.core import get_token_manager
|
||||||
from src.webui.routers.websocket.plugin_progress import update_progress
|
from src.webui.routers.websocket.plugin_progress import update_progress
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue