diff --git a/src/webui/app.py b/src/webui/app.py new file mode 100644 index 00000000..3731f962 --- /dev/null +++ b/src/webui/app.py @@ -0,0 +1,161 @@ +"""FastAPI 应用工厂 - 创建和配置 WebUI 应用实例""" + +import mimetypes +from pathlib import Path +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from src.common.logger import get_logger + +logger = get_logger("webui.app") + + +def create_app( + host: str = "0.0.0.0", + port: int = 8001, + enable_static: bool = True, +) -> FastAPI: + """ + 创建 WebUI FastAPI 应用实例 + + Args: + host: 服务器主机地址 + port: 服务器端口 + enable_static: 是否启用静态文件服务 + """ + app = FastAPI(title="MaiBot WebUI") + + _setup_anti_crawler(app) + _setup_cors(app, port) + _register_api_routes(app) + _setup_robots_txt(app) + + if enable_static: + _setup_static_files(app) + + return app + + +def _setup_cors(app: FastAPI, port: int): + app.add_middleware( + CORSMiddleware, + allow_origins=[ + "http://localhost:5173", + "http://127.0.0.1:5173", + "http://localhost:7999", + "http://127.0.0.1:7999", + f"http://localhost:{port}", + f"http://127.0.0.1:{port}", + ], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], + allow_headers=[ + "Content-Type", + "Authorization", + "Accept", + "Origin", + "X-Requested-With", + ], + expose_headers=["Content-Length", "Content-Type"], + ) + logger.debug("✅ CORS 中间件已配置") + + +def _setup_anti_crawler(app: FastAPI): + try: + from src.webui.middleware import AntiCrawlerMiddleware + from src.config.config import global_config + + anti_crawler_mode = global_config.webui.anti_crawler_mode + app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode) + + mode_descriptions = { + "false": "已禁用", + "strict": "严格模式", + "loose": "宽松模式", + "basic": "基础模式", + } + mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式") + logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}") + except Exception as e: + logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True) + + +def _setup_robots_txt(app: FastAPI): + try: + from src.webui.middleware import create_robots_txt_response + + @app.get("/robots.txt", include_in_schema=False) + async def robots_txt(): + return create_robots_txt_response() + + logger.debug("✅ robots.txt 路由已注册") + except Exception as e: + logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True) + + +def _register_api_routes(app: FastAPI): + try: + from src.webui.routers import get_all_routers + + for router in get_all_routers(): + app.include_router(router) + + logger.info("✅ WebUI API 路由已注册") + except Exception as e: + logger.error(f"❌ 注册 WebUI API 路由失败: {e}", exc_info=True) + + +def _setup_static_files(app: FastAPI): + mimetypes.init() + mimetypes.add_type("application/javascript", ".js") + mimetypes.add_type("application/javascript", ".mjs") + mimetypes.add_type("text/css", ".css") + mimetypes.add_type("application/json", ".json") + + base_dir = Path(__file__).parent.parent.parent + static_path = base_dir / "webui" / "dist" + + if not static_path.exists(): + logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}") + logger.warning("💡 请先构建前端: cd webui && npm run build") + return + + if not (static_path / "index.html").exists(): + logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}") + logger.warning("💡 请确认前端已正确构建") + return + + @app.get("/{full_path:path}", include_in_schema=False) + async def serve_spa(full_path: str): + if not full_path or full_path == "/": + response = FileResponse(static_path / "index.html", media_type="text/html") + response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive" + return response + + file_path = static_path / full_path + if file_path.is_file() and file_path.exists(): + media_type = mimetypes.guess_type(str(file_path))[0] + response = FileResponse(file_path, media_type=media_type) + if str(file_path).endswith(".html"): + response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive" + return response + + response = FileResponse(static_path / "index.html", media_type="text/html") + response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive" + return response + + logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}") + + +def show_access_token(): + """显示 WebUI Access Token(供启动时调用)""" + try: + from src.webui.core import get_token_manager + + token_manager = get_token_manager() + current_token = token_manager.get_token() + logger.info(f"🔑 WebUI Access Token: {current_token}") + logger.info("💡 请使用此 Token 登录 WebUI") + except Exception as e: + logger.error(f"❌ 获取 Access Token 失败: {e}") diff --git a/src/webui/core/__init__.py b/src/webui/core/__init__.py new file mode 100644 index 00000000..3124e897 --- /dev/null +++ b/src/webui/core/__init__.py @@ -0,0 +1,30 @@ +from .security import TokenManager, get_token_manager +from .rate_limiter import ( + RateLimiter, + get_rate_limiter, + check_auth_rate_limit, + check_api_rate_limit, +) +from .auth import ( + COOKIE_NAME, + COOKIE_MAX_AGE, + get_current_token, + set_auth_cookie, + clear_auth_cookie, + verify_auth_token_from_cookie_or_header, +) + +__all__ = [ + "TokenManager", + "get_token_manager", + "RateLimiter", + "get_rate_limiter", + "check_auth_rate_limit", + "check_api_rate_limit", + "COOKIE_NAME", + "COOKIE_MAX_AGE", + "get_current_token", + "set_auth_cookie", + "clear_auth_cookie", + "verify_auth_token_from_cookie_or_header", +] diff --git a/src/webui/auth.py b/src/webui/core/auth.py similarity index 96% rename from src/webui/auth.py rename to src/webui/core/auth.py index db0fc675..ff02b789 100644 --- a/src/webui/auth.py +++ b/src/webui/core/auth.py @@ -7,7 +7,7 @@ from typing import Optional from fastapi import HTTPException, Cookie, Header, Response, Request from src.common.logger import get_logger from src.config.config import global_config -from .token_manager import get_token_manager +from .security import get_token_manager logger = get_logger("webui.auth") @@ -27,7 +27,7 @@ def _is_secure_environment() -> bool: if global_config.webui.secure_cookie: logger.info("配置中启用了 secure_cookie") return True - + # 检查是否是生产环境 if global_config.webui.mode == "production": logger.info("WebUI运行在生产模式,启用 secure cookie") @@ -88,7 +88,7 @@ def set_auth_cookie(response: Response, token: str, request: Optional[Request] = """ # 根据环境和实际请求协议决定安全设置 is_secure = _is_secure_environment() - + # 如果提供了 request,检测实际使用的协议 if request: # 检查 X-Forwarded-Proto header(代理/负载均衡器) @@ -100,7 +100,7 @@ def set_auth_cookie(response: Response, token: str, request: Optional[Request] = # 检查 request.url.scheme is_https = request.url.scheme == "https" logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}") - + # 如果是 HTTP 连接,强制禁用 secure 标志 if not is_https and is_secure: logger.warning("=" * 80) @@ -110,7 +110,7 @@ def set_auth_cookie(response: Response, token: str, request: Optional[Request] = logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头") logger.warning("=" * 80) is_secure = False - + # 设置 Cookie response.set_cookie( key=COOKIE_NAME, @@ -121,8 +121,10 @@ def set_auth_cookie(response: Response, token: str, request: Optional[Request] = secure=is_secure, # 根据实际协议决定 path="/", # 确保 Cookie 在所有路径下可用 ) - - logger.info(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})") + + logger.info( + f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})" + ) logger.debug(f"完整 token 前缀: {token[:20]}...") diff --git a/src/webui/rate_limiter.py b/src/webui/core/rate_limiter.py similarity index 100% rename from src/webui/rate_limiter.py rename to src/webui/core/rate_limiter.py diff --git a/src/webui/token_manager.py b/src/webui/core/security.py similarity index 98% rename from src/webui/token_manager.py rename to src/webui/core/security.py index bd1e5fbb..7e5e6891 100644 --- a/src/webui/token_manager.py +++ b/src/webui/core/security.py @@ -24,8 +24,8 @@ class TokenManager: config_path: 配置文件路径,默认为项目根目录的 data/webui.json """ if config_path is None: - # 获取项目根目录 (src/webui -> src -> 根目录) - project_root = Path(__file__).parent.parent.parent + # 获取项目根目录 (src/webui/core -> src/webui -> src -> 根目录) + project_root = Path(__file__).parent.parent.parent.parent config_path = project_root / "data" / "webui.json" self.config_path = config_path diff --git a/src/webui/dependencies.py b/src/webui/dependencies.py new file mode 100644 index 00000000..a7395522 --- /dev/null +++ b/src/webui/dependencies.py @@ -0,0 +1,87 @@ +from typing import Optional +from fastapi import Depends, Cookie, Header, Request, HTTPException +from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit + + +async def require_auth( + request: Request, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> str: + """ + FastAPI 依赖:要求有效认证 + + 用于保护需要认证的路由,自动从 Cookie 或 Header 获取并验证 token + + Returns: + 验证通过的 token + + Raises: + HTTPException 401: 认证失败 + """ + return get_current_token(request, maibot_session, authorization) + + +async def require_auth_with_rate_limit( + request: Request, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), + _rate_limit: None = Depends(check_auth_rate_limit), +) -> str: + """ + FastAPI 依赖:要求有效认证 + 频率限制 + + 组合了认证检查和频率限制,适用于敏感操作 + + Returns: + 验证通过的 token + + Raises: + HTTPException 401: 认证失败 + HTTPException 429: 请求过于频繁 + """ + return get_current_token(request, maibot_session, authorization) + + +def get_optional_token( + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> Optional[str]: + """ + FastAPI 依赖:可选获取 token(不验证) + + 用于某些需要知道是否有 token 但不强制验证的场景 + + Returns: + token 字符串或 None + """ + if maibot_session: + return maibot_session + if authorization and authorization.startswith("Bearer "): + return authorization.replace("Bearer ", "") + return None + + +async def verify_token_optional( + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> bool: + """ + FastAPI 依赖:可选验证 token + + 返回 token 是否有效,不抛出异常 + + Returns: + True 如果 token 有效,否则 False + """ + token = None + if maibot_session: + token = maibot_session + elif authorization and authorization.startswith("Bearer "): + token = authorization.replace("Bearer ", "") + + if not token: + return False + + token_manager = get_token_manager() + return token_manager.verify_token(token) diff --git a/src/webui/middleware/__init__.py b/src/webui/middleware/__init__.py new file mode 100644 index 00000000..275b1daa --- /dev/null +++ b/src/webui/middleware/__init__.py @@ -0,0 +1,17 @@ +from .anti_crawler import ( + AntiCrawlerMiddleware, + create_robots_txt_response, + ANTI_CRAWLER_MODE, + ALLOWED_IPS, + TRUSTED_PROXIES, + TRUST_XFF, +) + +__all__ = [ + "AntiCrawlerMiddleware", + "create_robots_txt_response", + "ANTI_CRAWLER_MODE", + "ALLOWED_IPS", + "TRUSTED_PROXIES", + "TRUST_XFF", +] diff --git a/src/webui/anti_crawler.py b/src/webui/middleware/anti_crawler.py similarity index 100% rename from src/webui/anti_crawler.py rename to src/webui/middleware/anti_crawler.py diff --git a/src/webui/routers/__init__.py b/src/webui/routers/__init__.py new file mode 100644 index 00000000..a306bbe2 --- /dev/null +++ b/src/webui/routers/__init__.py @@ -0,0 +1,35 @@ +"""WebUI 路由聚合模块 - 提供统一的路由注册接口""" + +from fastapi import APIRouter + + +def get_api_router() -> APIRouter: + """获取主 API 路由器(包含所有子路由)""" + from src.webui.routes import router as main_router + + return main_router + + +def get_all_routers() -> list[APIRouter]: + """获取所有需要独立注册的路由器列表""" + from src.webui.routes import router as main_router + from src.webui.routers.websocket.logs import router as logs_router + from src.webui.routers.knowledge import router as knowledge_router + from src.webui.routers.chat import router as chat_router + from src.webui.api.planner import router as planner_router + from src.webui.api.replier import router as replier_router + + return [ + main_router, + logs_router, + knowledge_router, + chat_router, + planner_router, + replier_router, + ] + + +__all__ = [ + "get_api_router", + "get_all_routers", +] diff --git a/src/webui/annual_report_routes.py b/src/webui/routers/annual_report.py similarity index 99% rename from src/webui/annual_report_routes.py rename to src/webui/routers/annual_report.py index ff3ec00f..68e1f4b9 100644 --- a/src/webui/annual_report_routes.py +++ b/src/webui/routers/annual_report.py @@ -18,7 +18,7 @@ from src.common.database.database_model import ( ActionRecords, Jargon, ) -from src.webui.auth import verify_auth_token_from_cookie_or_header +from src.webui.core import verify_auth_token_from_cookie_or_header logger = get_logger("webui.annual_report") diff --git a/src/webui/chat_routes.py b/src/webui/routers/chat.py similarity index 99% rename from src/webui/chat_routes.py rename to src/webui/routers/chat.py index 6535b9e9..c7f847ea 100644 --- a/src/webui/chat_routes.py +++ b/src/webui/routers/chat.py @@ -15,9 +15,8 @@ from src.common.logger import get_logger from src.common.database.database_model import Messages, PersonInfo from src.config.config import global_config from src.chat.message_receive.bot import chat_bot -from src.webui.auth import verify_auth_token_from_cookie_or_header -from src.webui.token_manager import get_token_manager -from src.webui.ws_auth import verify_ws_token +from src.webui.core import verify_auth_token_from_cookie_or_header, get_token_manager +from src.webui.routers.websocket.auth import verify_ws_token logger = get_logger("webui.chat") diff --git a/src/webui/config_routes.py b/src/webui/routers/config.py similarity index 99% rename from src/webui/config_routes.py rename to src/webui/routers/config.py index 6a028927..db176cbe 100644 --- a/src/webui/config_routes.py +++ b/src/webui/routers/config.py @@ -8,7 +8,7 @@ from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header from typing import Any, Annotated, Optional from src.common.logger import get_logger -from src.webui.auth import verify_auth_token_from_cookie_or_header +from src.webui.core import verify_auth_token_from_cookie_or_header from src.common.toml_utils import save_toml_with_format, _update_toml_doc from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT from src.config.official_configs import ( diff --git a/src/webui/emoji_routes.py b/src/webui/routers/emoji.py similarity index 99% rename from src/webui/emoji_routes.py rename to src/webui/routers/emoji.py index 90b2d60b..ea09f68e 100644 --- a/src/webui/emoji_routes.py +++ b/src/webui/routers/emoji.py @@ -6,8 +6,7 @@ from pydantic import BaseModel from typing import Optional, List, Annotated from src.common.logger import get_logger from src.common.database.database_model import Emoji -from .token_manager import get_token_manager -from .auth import verify_auth_token_from_cookie_or_header +from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header import time import os import hashlib diff --git a/src/webui/expression_routes.py b/src/webui/routers/expression.py similarity index 90% rename from src/webui/expression_routes.py rename to src/webui/routers/expression.py index d3586947..1e78d982 100644 --- a/src/webui/expression_routes.py +++ b/src/webui/routers/expression.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from typing import Optional, List, Dict from src.common.logger import get_logger from src.common.database.database_model import Expression, ChatStreams -from .auth import verify_auth_token_from_cookie_or_header +from src.webui.core import verify_auth_token_from_cookie_or_header import time logger = get_logger("webui.expression") @@ -224,10 +224,7 @@ async def get_expression_list( # 搜索过滤 if search: - query = query.where( - (Expression.situation.contains(search)) - | (Expression.style.contains(search)) - ) + query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search))) # 聊天ID过滤 if chat_id: @@ -363,21 +360,21 @@ async def update_expression( if request.require_unchecked and expression.checked: raise HTTPException( status_code=409, - detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表" + detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表", ) # 只更新提供的字段 update_data = request.model_dump(exclude_unset=True) - + # 移除 require_unchecked,它不是数据库字段 - update_data.pop('require_unchecked', None) + update_data.pop("require_unchecked", None) if not update_data: raise HTTPException(status_code=400, detail="未提供任何需要更新的字段") # 如果更新了 checked 或 rejected,标记为用户修改 - if 'checked' in update_data or 'rejected' in update_data: - update_data['modified_by'] = 'user' + if "checked" in update_data or "rejected" in update_data: + update_data["modified_by"] = "user" # 更新最后活跃时间 update_data["last_active_time"] = time.time() @@ -542,8 +539,10 @@ async def get_expression_stats( # ============ 审核相关接口 ============ + class ReviewStatsResponse(BaseModel): """审核统计响应""" + total: int unchecked: int passed: int @@ -553,10 +552,7 @@ class ReviewStatsResponse(BaseModel): @router.get("/review/stats", response_model=ReviewStatsResponse) -async def get_review_stats( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None) -): +async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 获取审核统计数据 @@ -568,14 +564,10 @@ async def get_review_stats( total = Expression.select().count() unchecked = Expression.select().where(Expression.checked == False).count() - passed = Expression.select().where( - (Expression.checked == True) & (Expression.rejected == False) - ).count() - rejected = Expression.select().where( - (Expression.checked == True) & (Expression.rejected == True) - ).count() - ai_checked = Expression.select().where(Expression.modified_by == 'ai').count() - user_checked = Expression.select().where(Expression.modified_by == 'user').count() + passed = Expression.select().where((Expression.checked == True) & (Expression.rejected == False)).count() + rejected = Expression.select().where((Expression.checked == True) & (Expression.rejected == True)).count() + ai_checked = Expression.select().where(Expression.modified_by == "ai").count() + user_checked = Expression.select().where(Expression.modified_by == "user").count() return ReviewStatsResponse( total=total, @@ -583,7 +575,7 @@ async def get_review_stats( passed=passed, rejected=rejected, ai_checked=ai_checked, - user_checked=user_checked + user_checked=user_checked, ) except HTTPException: @@ -595,6 +587,7 @@ async def get_review_stats( class ReviewListResponse(BaseModel): """审核列表响应""" + success: bool total: int page: int @@ -641,9 +634,7 @@ async def get_review_list( # 搜索过滤 if search: - query = query.where( - (Expression.situation.contains(search)) | (Expression.style.contains(search)) - ) + query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search))) # 聊天ID过滤 if chat_id: @@ -651,10 +642,8 @@ async def get_review_list( # 排序:创建时间倒序 from peewee import Case - query = query.order_by( - Case(None, [(Expression.create_date.is_null(), 1)], 0), - Expression.create_date.desc() - ) + + query = query.order_by(Case(None, [(Expression.create_date.is_null(), 1)], 0), Expression.create_date.desc()) total = query.count() offset = (page - 1) * page_size @@ -665,7 +654,7 @@ async def get_review_list( total=total, page=page, page_size=page_size, - data=[expression_to_response(expr) for expr in expressions] + data=[expression_to_response(expr) for expr in expressions], ) except HTTPException: @@ -677,6 +666,7 @@ async def get_review_list( class BatchReviewItem(BaseModel): """批量审核项""" + id: int rejected: bool require_unchecked: bool = True # 默认要求未检查状态 @@ -684,11 +674,13 @@ class BatchReviewItem(BaseModel): class BatchReviewRequest(BaseModel): """批量审核请求""" + items: List[BatchReviewItem] class BatchReviewResultItem(BaseModel): """批量审核结果项""" + id: int success: bool message: str @@ -696,6 +688,7 @@ class BatchReviewResultItem(BaseModel): class BatchReviewResponse(BaseModel): """批量审核响应""" + success: bool total: int succeeded: int @@ -733,54 +726,44 @@ async def batch_review_expressions( expression = Expression.get_or_none(Expression.id == item.id) if not expression: - results.append(BatchReviewResultItem( - id=item.id, - success=False, - message=f"未找到 ID 为 {item.id} 的表达方式" - )) + results.append( + BatchReviewResultItem(id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式") + ) failed += 1 continue # 冲突检测 if item.require_unchecked and expression.checked: - results.append(BatchReviewResultItem( - id=item.id, - success=False, - message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查" - )) + results.append( + BatchReviewResultItem( + id=item.id, + success=False, + message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查", + ) + ) failed += 1 continue # 更新状态 expression.checked = True expression.rejected = item.rejected - expression.modified_by = 'user' + expression.modified_by = "user" expression.last_active_time = time.time() expression.save() - results.append(BatchReviewResultItem( - id=item.id, - success=True, - message="通过" if not item.rejected else "拒绝" - )) + results.append( + BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝") + ) succeeded += 1 except Exception as e: - results.append(BatchReviewResultItem( - id=item.id, - success=False, - message=str(e) - )) + results.append(BatchReviewResultItem(id=item.id, success=False, message=str(e))) failed += 1 logger.info(f"批量审核完成: 成功 {succeeded}, 失败 {failed}") return BatchReviewResponse( - success=True, - total=len(request.items), - succeeded=succeeded, - failed=failed, - results=results + success=True, total=len(request.items), succeeded=succeeded, failed=failed, results=results ) except HTTPException: diff --git a/src/webui/jargon_routes.py b/src/webui/routers/jargon.py similarity index 100% rename from src/webui/jargon_routes.py rename to src/webui/routers/jargon.py diff --git a/src/webui/knowledge_routes.py b/src/webui/routers/knowledge.py similarity index 99% rename from src/webui/knowledge_routes.py rename to src/webui/routers/knowledge.py index fb540105..5959e0ac 100644 --- a/src/webui/knowledge_routes.py +++ b/src/webui/routers/knowledge.py @@ -4,7 +4,7 @@ from typing import List, Optional from fastapi import APIRouter, Query, Depends, Cookie, Header from pydantic import BaseModel import logging -from src.webui.auth import verify_auth_token_from_cookie_or_header +from src.webui.core import verify_auth_token_from_cookie_or_header from src.config.config import global_config logger = logging.getLogger(__name__) diff --git a/src/webui/logs_routes.py b/src/webui/routers/logs.py similarity index 100% rename from src/webui/logs_routes.py rename to src/webui/routers/logs.py diff --git a/src/webui/model_routes.py b/src/webui/routers/model.py similarity index 99% rename from src/webui/model_routes.py rename to src/webui/routers/model.py index a84241b9..b5ca4128 100644 --- a/src/webui/model_routes.py +++ b/src/webui/routers/model.py @@ -12,7 +12,7 @@ import tomlkit from src.common.logger import get_logger from src.config.config import CONFIG_DIR -from src.webui.auth import verify_auth_token_from_cookie_or_header +from src.webui.core import verify_auth_token_from_cookie_or_header logger = get_logger("webui") diff --git a/src/webui/person_routes.py b/src/webui/routers/person.py similarity index 99% rename from src/webui/person_routes.py rename to src/webui/routers/person.py index 9881d44e..1368c2a4 100644 --- a/src/webui/person_routes.py +++ b/src/webui/routers/person.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from typing import Optional, List, Dict from src.common.logger import get_logger from src.common.database.database_model import PersonInfo -from .auth import verify_auth_token_from_cookie_or_header +from src.webui.core import verify_auth_token_from_cookie_or_header import json import time diff --git a/src/webui/plugin_routes.py b/src/webui/routers/plugin.py similarity index 99% rename from src/webui/plugin_routes.py rename to src/webui/routers/plugin.py index e85e1263..3ddcca34 100644 --- a/src/webui/plugin_routes.py +++ b/src/webui/routers/plugin.py @@ -7,9 +7,9 @@ from src.common.logger import get_logger from src.common.toml_utils import save_toml_with_format from src.config.config import MMC_VERSION from src.plugin_system.base.config_types import ConfigField -from .git_mirror_service import get_git_mirror_service, set_update_progress_callback -from .token_manager import get_token_manager -from .plugin_progress_ws import update_progress +from src.webui.git_mirror_service import get_git_mirror_service, set_update_progress_callback +from src.webui.core import get_token_manager +from src.webui.routers.websocket.plugin_progress import update_progress logger = get_logger("webui.plugin_routes") @@ -1370,21 +1370,19 @@ async def get_installed_plugins( seen_ids = {} # 记录 ID -> 路径的映射 unique_plugins = [] duplicates = [] - + for plugin in installed_plugins: plugin_id = plugin["id"] plugin_path = plugin["path"] - + if plugin_id not in seen_ids: seen_ids[plugin_id] = plugin_path unique_plugins.append(plugin) else: duplicates.append(plugin) first_path = seen_ids[plugin_id] - logger.warning( - f"重复插件 {plugin_id}: 保留 {first_path}, 跳过 {plugin_path}" - ) - + logger.warning(f"重复插件 {plugin_id}: 保留 {first_path}, 跳过 {plugin_path}") + if duplicates: logger.warning(f"共检测到 {len(duplicates)} 个重复插件已去重") @@ -1420,34 +1418,35 @@ async def get_local_plugin_readme( try: plugins_dir = Path("plugins") - + # 查找插件目录 plugin_path = None for folder in plugins_dir.iterdir(): if not folder.is_dir(): continue - + manifest_path = folder / "_manifest.json" if manifest_path.exists(): try: import json as json_module + with open(manifest_path, "r", encoding="utf-8") as f: manifest = json_module.load(f) - + # 检查是否匹配 plugin_id if manifest.get("id") == plugin_id: plugin_path = folder break except Exception: continue - + if not plugin_path: return {"success": False, "error": "插件未安装"} - + # 查找 README 文件(支持多种命名) readme_files = ["README.md", "readme.md", "Readme.md", "README.MD"] readme_content = None - + for readme_name in readme_files: readme_path = plugin_path / readme_name if readme_path.exists(): @@ -1459,12 +1458,12 @@ async def get_local_plugin_readme( except Exception as e: logger.warning(f"读取 {readme_path} 失败: {e}") continue - + if readme_content: return {"success": True, "data": readme_content} else: return {"success": False, "error": "本地未找到 README 文件"} - + except Exception as e: logger.error(f"获取本地 README 失败: {e}", exc_info=True) return {"success": False, "error": str(e)} @@ -1756,10 +1755,10 @@ async def update_plugin_config_raw( # 验证 TOML 格式 import tomlkit - + if not isinstance(request.config, str): raise HTTPException(status_code=400, detail="配置必须是字符串格式的 TOML 内容") - + try: tomlkit.loads(request.config) except Exception as e: diff --git a/src/webui/statistics_routes.py b/src/webui/routers/statistics.py similarity index 99% rename from src/webui/statistics_routes.py rename to src/webui/routers/statistics.py index e5628538..40770bd6 100644 --- a/src/webui/statistics_routes.py +++ b/src/webui/routers/statistics.py @@ -8,7 +8,7 @@ from peewee import fn from src.common.logger import get_logger from src.common.database.database_model import LLMUsage, OnlineTime, Messages -from src.webui.auth import verify_auth_token_from_cookie_or_header +from src.webui.core import verify_auth_token_from_cookie_or_header logger = get_logger("webui.statistics") diff --git a/src/webui/routers/system.py b/src/webui/routers/system.py index b1d3729a..ac6ab324 100644 --- a/src/webui/routers/system.py +++ b/src/webui/routers/system.py @@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header from pydantic import BaseModel from src.config.config import MMC_VERSION from src.common.logger import get_logger -from src.webui.auth import verify_auth_token_from_cookie_or_header +from src.webui.core import verify_auth_token_from_cookie_or_header router = APIRouter(prefix="/system", tags=["system"]) logger = get_logger("webui_system") diff --git a/src/webui/routers/websocket/__init__.py b/src/webui/routers/websocket/__init__.py new file mode 100644 index 00000000..0acec62f --- /dev/null +++ b/src/webui/routers/websocket/__init__.py @@ -0,0 +1,9 @@ +from .logs import router as logs_router +from .plugin_progress import get_progress_router +from .auth import router as ws_auth_router + +__all__ = [ + "logs_router", + "get_progress_router", + "ws_auth_router", +] diff --git a/src/webui/ws_auth.py b/src/webui/routers/websocket/auth.py similarity index 98% rename from src/webui/ws_auth.py rename to src/webui/routers/websocket/auth.py index e6bb00e7..74246759 100644 --- a/src/webui/ws_auth.py +++ b/src/webui/routers/websocket/auth.py @@ -9,7 +9,7 @@ from typing import Optional import secrets import time from src.common.logger import get_logger -from src.webui.token_manager import get_token_manager +from src.webui.core import get_token_manager logger = get_logger("webui.ws_auth") router = APIRouter() diff --git a/src/webui/logs_ws.py b/src/webui/routers/websocket/logs.py similarity index 98% rename from src/webui/logs_ws.py rename to src/webui/routers/websocket/logs.py index 5ae92189..1d43f306 100644 --- a/src/webui/logs_ws.py +++ b/src/webui/routers/websocket/logs.py @@ -5,8 +5,8 @@ from typing import Set, Optional import json from pathlib import Path from src.common.logger import get_logger -from src.webui.token_manager import get_token_manager -from src.webui.ws_auth import verify_ws_token +from src.webui.core import get_token_manager +from src.webui.routers.websocket.auth import verify_ws_token logger = get_logger("webui.logs_ws") router = APIRouter() diff --git a/src/webui/plugin_progress_ws.py b/src/webui/routers/websocket/plugin_progress.py similarity index 98% rename from src/webui/plugin_progress_ws.py rename to src/webui/routers/websocket/plugin_progress.py index 8d0a18c6..82ead6a5 100644 --- a/src/webui/plugin_progress_ws.py +++ b/src/webui/routers/websocket/plugin_progress.py @@ -5,8 +5,8 @@ from typing import Set, Dict, Any, Optional import json import asyncio from src.common.logger import get_logger -from src.webui.token_manager import get_token_manager -from src.webui.ws_auth import verify_ws_token +from src.webui.core import get_token_manager +from src.webui.routers.websocket.auth import verify_ws_token logger = get_logger("webui.plugin_progress") diff --git a/src/webui/routes.py b/src/webui/routes.py index 0479dc51..da45cb06 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -4,21 +4,25 @@ from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, from pydantic import BaseModel, Field from typing import Optional from src.common.logger import get_logger -from .token_manager import get_token_manager -from .auth import set_auth_cookie, clear_auth_cookie -from .rate_limiter import get_rate_limiter, check_auth_rate_limit -from .config_routes import router as config_router -from .statistics_routes import router as statistics_router -from .person_routes import router as person_router -from .expression_routes import router as expression_router -from .jargon_routes import router as jargon_router -from .emoji_routes import router as emoji_router -from .plugin_routes import router as plugin_router -from .plugin_progress_ws import get_progress_router -from .routers.system import router as system_router -from .model_routes import router as model_router -from .ws_auth import router as ws_auth_router -from .annual_report_routes import router as annual_report_router +from src.webui.core import ( + get_token_manager, + set_auth_cookie, + clear_auth_cookie, + get_rate_limiter, + check_auth_rate_limit, +) +from src.webui.routers.config import router as config_router +from src.webui.routers.statistics import router as statistics_router +from src.webui.routers.person import router as person_router +from src.webui.routers.expression import router as expression_router +from src.webui.routers.jargon import router as jargon_router +from src.webui.routers.emoji import router as emoji_router +from src.webui.routers.plugin import router as plugin_router +from src.webui.routers.websocket.plugin_progress import get_progress_router +from src.webui.routers.system import router as system_router +from src.webui.routers.model import router as model_router +from src.webui.routers.websocket.auth import router as ws_auth_router +from src.webui.routers.annual_report import router as annual_report_router logger = get_logger("webui.api") @@ -198,9 +202,11 @@ async def check_auth_status( """ try: token = None - + # 记录请求信息用于调试 - logger.debug(f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}") + logger.debug( + f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}" + ) # 优先从 Cookie 获取 if maibot_session: @@ -218,7 +224,7 @@ async def check_auth_status( token_manager = get_token_manager() is_valid = token_manager.verify_token(token) logger.debug(f"Token 验证结果: {is_valid}") - + if is_valid: return {"authenticated": True} else: diff --git a/src/webui/schemas/__init__.py b/src/webui/schemas/__init__.py new file mode 100644 index 00000000..c9d12a41 --- /dev/null +++ b/src/webui/schemas/__init__.py @@ -0,0 +1,109 @@ +"""WebUI Schemas - Pydantic models for API requests and responses.""" + +# Auth schemas +from .auth import ( + TokenVerifyRequest, + TokenVerifyResponse, + TokenUpdateRequest, + TokenUpdateResponse, + TokenRegenerateResponse, + FirstSetupStatusResponse, + CompleteSetupResponse, + ResetSetupResponse, +) + +# Statistics schemas +from .statistics import ( + StatisticsSummary, + ModelStatistics, + TimeSeriesData, + DashboardData, +) + +# Emoji schemas +from .emoji import ( + EmojiResponse, + EmojiListResponse, + EmojiDetailResponse, + EmojiUpdateRequest, + EmojiUpdateResponse, + EmojiDeleteResponse, + BatchDeleteRequest, + BatchDeleteResponse, + EmojiUploadResponse, + ThumbnailCacheStatsResponse, + ThumbnailCleanupResponse, + ThumbnailPreheatResponse, +) + +# Chat schemas +from .chat import ( + VirtualIdentityConfig, + ChatHistoryMessage, +) + +# Plugin schemas +from .plugin import ( + FetchRawFileRequest, + FetchRawFileResponse, + CloneRepositoryRequest, + CloneRepositoryResponse, + MirrorConfigResponse, + AvailableMirrorsResponse, + AddMirrorRequest, + UpdateMirrorRequest, + GitStatusResponse, + InstallPluginRequest, + VersionResponse, + UninstallPluginRequest, + UpdatePluginRequest, + UpdatePluginConfigRequest, +) + +__all__ = [ + # Auth + "TokenVerifyRequest", + "TokenVerifyResponse", + "TokenUpdateRequest", + "TokenUpdateResponse", + "TokenRegenerateResponse", + "FirstSetupStatusResponse", + "CompleteSetupResponse", + "ResetSetupResponse", + # Statistics + "StatisticsSummary", + "ModelStatistics", + "TimeSeriesData", + "DashboardData", + # Emoji + "EmojiResponse", + "EmojiListResponse", + "EmojiDetailResponse", + "EmojiUpdateRequest", + "EmojiUpdateResponse", + "EmojiDeleteResponse", + "BatchDeleteRequest", + "BatchDeleteResponse", + "EmojiUploadResponse", + "ThumbnailCacheStatsResponse", + "ThumbnailCleanupResponse", + "ThumbnailPreheatResponse", + # Chat + "VirtualIdentityConfig", + "ChatHistoryMessage", + # Plugin + "FetchRawFileRequest", + "FetchRawFileResponse", + "CloneRepositoryRequest", + "CloneRepositoryResponse", + "MirrorConfigResponse", + "AvailableMirrorsResponse", + "AddMirrorRequest", + "UpdateMirrorRequest", + "GitStatusResponse", + "InstallPluginRequest", + "VersionResponse", + "UninstallPluginRequest", + "UpdatePluginRequest", + "UpdatePluginConfigRequest", +] diff --git a/src/webui/schemas/auth.py b/src/webui/schemas/auth.py new file mode 100644 index 00000000..51e8d662 --- /dev/null +++ b/src/webui/schemas/auth.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, Field + + +class TokenVerifyRequest(BaseModel): + token: str = Field(..., description="访问令牌") + + +class TokenVerifyResponse(BaseModel): + valid: bool = Field(..., description="Token 是否有效") + message: str = Field(..., description="验证结果消息") + is_first_setup: bool = Field(False, description="是否为首次设置") + + +class TokenUpdateRequest(BaseModel): + new_token: str = Field(..., description="新的访问令牌", min_length=10) + + +class TokenUpdateResponse(BaseModel): + success: bool = Field(..., description="是否更新成功") + message: str = Field(..., description="更新结果消息") + + +class TokenRegenerateResponse(BaseModel): + success: bool = Field(..., description="是否生成成功") + token: str = Field(..., description="新生成的令牌") + message: str = Field(..., description="生成结果消息") + + +class FirstSetupStatusResponse(BaseModel): + is_first_setup: bool = Field(..., description="是否为首次配置") + message: str = Field(..., description="状态消息") + + +class CompleteSetupResponse(BaseModel): + success: bool = Field(..., description="是否成功") + message: str = Field(..., description="结果消息") + + +class ResetSetupResponse(BaseModel): + success: bool = Field(..., description="是否成功") + message: str = Field(..., description="结果消息") diff --git a/src/webui/schemas/chat.py b/src/webui/schemas/chat.py new file mode 100644 index 00000000..786792be --- /dev/null +++ b/src/webui/schemas/chat.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel +from typing import Optional + + +class VirtualIdentityConfig(BaseModel): + """虚拟身份配置""" + + enabled: bool = False + platform: Optional[str] = None + person_id: Optional[str] = None + user_id: Optional[str] = None + user_nickname: Optional[str] = None + group_id: Optional[str] = None + group_name: Optional[str] = None + + +class ChatHistoryMessage(BaseModel): + """聊天历史消息""" + + id: str + type: str # 'user' | 'bot' | 'system' + content: str + timestamp: float + sender_name: str + sender_id: Optional[str] = None + is_bot: bool = False diff --git a/src/webui/schemas/emoji.py b/src/webui/schemas/emoji.py new file mode 100644 index 00000000..571eccd6 --- /dev/null +++ b/src/webui/schemas/emoji.py @@ -0,0 +1,115 @@ +from pydantic import BaseModel +from typing import Optional, List + + +class EmojiResponse(BaseModel): + """表情包响应""" + + id: int + full_path: str + format: str + emoji_hash: str + description: str + query_count: int + is_registered: bool + is_banned: bool + emotion: Optional[str] + record_time: float + register_time: Optional[float] + usage_count: int + last_used_time: Optional[float] + + +class EmojiListResponse(BaseModel): + """表情包列表响应""" + + success: bool + total: int + page: int + page_size: int + data: List[EmojiResponse] + + +class EmojiDetailResponse(BaseModel): + """表情包详情响应""" + + success: bool + data: EmojiResponse + + +class EmojiUpdateRequest(BaseModel): + """表情包更新请求""" + + description: Optional[str] = None + is_registered: Optional[bool] = None + is_banned: Optional[bool] = None + emotion: Optional[str] = None + + +class EmojiUpdateResponse(BaseModel): + """表情包更新响应""" + + success: bool + message: str + data: Optional[EmojiResponse] = None + + +class EmojiDeleteResponse(BaseModel): + """表情包删除响应""" + + success: bool + message: str + + +class BatchDeleteRequest(BaseModel): + """批量删除请求""" + + emoji_ids: List[int] + + +class BatchDeleteResponse(BaseModel): + """批量删除响应""" + + success: bool + message: str + deleted_count: int + failed_count: int + failed_ids: List[int] = [] + + +class EmojiUploadResponse(BaseModel): + """表情包上传响应""" + + success: bool + message: str + data: Optional[EmojiResponse] = None + + +class ThumbnailCacheStatsResponse(BaseModel): + """缩略图缓存统计响应""" + + success: bool + cache_dir: str + total_count: int + total_size_mb: float + emoji_count: int + coverage_percent: float + + +class ThumbnailCleanupResponse(BaseModel): + """缩略图清理响应""" + + success: bool + message: str + cleaned_count: int + kept_count: int + + +class ThumbnailPreheatResponse(BaseModel): + """缩略图预热响应""" + + success: bool + message: str + generated_count: int + skipped_count: int + failed_count: int diff --git a/src/webui/schemas/plugin.py b/src/webui/schemas/plugin.py new file mode 100644 index 00000000..1a75a38e --- /dev/null +++ b/src/webui/schemas/plugin.py @@ -0,0 +1,135 @@ +from pydantic import BaseModel, Field +from typing import Optional, List, Dict, Any + + +class FetchRawFileRequest(BaseModel): + """获取 Raw 文件请求""" + + owner: str = Field(..., description="仓库所有者", example="MaiM-with-u") + repo: str = Field(..., description="仓库名称", example="plugin-repo") + branch: str = Field(..., description="分支名称", example="main") + file_path: str = Field(..., description="文件路径", example="plugin_details.json") + mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") + custom_url: Optional[str] = Field(None, description="自定义完整 URL") + + +class FetchRawFileResponse(BaseModel): + """获取 Raw 文件响应""" + + success: bool = Field(..., description="是否成功") + data: Optional[str] = Field(None, description="文件内容") + error: Optional[str] = Field(None, description="错误信息") + mirror_used: Optional[str] = Field(None, description="使用的镜像源") + attempts: int = Field(..., description="尝试次数") + url: Optional[str] = Field(None, description="实际请求的 URL") + + +class CloneRepositoryRequest(BaseModel): + """克隆仓库请求""" + + owner: str = Field(..., description="仓库所有者", example="MaiM-with-u") + repo: str = Field(..., description="仓库名称", example="plugin-repo") + target_path: str = Field(..., description="目标路径(相对于插件目录)") + branch: Optional[str] = Field(None, description="分支名称", example="main") + mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") + custom_url: Optional[str] = Field(None, description="自定义克隆 URL") + depth: Optional[int] = Field(None, description="克隆深度(浅克隆)", ge=1) + + +class CloneRepositoryResponse(BaseModel): + """克隆仓库响应""" + + success: bool = Field(..., description="是否成功") + path: Optional[str] = Field(None, description="克隆路径") + error: Optional[str] = Field(None, description="错误信息") + mirror_used: Optional[str] = Field(None, description="使用的镜像源") + attempts: int = Field(..., description="尝试次数") + url: Optional[str] = Field(None, description="实际克隆的 URL") + message: Optional[str] = Field(None, description="附加信息") + + +class MirrorConfigResponse(BaseModel): + """镜像源配置响应""" + + id: str = Field(..., description="镜像源 ID") + name: str = Field(..., description="镜像源名称") + raw_prefix: str = Field(..., description="Raw 文件前缀") + clone_prefix: str = Field(..., description="克隆前缀") + enabled: bool = Field(..., description="是否启用") + priority: int = Field(..., description="优先级(数字越小优先级越高)") + + +class AvailableMirrorsResponse(BaseModel): + """可用镜像源列表响应""" + + mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表") + default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)") + + +class AddMirrorRequest(BaseModel): + """添加镜像源请求""" + + id: str = Field(..., description="镜像源 ID", example="custom-mirror") + name: str = Field(..., description="镜像源名称", example="自定义镜像源") + raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw") + clone_prefix: str = Field(..., description="克隆前缀", example="https://example.com/clone") + enabled: bool = Field(True, description="是否启用") + priority: Optional[int] = Field(None, description="优先级") + + +class UpdateMirrorRequest(BaseModel): + """更新镜像源请求""" + + name: Optional[str] = Field(None, description="镜像源名称") + raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀") + clone_prefix: Optional[str] = Field(None, description="克隆前缀") + enabled: Optional[bool] = Field(None, description="是否启用") + priority: Optional[int] = Field(None, description="优先级") + + +class GitStatusResponse(BaseModel): + """Git 安装状态响应""" + + installed: bool = Field(..., description="是否已安装 Git") + version: Optional[str] = Field(None, description="Git 版本号") + path: Optional[str] = Field(None, description="Git 可执行文件路径") + error: Optional[str] = Field(None, description="错误信息") + + +class InstallPluginRequest(BaseModel): + """安装插件请求""" + + plugin_id: str = Field(..., description="插件 ID") + repository_url: str = Field(..., description="插件仓库 URL") + branch: Optional[str] = Field("main", description="分支名称") + mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") + + +class VersionResponse(BaseModel): + """麦麦版本响应""" + + version: str = Field(..., description="麦麦版本号") + version_major: int = Field(..., description="主版本号") + version_minor: int = Field(..., description="次版本号") + version_patch: int = Field(..., description="补丁版本号") + + +class UninstallPluginRequest(BaseModel): + """卸载插件请求""" + + plugin_id: str = Field(..., description="插件 ID") + + +class UpdatePluginRequest(BaseModel): + """更新插件请求""" + + plugin_id: str = Field(..., description="插件 ID") + repository_url: str = Field(..., description="插件仓库 URL") + branch: Optional[str] = Field("main", description="分支名称") + mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") + + +class UpdatePluginConfigRequest(BaseModel): + """更新插件配置请求""" + + config: Dict[str, Any] = Field(..., description="配置数据") diff --git a/src/webui/schemas/statistics.py b/src/webui/schemas/statistics.py new file mode 100644 index 00000000..278d7251 --- /dev/null +++ b/src/webui/schemas/statistics.py @@ -0,0 +1,45 @@ +from pydantic import BaseModel, Field +from typing import Dict, Any, List + + +class StatisticsSummary(BaseModel): + """统计数据摘要""" + + total_requests: int = Field(0, description="总请求数") + total_cost: float = Field(0.0, description="总花费") + total_tokens: int = Field(0, description="总token数") + online_time: float = Field(0.0, description="在线时间(秒)") + total_messages: int = Field(0, description="总消息数") + total_replies: int = Field(0, description="总回复数") + avg_response_time: float = Field(0.0, description="平均响应时间") + cost_per_hour: float = Field(0.0, description="每小时花费") + tokens_per_hour: float = Field(0.0, description="每小时token数") + + +class ModelStatistics(BaseModel): + """模型统计""" + + model_name: str + request_count: int + total_cost: float + total_tokens: int + avg_response_time: float + + +class TimeSeriesData(BaseModel): + """时间序列数据""" + + timestamp: str + requests: int = 0 + cost: float = 0.0 + tokens: int = 0 + + +class DashboardData(BaseModel): + """仪表盘数据""" + + summary: StatisticsSummary + model_stats: List[ModelStatistics] + hourly_data: List[TimeSeriesData] + daily_data: List[TimeSeriesData] + recent_activity: List[Dict[str, Any]] diff --git a/src/webui/services/__init__.py b/src/webui/services/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/webui/services/__init__.py @@ -0,0 +1 @@ + diff --git a/src/webui/utils/__init__.py b/src/webui/utils/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/webui/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index fca2cee1..7e2afbb5 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -1,13 +1,9 @@ """独立的 WebUI 服务器 - 运行在 0.0.0.0:8001""" import asyncio -import mimetypes -from pathlib import Path -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse from uvicorn import Config, Server as UvicornServer from src.common.logger import get_logger +from src.webui.app import create_app, show_access_token logger = get_logger("webui_server") @@ -18,174 +14,10 @@ class WebUIServer: def __init__(self, host: str = "0.0.0.0", port: int = 8001): self.host = host self.port = port - self.app = FastAPI(title="MaiBot WebUI") + self.app = create_app(host=host, port=port, enable_static=True) self._server = None - # 配置防爬虫中间件(需要在CORS之前注册) - self._setup_anti_crawler() - - # 配置 CORS(支持开发环境跨域请求) - self._setup_cors() - - # 显示 Access Token - self._show_access_token() - - # 重要:先注册 API 路由,再设置静态文件 - self._register_api_routes() - self._setup_static_files() - - # 注册robots.txt路由 - self._setup_robots_txt() - - def _setup_cors(self): - """配置 CORS 中间件""" - # 开发环境需要允许前端开发服务器的跨域请求 - self.app.add_middleware( - CORSMiddleware, - allow_origins=[ - "http://localhost:5173", # Vite 开发服务器 - "http://127.0.0.1:5173", - "http://localhost:7999", # 前端开发服务器备用端口 - "http://127.0.0.1:7999", - "http://localhost:8001", # 生产环境 - "http://127.0.0.1:8001", - ], - allow_credentials=True, # 允许携带 Cookie - allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法 - allow_headers=[ - "Content-Type", - "Authorization", - "Accept", - "Origin", - "X-Requested-With", - ], # 明确指定允许的头 - expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头 - ) - logger.debug("✅ CORS 中间件已配置") - - def _show_access_token(self): - """显示 WebUI Access Token""" - try: - from src.webui.token_manager import get_token_manager - - token_manager = get_token_manager() - current_token = token_manager.get_token() - logger.info(f"🔑 WebUI Access Token: {current_token}") - logger.info("💡 请使用此 Token 登录 WebUI") - except Exception as e: - logger.error(f"❌ 获取 Access Token 失败: {e}") - - def _setup_static_files(self): - """设置静态文件服务""" - # 确保正确的 MIME 类型映射 - mimetypes.init() - mimetypes.add_type("application/javascript", ".js") - mimetypes.add_type("application/javascript", ".mjs") - mimetypes.add_type("text/css", ".css") - mimetypes.add_type("application/json", ".json") - - base_dir = Path(__file__).parent.parent.parent - static_path = base_dir / "webui" / "dist" - - if not static_path.exists(): - logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}") - logger.warning("💡 请先构建前端: cd webui && npm run build") - return - - if not (static_path / "index.html").exists(): - logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}") - logger.warning("💡 请确认前端已正确构建") - return - - # 处理 SPA 路由 - 注意:这个路由优先级最低 - @self.app.get("/{full_path:path}", include_in_schema=False) - async def serve_spa(full_path: str): - """服务单页应用 - 只处理非 API 请求""" - # 如果是根路径,直接返回 index.html - if not full_path or full_path == "/": - response = FileResponse(static_path / "index.html", media_type="text/html") - response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive" - return response - - # 检查是否是静态文件 - file_path = static_path / full_path - if file_path.is_file() and file_path.exists(): - # 自动检测 MIME 类型 - media_type = mimetypes.guess_type(str(file_path))[0] - response = FileResponse(file_path, media_type=media_type) - # HTML 文件添加防索引头 - if str(file_path).endswith(".html"): - response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive" - return response - - # 其他路径返回 index.html(SPA 路由) - response = FileResponse(static_path / "index.html", media_type="text/html") - response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive" - return response - - logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}") - - def _setup_anti_crawler(self): - """配置防爬虫中间件""" - try: - from src.webui.anti_crawler import AntiCrawlerMiddleware - from src.config.config import global_config - - # 从配置读取防爬虫模式 - anti_crawler_mode = global_config.webui.anti_crawler_mode - - # 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行 - # 我们需要在CORS之前注册,这样防爬虫检查会在CORS之前执行 - self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode) - - mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"} - mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式") - logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}") - except Exception as e: - logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True) - - def _setup_robots_txt(self): - """设置robots.txt路由""" - try: - from src.webui.anti_crawler import create_robots_txt_response - - @self.app.get("/robots.txt", include_in_schema=False) - async def robots_txt(): - """返回robots.txt,禁止所有爬虫""" - return create_robots_txt_response() - - logger.debug("✅ robots.txt 路由已注册") - except Exception as e: - logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True) - - def _register_api_routes(self): - """注册所有 WebUI API 路由""" - try: - # 导入所有 WebUI 路由 - from src.webui.routes import router as webui_router - from src.webui.logs_ws import router as logs_router - from src.webui.knowledge_routes import router as knowledge_router - - # 导入本地聊天室路由 - from src.webui.chat_routes import router as chat_router - - # 导入规划器监控路由 - from src.webui.api.planner import router as planner_router - - # 导入回复器监控路由 - from src.webui.api.replier import router as replier_router - - # 注册路由 - self.app.include_router(webui_router) - self.app.include_router(logs_router) - self.app.include_router(knowledge_router) - self.app.include_router(chat_router) - self.app.include_router(planner_router) - self.app.include_router(replier_router) - - logger.info("✅ WebUI API 路由已注册") - except Exception as e: - logger.error(f"❌ 注册 WebUI API 路由失败: {e}", exc_info=True) + show_access_token() async def start(self): """启动服务器""" @@ -209,9 +41,9 @@ class WebUIServer: self._server = UvicornServer(config=config) logger.info("🌐 WebUI 服务器启动中...") - + # 根据地址类型显示正确的访问地址 - if ':' in self.host: + if ":" in self.host: # IPv6 地址需要用方括号包裹 logger.info(f"🌐 访问地址: http://[{self.host}]:{self.port}") if self.host == "::": @@ -245,7 +77,7 @@ class WebUIServer: import socket # 判断使用 IPv4 还是 IPv6 - if ':' in self.host: + if ":" in self.host: # IPv6 地址 family = socket.AF_INET6 test_host = self.host if self.host != "::" else "::1" @@ -289,6 +121,7 @@ def get_webui_server() -> WebUIServer: if _webui_server is None: # 从环境变量读取 import os + host = os.getenv("WEBUI_HOST", "127.0.0.1") port = int(os.getenv("WEBUI_PORT", "8001")) _webui_server = WebUIServer(host=host, port=port)