diff --git a/src/webui/chat_routes.py b/src/webui/chat_routes.py index 6dfbca5b..4c889915 100644 --- a/src/webui/chat_routes.py +++ b/src/webui/chat_routes.py @@ -17,6 +17,7 @@ from src.config.config import global_config from src.chat.message_receive.bot import chat_bot from src.webui.auth import verify_auth_token_from_cookie_or_header from src.webui.token_manager import get_token_manager +from src.webui.ws_auth import verify_ws_token logger = get_logger("webui.chat") @@ -398,23 +399,40 @@ async def websocket_chat( token: 认证 token(可选,也可从 Cookie 获取) 虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置 + + 支持三种认证方式(按优先级): + 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) + 2. Cookie 中的 maibot_session + 3. 直接使用 session token(兼容) + + 示例:ws://host/api/chat/ws?token=xxx """ - # 认证检查 - auth_token = token - if not auth_token: - # 尝试从 Cookie 获取 token - auth_token = websocket.cookies.get("maibot_session") + is_authenticated = False - if not auth_token: - logger.warning("WebSocket 聊天连接被拒绝:未提供认证 token") - await websocket.close(code=4001, reason="未提供认证信息") - return + # 方式 1: 尝试验证临时 WebSocket token(推荐方式) + if token and verify_ws_token(token): + is_authenticated = True + logger.debug("聊天 WebSocket 使用临时 token 认证成功") - # 验证 token - token_manager = get_token_manager() - if not token_manager.verify_token(auth_token): - logger.warning("WebSocket 聊天连接被拒绝:token 无效") - await websocket.close(code=4003, reason="Token 无效或已过期") + # 方式 2: 尝试从 Cookie 获取 session token + if not is_authenticated: + cookie_token = websocket.cookies.get("maibot_session") + if cookie_token: + token_manager = get_token_manager() + if token_manager.verify_token(cookie_token): + is_authenticated = True + logger.debug("聊天 WebSocket 使用 Cookie 认证成功") + + # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) + if not is_authenticated and token: + token_manager = get_token_manager() + if token_manager.verify_token(token): + is_authenticated = True + logger.debug("聊天 WebSocket 使用 session token 认证成功") + + if not is_authenticated: + logger.warning("聊天 WebSocket 连接被拒绝:认证失败") + await websocket.close(code=4001, reason="认证失败,请重新登录") return # 生成会话 ID(每次连接都是新的) diff --git a/src/webui/logs_ws.py b/src/webui/logs_ws.py index 836191ee..382a09a2 100644 --- a/src/webui/logs_ws.py +++ b/src/webui/logs_ws.py @@ -6,6 +6,7 @@ import json from pathlib import Path from src.common.logger import get_logger from src.webui.token_manager import get_token_manager +from src.webui.ws_auth import verify_ws_token logger = get_logger("webui.logs_ws") router = APIRouter() @@ -78,23 +79,39 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None """WebSocket 日志推送端点 客户端连接后会持续接收服务器端的日志消息 - 需要通过 query 参数传递 token 进行认证,例如:ws://host/ws/logs?token=xxx + 支持三种认证方式(按优先级): + 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) + 2. Cookie 中的 maibot_session + 3. 直接使用 session token(兼容) + + 示例:ws://host/ws/logs?token=xxx """ - # 认证检查 - if not token: - # 尝试从 Cookie 获取 token - token = websocket.cookies.get("maibot_session") + is_authenticated = False - if not token: - logger.warning("WebSocket 连接被拒绝:未提供认证 token") - await websocket.close(code=4001, reason="未提供认证信息") - return + # 方式 1: 尝试验证临时 WebSocket token(推荐方式) + if token and verify_ws_token(token): + is_authenticated = True + logger.debug("WebSocket 使用临时 token 认证成功") - # 验证 token - token_manager = get_token_manager() - if not token_manager.verify_token(token): - logger.warning("WebSocket 连接被拒绝:token 无效") - await websocket.close(code=4003, reason="Token 无效或已过期") + # 方式 2: 尝试从 Cookie 获取 session token + if not is_authenticated: + cookie_token = websocket.cookies.get("maibot_session") + if cookie_token: + token_manager = get_token_manager() + if token_manager.verify_token(cookie_token): + is_authenticated = True + logger.debug("WebSocket 使用 Cookie 认证成功") + + # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) + if not is_authenticated and token: + token_manager = get_token_manager() + if token_manager.verify_token(token): + is_authenticated = True + logger.debug("WebSocket 使用 session token 认证成功") + + if not is_authenticated: + logger.warning("WebSocket 连接被拒绝:认证失败") + await websocket.close(code=4001, reason="认证失败,请重新登录") return await websocket.accept() diff --git a/src/webui/plugin_progress_ws.py b/src/webui/plugin_progress_ws.py index 3d334ca9..2f576b6a 100644 --- a/src/webui/plugin_progress_ws.py +++ b/src/webui/plugin_progress_ws.py @@ -6,6 +6,7 @@ import json import asyncio from src.common.logger import get_logger from src.webui.token_manager import get_token_manager +from src.webui.ws_auth import verify_ws_token logger = get_logger("webui.plugin_progress") @@ -94,24 +95,39 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = """WebSocket 插件加载进度推送端点 客户端连接后会立即收到当前进度状态 - 需要通过 query 参数或 Cookie 传递 token 进行认证 + 支持三种认证方式(按优先级): + 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) + 2. Cookie 中的 maibot_session + 3. 直接使用 session token(兼容) + + 示例:ws://host/ws/plugin-progress?token=xxx """ - # 认证检查 - auth_token = token - if not auth_token: - # 尝试从 Cookie 获取 token - auth_token = websocket.cookies.get("maibot_session") + is_authenticated = False - if not auth_token: - logger.warning("插件进度 WebSocket 连接被拒绝:未提供认证 token") - await websocket.close(code=4001, reason="未提供认证信息") - return + # 方式 1: 尝试验证临时 WebSocket token(推荐方式) + if token and verify_ws_token(token): + is_authenticated = True + logger.debug("插件进度 WebSocket 使用临时 token 认证成功") - # 验证 token - token_manager = get_token_manager() - if not token_manager.verify_token(auth_token): - logger.warning("插件进度 WebSocket 连接被拒绝:token 无效") - await websocket.close(code=4003, reason="Token 无效或已过期") + # 方式 2: 尝试从 Cookie 获取 session token + if not is_authenticated: + cookie_token = websocket.cookies.get("maibot_session") + if cookie_token: + token_manager = get_token_manager() + if token_manager.verify_token(cookie_token): + is_authenticated = True + logger.debug("插件进度 WebSocket 使用 Cookie 认证成功") + + # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) + if not is_authenticated and token: + token_manager = get_token_manager() + if token_manager.verify_token(token): + is_authenticated = True + logger.debug("插件进度 WebSocket 使用 session token 认证成功") + + if not is_authenticated: + logger.warning("插件进度 WebSocket 连接被拒绝:认证失败") + await websocket.close(code=4001, reason="认证失败,请重新登录") return await websocket.accept() diff --git a/src/webui/plugin_routes.py b/src/webui/plugin_routes.py index 2e757aac..fd5f7462 100644 --- a/src/webui/plugin_routes.py +++ b/src/webui/plugin_routes.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field from typing import Optional, List, Dict, Any, get_origin from pathlib import Path import json +import re from src.common.logger import get_logger from src.common.toml_utils import save_toml_with_format from src.config.config import MMC_VERSION @@ -34,6 +35,85 @@ def get_token_from_cookie_or_header( return None +def validate_safe_path(user_path: str, base_path: Path) -> Path: + """ + 验证用户提供的路径是否安全,防止路径遍历攻击 + + Args: + user_path: 用户输入的路径(相对路径) + base_path: 允许的基础目录 + + Returns: + 安全的绝对路径 + + Raises: + HTTPException: 如果检测到路径遍历攻击 + """ + # 规范化基础路径 + base_resolved = base_path.resolve() + + # 检查用户路径是否包含可疑字符 + # 禁止: .., 绝对路径开头, 空字节等 + if any(pattern in user_path for pattern in ["..", "\x00"]): + logger.warning(f"检测到可疑路径: {user_path}") + raise HTTPException(status_code=400, detail="路径包含非法字符") + + # 检查是否为绝对路径(Windows 和 Unix) + if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"): + logger.warning(f"检测到绝对路径: {user_path}") + raise HTTPException(status_code=400, detail="不允许使用绝对路径") + + # 构建目标路径并解析 + target_path = (base_path / user_path).resolve() + + # 验证解析后的路径仍在基础目录内 + try: + target_path.relative_to(base_resolved) + except ValueError as e: + logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}") + raise HTTPException(status_code=400, detail="路径超出允许范围") from e + + return target_path + + +def validate_plugin_id(plugin_id: str) -> str: + """ + 验证插件 ID 格式是否安全 + + Args: + plugin_id: 插件 ID (支持 author.name 格式,允许中文) + + Returns: + 验证通过的插件 ID + + Raises: + HTTPException: 如果插件 ID 格式不安全 + """ + # 禁止空字符串 + if not plugin_id or not plugin_id.strip(): + logger.warning("非法插件 ID: 空字符串") + raise HTTPException(status_code=400, detail="插件 ID 不能为空") + + # 禁止危险字符: 路径分隔符、空字节、控制字符等 + dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"] + for pattern in dangerous_patterns: + if pattern in plugin_id: + logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)") + raise HTTPException(status_code=400, detail="插件 ID 包含非法字符") + + # 禁止以点开头或结尾(防止隐藏文件和路径问题) + if plugin_id.startswith(".") or plugin_id.endswith("."): + logger.warning(f"非法插件 ID: {plugin_id}") + raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾") + + # 禁止特殊名称 + if plugin_id in (".", ".."): + logger.warning(f"非法插件 ID: {plugin_id}") + raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名") + + return plugin_id + + def parse_version(version_str: str) -> tuple[int, int, int]: """ 解析版本号字符串 @@ -468,17 +548,16 @@ async def fetch_raw_file( 支持多镜像源自动切换和错误重试 - 注意:此接口可公开访问,用于获取插件仓库等公开资源 + 需要认证才能访问,防止被滥用作为 SSRF 跳板 """ - # Token 验证(可选,用于日志记录) + # Token 验证(强制) token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() - is_authenticated = token and token_manager.verify_token(token) + if not token or not token_manager.verify_token(token): + raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - # 对于公开仓库的访问,不强制要求认证 - # 只在日志中记录是否认证 logger.info( - f"收到获取 Raw 文件请求 (认证: {is_authenticated}): " + f"收到获取 Raw 文件请求: " f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}" ) @@ -564,10 +643,10 @@ async def clone_repository( logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}") try: - # TODO: 验证 target_path 的安全性,防止路径遍历攻击 - # TODO: 确定实际的插件目录基路径 - base_plugin_path = Path("./plugins") # 临时路径 - target_path = base_plugin_path / request.target_path + # 验证 target_path 的安全性,防止路径遍历攻击 + base_plugin_path = Path("./plugins").resolve() + base_plugin_path.mkdir(exist_ok=True) + target_path = validate_safe_path(request.target_path, base_plugin_path) service = get_git_mirror_service() result = await service.clone_repository( @@ -607,13 +686,16 @@ async def install_plugin( logger.info(f"收到安装插件请求: {request.plugin_id}") try: + # 验证插件 ID 格式安全性 + plugin_id = validate_plugin_id(request.plugin_id) + # 推送进度:开始安装 await update_progress( stage="loading", progress=5, - message=f"开始安装插件: {request.plugin_id}", + message=f"开始安装插件: {plugin_id}", operation="install", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) # 1. 解析仓库 URL @@ -634,27 +716,28 @@ async def install_plugin( progress=10, message=f"解析仓库信息: {owner}/{repo}", operation="install", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) # 2. 确定插件安装路径 - plugins_dir = Path("plugins") + plugins_dir = Path("plugins").resolve() plugins_dir.mkdir(exist_ok=True) # 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题) # 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin - folder_name = request.plugin_id.replace(".", "_") - target_path = plugins_dir / folder_name + folder_name = plugin_id.replace(".", "_") + # 使用安全路径验证,防止路径遍历 + target_path = validate_safe_path(folder_name, plugins_dir) # 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点) - old_format_path = plugins_dir / request.plugin_id + old_format_path = plugins_dir / plugin_id if target_path.exists() or old_format_path.exists(): await update_progress( stage="error", progress=0, message="插件已存在", operation="install", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error="插件已安装,请先卸载", ) raise HTTPException(status_code=400, detail="插件已安装") @@ -664,7 +747,7 @@ async def install_plugin( progress=15, message=f"准备克隆到: {target_path}", operation="install", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) # 3. 克隆仓库(这里会自动推送 20%-80% 的进度) @@ -693,14 +776,14 @@ async def install_plugin( progress=0, message="克隆仓库失败", operation="install", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error=error_msg, ) raise HTTPException(status_code=500, detail=error_msg) # 4. 验证插件完整性 await update_progress( - stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id + stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id ) manifest_path = target_path / "_manifest.json" @@ -715,14 +798,14 @@ async def install_plugin( progress=0, message="插件缺少 _manifest.json", operation="install", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error="无效的插件格式", ) raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") # 5. 读取并验证 manifest await update_progress( - stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id + stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id ) try: @@ -739,7 +822,7 @@ async def install_plugin( # 将插件 ID 写入 manifest(用于后续准确识别) # 这样即使文件夹名称改变,也能通过 manifest 准确识别插件 - manifest["id"] = request.plugin_id + manifest["id"] = plugin_id with open(manifest_path, "w", encoding="utf-8") as f: json_module.dump(manifest, f, ensure_ascii=False, indent=2) @@ -754,7 +837,7 @@ async def install_plugin( progress=0, message="_manifest.json 无效", operation="install", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error=str(e), ) raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e @@ -765,13 +848,13 @@ async def install_plugin( progress=100, message=f"成功安装插件: {manifest['name']} v{manifest['version']}", operation="install", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) return { "success": True, "message": "插件安装成功", - "plugin_id": request.plugin_id, + "plugin_id": plugin_id, "plugin_name": manifest["name"], "version": manifest["version"], "path": str(target_path), @@ -787,7 +870,7 @@ async def install_plugin( progress=0, message="安装失败", operation="install", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error=str(e), ) @@ -814,22 +897,26 @@ async def uninstall_plugin( logger.info(f"收到卸载插件请求: {request.plugin_id}") try: + # 验证插件 ID 格式安全性 + plugin_id = validate_plugin_id(request.plugin_id) + # 推送进度:开始卸载 await update_progress( stage="loading", progress=10, - message=f"开始卸载插件: {request.plugin_id}", + message=f"开始卸载插件: {plugin_id}", operation="uninstall", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) # 1. 检查插件是否存在(支持新旧两种格式) - plugins_dir = Path("plugins") + plugins_dir = Path("plugins").resolve() # 新格式:下划线 - folder_name = request.plugin_id.replace(".", "_") - plugin_path = plugins_dir / folder_name + folder_name = plugin_id.replace(".", "_") + # 使用安全路径验证 + plugin_path = validate_safe_path(folder_name, plugins_dir) # 旧格式:点 - old_format_path = plugins_dir / request.plugin_id + old_format_path = validate_safe_path(plugin_id, plugins_dir) # 优先使用新格式,如果不存在则尝试旧格式 if not plugin_path.exists(): @@ -841,7 +928,7 @@ async def uninstall_plugin( progress=0, message="插件不存在", operation="uninstall", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error="插件未安装或已被删除", ) raise HTTPException(status_code=404, detail="插件未安装") @@ -851,12 +938,12 @@ async def uninstall_plugin( progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) # 2. 读取插件信息(用于日志) manifest_path = plugin_path / "_manifest.json" - plugin_name = request.plugin_id + plugin_name = plugin_id if manifest_path.exists(): try: @@ -864,7 +951,7 @@ async def uninstall_plugin( with open(manifest_path, "r", encoding="utf-8") as f: manifest = json_module.load(f) - plugin_name = manifest.get("name", request.plugin_id) + plugin_name = manifest.get("name", plugin_id) except Exception: pass # 如果读取失败,使用插件 ID 作为名称 @@ -873,7 +960,7 @@ async def uninstall_plugin( progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) # 3. 删除插件目录 @@ -889,7 +976,7 @@ async def uninstall_plugin( shutil.rmtree(plugin_path, onerror=remove_readonly) - logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})") + logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})") # 4. 推送成功状态 await update_progress( @@ -897,10 +984,10 @@ async def uninstall_plugin( progress=100, message=f"成功卸载插件: {plugin_name}", operation="uninstall", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) - return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name} + return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name} except HTTPException: raise @@ -912,7 +999,7 @@ async def uninstall_plugin( progress=0, message="卸载失败", operation="uninstall", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error="权限不足,无法删除插件文件", ) @@ -925,7 +1012,7 @@ async def uninstall_plugin( progress=0, message="卸载失败", operation="uninstall", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error=str(e), ) @@ -952,22 +1039,26 @@ async def update_plugin( logger.info(f"收到更新插件请求: {request.plugin_id}") try: + # 验证插件 ID 格式安全性 + plugin_id = validate_plugin_id(request.plugin_id) + # 推送进度:开始更新 await update_progress( stage="loading", progress=5, - message=f"开始更新插件: {request.plugin_id}", + message=f"开始更新插件: {plugin_id}", operation="update", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) # 1. 检查插件是否已安装(支持新旧两种格式) - plugins_dir = Path("plugins") + plugins_dir = Path("plugins").resolve() # 新格式:下划线 - folder_name = request.plugin_id.replace(".", "_") - plugin_path = plugins_dir / folder_name + folder_name = plugin_id.replace(".", "_") + # 使用安全路径验证 + plugin_path = validate_safe_path(folder_name, plugins_dir) # 旧格式:点 - old_format_path = plugins_dir / request.plugin_id + old_format_path = validate_safe_path(plugin_id, plugins_dir) # 优先使用新格式,如果不存在则尝试旧格式 if not plugin_path.exists(): @@ -979,7 +1070,7 @@ async def update_plugin( progress=0, message="插件不存在", operation="update", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error="插件未安装,请先安装", ) raise HTTPException(status_code=404, detail="插件未安装") @@ -1003,12 +1094,12 @@ async def update_plugin( progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) # 3. 删除旧版本 await update_progress( - stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id + stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id ) import shutil @@ -1023,7 +1114,7 @@ async def update_plugin( shutil.rmtree(plugin_path, onerror=remove_readonly) - logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}") + logger.info(f"已删除旧版本: {plugin_id} v{old_version}") # 4. 解析仓库 URL await update_progress( @@ -1031,7 +1122,7 @@ async def update_plugin( progress=30, message="正在准备下载新版本...", operation="update", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) repo_url = request.repository_url.rstrip("/") @@ -1069,14 +1160,14 @@ async def update_plugin( progress=0, message="下载新版本失败", operation="update", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error=error_msg, ) raise HTTPException(status_code=500, detail=error_msg) # 6. 验证新版本 await update_progress( - stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id + stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id ) new_manifest_path = plugin_path / "_manifest.json" @@ -1096,7 +1187,7 @@ async def update_plugin( progress=0, message="新版本缺少 _manifest.json", operation="update", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error="无效的插件格式", ) raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") @@ -1107,9 +1198,9 @@ async def update_plugin( new_manifest = json_module.load(f) new_version = new_manifest.get("version", "unknown") - new_name = new_manifest.get("name", request.plugin_id) + new_name = new_manifest.get("name", plugin_id) - logger.info(f"成功更新插件: {request.plugin_id} {old_version} → {new_version}") + logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}") # 8. 推送成功状态 await update_progress( @@ -1117,13 +1208,13 @@ async def update_plugin( progress=100, message=f"成功更新 {new_name}: {old_version} → {new_version}", operation="update", - plugin_id=request.plugin_id, + plugin_id=plugin_id, ) return { "success": True, "message": "插件更新成功", - "plugin_id": request.plugin_id, + "plugin_id": plugin_id, "plugin_name": new_name, "old_version": old_version, "new_version": new_version, @@ -1138,7 +1229,7 @@ async def update_plugin( progress=0, message="_manifest.json 无效", operation="update", - plugin_id=request.plugin_id, + plugin_id=plugin_id, error=str(e), ) raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e @@ -1149,7 +1240,7 @@ async def update_plugin( logger.error(f"更新插件失败: {e}", exc_info=True) await update_progress( - stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e) + stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e) ) raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e diff --git a/src/webui/routes.py b/src/webui/routes.py index 8be6f84f..558b8852 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -17,6 +17,7 @@ from .plugin_routes import router as plugin_router from .plugin_progress_ws import get_progress_router from .routers.system import router as system_router from .model_routes import router as model_router +from .ws_auth import router as ws_auth_router logger = get_logger("webui.api") @@ -43,6 +44,8 @@ router.include_router(get_progress_router()) router.include_router(system_router) # 注册模型列表获取路由 router.include_router(model_router) +# 注册 WebSocket 认证路由 +router.include_router(ws_auth_router) class TokenVerifyRequest(BaseModel): diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index 4ecd509d..928824e0 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -40,6 +40,8 @@ class WebUIServer: allow_origins=[ "http://localhost:5173", # Vite 开发服务器 "http://127.0.0.1:5173", + "http://localhost:7999", # 前端开发服务器备用端口 + "http://127.0.0.1:7999", "http://localhost:8001", # 生产环境 "http://127.0.0.1:8001", ], diff --git a/src/webui/ws_auth.py b/src/webui/ws_auth.py new file mode 100644 index 00000000..d6c4bd33 --- /dev/null +++ b/src/webui/ws_auth.py @@ -0,0 +1,107 @@ +"""WebSocket 认证模块 + +提供所有 WebSocket 端点统一使用的临时 token 认证机制。 +临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。 +""" + +from fastapi import APIRouter, Cookie, Header, HTTPException +from typing import Optional +import secrets +import time +from src.common.logger import get_logger +from src.webui.token_manager import get_token_manager + +logger = get_logger("webui.ws_auth") +router = APIRouter() + +# WebSocket 临时 token 存储 {token: (expire_time, session_token)} +# 临时 token 有效期 60 秒,仅用于 WebSocket 握手 +_ws_temp_tokens: dict[str, tuple[float, str]] = {} +_WS_TOKEN_EXPIRE_SECONDS = 60 + + +def _cleanup_expired_ws_tokens(): + """清理过期的临时 token""" + now = time.time() + expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp] + for t in expired: + del _ws_temp_tokens[t] + + +def generate_ws_token(session_token: str) -> str: + """生成 WebSocket 临时 token + + Args: + session_token: 原始的 session token + + Returns: + 临时 token 字符串 + """ + _cleanup_expired_ws_tokens() + temp_token = secrets.token_urlsafe(32) + _ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token) + logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s") + return temp_token + + +def verify_ws_token(temp_token: str) -> bool: + """验证并消费 WebSocket 临时 token(一次性使用) + + Args: + temp_token: 临时 token + + Returns: + 验证是否通过 + """ + _cleanup_expired_ws_tokens() + if temp_token not in _ws_temp_tokens: + logger.warning(f"WS token 不存在: {temp_token[:8]}...") + return False + expire_time, session_token = _ws_temp_tokens[temp_token] + if time.time() > expire_time: + del _ws_temp_tokens[temp_token] + logger.warning(f"WS token 已过期: {temp_token[:8]}...") + return False + # 验证原始 session token 仍然有效 + token_manager = get_token_manager() + if not token_manager.verify_token(session_token): + del _ws_temp_tokens[temp_token] + logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...") + return False + # 消费 token(一次性使用) + del _ws_temp_tokens[temp_token] + logger.debug(f"WS token 验证成功: {temp_token[:8]}...") + return True + + +@router.get("/ws-token") +async def get_ws_token( + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): + """ + 获取 WebSocket 连接用的临时 token + + 此端点验证当前会话的 Cookie 或 Authorization header, + 然后返回一个临时 token 用于 WebSocket 握手认证。 + 临时 token 有效期 60 秒,且只能使用一次。 + """ + # 获取当前 session token + session_token = None + if maibot_session: + session_token = maibot_session + elif authorization and authorization.startswith("Bearer "): + session_token = authorization.replace("Bearer ", "") + + if not session_token: + raise HTTPException(status_code=401, detail="未提供认证信息") + + # 验证 session token + token_manager = get_token_manager() + if not token_manager.verify_token(session_token): + raise HTTPException(status_code=401, detail="认证已过期,请重新登录") + + # 生成临时 WebSocket token + ws_token = generate_ws_token(session_token) + + return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}