添加 WebSocket 认证模块,支持临时 token 认证机制,增强安全性并解决 Cookie 不可用问题

pull/1438/head^2
墨梓柒 2025-12-14 20:08:49 +08:00
parent ea420f9f59
commit 6055b087f0
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
7 changed files with 361 additions and 107 deletions

View File

@ -17,6 +17,7 @@ from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.chat")
@ -398,23 +399,40 @@ async def websocket_chat(
token: 认证 token可选也可从 Cookie 获取
虚拟身份模式可通过 URL 参数直接配置或通过消息中的 set_virtual_identity 配置
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx
"""
# 认证检查
auth_token = token
if not auth_token:
# 尝试从 Cookie 获取 token
auth_token = websocket.cookies.get("maibot_session")
is_authenticated = False
if not auth_token:
logger.warning("WebSocket 聊天连接被拒绝:未提供认证 token")
await websocket.close(code=4001, reason="未提供认证信息")
return
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(auth_token):
logger.warning("WebSocket 聊天连接被拒绝token 无效")
await websocket.close(code=4003, reason="Token 无效或已过期")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
# 生成会话 ID每次连接都是新的

View File

@ -6,6 +6,7 @@ import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.logs_ws")
router = APIRouter()
@ -78,23 +79,39 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
需要通过 query 参数传递 token 进行认证例如ws://host/ws/logs?token=xxx
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
# 认证检查
if not token:
# 尝试从 Cookie 获取 token
token = websocket.cookies.get("maibot_session")
is_authenticated = False
if not token:
logger.warning("WebSocket 连接被拒绝:未提供认证 token")
await websocket.close(code=4001, reason="未提供认证信息")
return
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
logger.warning("WebSocket 连接被拒绝token 无效")
await websocket.close(code=4003, reason="Token 无效或已过期")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()

View File

@ -6,6 +6,7 @@ import json
import asyncio
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.plugin_progress")
@ -94,24 +95,39 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
需要通过 query 参数或 Cookie 传递 token 进行认证
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx
"""
# 认证检查
auth_token = token
if not auth_token:
# 尝试从 Cookie 获取 token
auth_token = websocket.cookies.get("maibot_session")
is_authenticated = False
if not auth_token:
logger.warning("插件进度 WebSocket 连接被拒绝:未提供认证 token")
await websocket.close(code=4001, reason="未提供认证信息")
return
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(auth_token):
logger.warning("插件进度 WebSocket 连接被拒绝token 无效")
await websocket.close(code=4003, reason="Token 无效或已过期")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()

View File

@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, get_origin
from pathlib import Path
import json
import re
from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION
@ -34,6 +35,85 @@ def get_token_from_cookie_or_header(
return None
def validate_safe_path(user_path: str, base_path: Path) -> Path:
"""
验证用户提供的路径是否安全防止路径遍历攻击
Args:
user_path: 用户输入的路径相对路径
base_path: 允许的基础目录
Returns:
安全的绝对路径
Raises:
HTTPException: 如果检测到路径遍历攻击
"""
# 规范化基础路径
base_resolved = base_path.resolve()
# 检查用户路径是否包含可疑字符
# 禁止: .., 绝对路径开头, 空字节等
if any(pattern in user_path for pattern in ["..", "\x00"]):
logger.warning(f"检测到可疑路径: {user_path}")
raise HTTPException(status_code=400, detail="路径包含非法字符")
# 检查是否为绝对路径Windows 和 Unix
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
logger.warning(f"检测到绝对路径: {user_path}")
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
# 构建目标路径并解析
target_path = (base_path / user_path).resolve()
# 验证解析后的路径仍在基础目录内
try:
target_path.relative_to(base_resolved)
except ValueError as e:
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
return target_path
def validate_plugin_id(plugin_id: str) -> str:
"""
验证插件 ID 格式是否安全
Args:
plugin_id: 插件 ID (支持 author.name 格式允许中文)
Returns:
验证通过的插件 ID
Raises:
HTTPException: 如果插件 ID 格式不安全
"""
# 禁止空字符串
if not plugin_id or not plugin_id.strip():
logger.warning("非法插件 ID: 空字符串")
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
# 禁止危险字符: 路径分隔符、空字节、控制字符等
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
for pattern in dangerous_patterns:
if pattern in plugin_id:
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
if plugin_id.startswith(".") or plugin_id.endswith("."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
# 禁止特殊名称
if plugin_id in (".", ".."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
return plugin_id
def parse_version(version_str: str) -> tuple[int, int, int]:
"""
解析版本号字符串
@ -468,17 +548,16 @@ async def fetch_raw_file(
支持多镜像源自动切换和错误重试
注意此接口可公开访问用于获取插件仓库等公开资源
需要认证才能访问防止被滥用作为 SSRF 跳板
"""
# Token 验证(可选,用于日志记录
# Token 验证(强制
token = get_token_from_cookie_or_header(maibot_session, authorization)
token_manager = get_token_manager()
is_authenticated = token and token_manager.verify_token(token)
if not token or not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
# 对于公开仓库的访问,不强制要求认证
# 只在日志中记录是否认证
logger.info(
f"收到获取 Raw 文件请求 (认证: {is_authenticated}): "
f"收到获取 Raw 文件请求: "
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
)
@ -564,10 +643,10 @@ async def clone_repository(
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
try:
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
# TODO: 确定实际的插件目录基路径
base_plugin_path = Path("./plugins") # 临时路径
target_path = base_plugin_path / request.target_path
# 验证 target_path 的安全性,防止路径遍历攻击
base_plugin_path = Path("./plugins").resolve()
base_plugin_path.mkdir(exist_ok=True)
target_path = validate_safe_path(request.target_path, base_plugin_path)
service = get_git_mirror_service()
result = await service.clone_repository(
@ -607,13 +686,16 @@ async def install_plugin(
logger.info(f"收到安装插件请求: {request.plugin_id}")
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始安装
await update_progress(
stage="loading",
progress=5,
message=f"开始安装插件: {request.plugin_id}",
message=f"开始安装插件: {plugin_id}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 1. 解析仓库 URL
@ -634,27 +716,28 @@ async def install_plugin(
progress=10,
message=f"解析仓库信息: {owner}/{repo}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 2. 确定插件安装路径
plugins_dir = Path("plugins")
plugins_dir = Path("plugins").resolve()
plugins_dir.mkdir(exist_ok=True)
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
folder_name = request.plugin_id.replace(".", "_")
target_path = plugins_dir / folder_name
folder_name = plugin_id.replace(".", "_")
# 使用安全路径验证,防止路径遍历
target_path = validate_safe_path(folder_name, plugins_dir)
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
old_format_path = plugins_dir / request.plugin_id
old_format_path = plugins_dir / plugin_id
if target_path.exists() or old_format_path.exists():
await update_progress(
stage="error",
progress=0,
message="插件已存在",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="插件已安装,请先卸载",
)
raise HTTPException(status_code=400, detail="插件已安装")
@ -664,7 +747,7 @@ async def install_plugin(
progress=15,
message=f"准备克隆到: {target_path}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
@ -693,14 +776,14 @@ async def install_plugin(
progress=0,
message="克隆仓库失败",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# 4. 验证插件完整性
await update_progress(
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
)
manifest_path = target_path / "_manifest.json"
@ -715,14 +798,14 @@ async def install_plugin(
progress=0,
message="插件缺少 _manifest.json",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
# 5. 读取并验证 manifest
await update_progress(
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
)
try:
@ -739,7 +822,7 @@ async def install_plugin(
# 将插件 ID 写入 manifest用于后续准确识别
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
manifest["id"] = request.plugin_id
manifest["id"] = plugin_id
with open(manifest_path, "w", encoding="utf-8") as f:
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
@ -754,7 +837,7 @@ async def install_plugin(
progress=0,
message="_manifest.json 无效",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
@ -765,13 +848,13 @@ async def install_plugin(
progress=100,
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
return {
"success": True,
"message": "插件安装成功",
"plugin_id": request.plugin_id,
"plugin_id": plugin_id,
"plugin_name": manifest["name"],
"version": manifest["version"],
"path": str(target_path),
@ -787,7 +870,7 @@ async def install_plugin(
progress=0,
message="安装失败",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
@ -814,22 +897,26 @@ async def uninstall_plugin(
logger.info(f"收到卸载插件请求: {request.plugin_id}")
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始卸载
await update_progress(
stage="loading",
progress=10,
message=f"开始卸载插件: {request.plugin_id}",
message=f"开始卸载插件: {plugin_id}",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 1. 检查插件是否存在(支持新旧两种格式)
plugins_dir = Path("plugins")
plugins_dir = Path("plugins").resolve()
# 新格式:下划线
folder_name = request.plugin_id.replace(".", "_")
plugin_path = plugins_dir / folder_name
folder_name = plugin_id.replace(".", "_")
# 使用安全路径验证
plugin_path = validate_safe_path(folder_name, plugins_dir)
# 旧格式:点
old_format_path = plugins_dir / request.plugin_id
old_format_path = validate_safe_path(plugin_id, plugins_dir)
# 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists():
@ -841,7 +928,7 @@ async def uninstall_plugin(
progress=0,
message="插件不存在",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="插件未安装或已被删除",
)
raise HTTPException(status_code=404, detail="插件未安装")
@ -851,12 +938,12 @@ async def uninstall_plugin(
progress=30,
message=f"正在删除插件文件: {plugin_path}",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 2. 读取插件信息(用于日志)
manifest_path = plugin_path / "_manifest.json"
plugin_name = request.plugin_id
plugin_name = plugin_id
if manifest_path.exists():
try:
@ -864,7 +951,7 @@ async def uninstall_plugin(
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json_module.load(f)
plugin_name = manifest.get("name", request.plugin_id)
plugin_name = manifest.get("name", plugin_id)
except Exception:
pass # 如果读取失败,使用插件 ID 作为名称
@ -873,7 +960,7 @@ async def uninstall_plugin(
progress=50,
message=f"正在删除 {plugin_name}...",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 3. 删除插件目录
@ -889,7 +976,7 @@ async def uninstall_plugin(
shutil.rmtree(plugin_path, onerror=remove_readonly)
logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})")
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
# 4. 推送成功状态
await update_progress(
@ -897,10 +984,10 @@ async def uninstall_plugin(
progress=100,
message=f"成功卸载插件: {plugin_name}",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
except HTTPException:
raise
@ -912,7 +999,7 @@ async def uninstall_plugin(
progress=0,
message="卸载失败",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="权限不足,无法删除插件文件",
)
@ -925,7 +1012,7 @@ async def uninstall_plugin(
progress=0,
message="卸载失败",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
@ -952,22 +1039,26 @@ async def update_plugin(
logger.info(f"收到更新插件请求: {request.plugin_id}")
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始更新
await update_progress(
stage="loading",
progress=5,
message=f"开始更新插件: {request.plugin_id}",
message=f"开始更新插件: {plugin_id}",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 1. 检查插件是否已安装(支持新旧两种格式)
plugins_dir = Path("plugins")
plugins_dir = Path("plugins").resolve()
# 新格式:下划线
folder_name = request.plugin_id.replace(".", "_")
plugin_path = plugins_dir / folder_name
folder_name = plugin_id.replace(".", "_")
# 使用安全路径验证
plugin_path = validate_safe_path(folder_name, plugins_dir)
# 旧格式:点
old_format_path = plugins_dir / request.plugin_id
old_format_path = validate_safe_path(plugin_id, plugins_dir)
# 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists():
@ -979,7 +1070,7 @@ async def update_plugin(
progress=0,
message="插件不存在",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="插件未安装,请先安装",
)
raise HTTPException(status_code=404, detail="插件未安装")
@ -1003,12 +1094,12 @@ async def update_plugin(
progress=10,
message=f"当前版本: {old_version},准备更新...",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 3. 删除旧版本
await update_progress(
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
)
import shutil
@ -1023,7 +1114,7 @@ async def update_plugin(
shutil.rmtree(plugin_path, onerror=remove_readonly)
logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}")
logger.info(f"已删除旧版本: {plugin_id} v{old_version}")
# 4. 解析仓库 URL
await update_progress(
@ -1031,7 +1122,7 @@ async def update_plugin(
progress=30,
message="正在准备下载新版本...",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
repo_url = request.repository_url.rstrip("/")
@ -1069,14 +1160,14 @@ async def update_plugin(
progress=0,
message="下载新版本失败",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# 6. 验证新版本
await update_progress(
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
)
new_manifest_path = plugin_path / "_manifest.json"
@ -1096,7 +1187,7 @@ async def update_plugin(
progress=0,
message="新版本缺少 _manifest.json",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
@ -1107,9 +1198,9 @@ async def update_plugin(
new_manifest = json_module.load(f)
new_version = new_manifest.get("version", "unknown")
new_name = new_manifest.get("name", request.plugin_id)
new_name = new_manifest.get("name", plugin_id)
logger.info(f"成功更新插件: {request.plugin_id} {old_version}{new_version}")
logger.info(f"成功更新插件: {plugin_id} {old_version}{new_version}")
# 8. 推送成功状态
await update_progress(
@ -1117,13 +1208,13 @@ async def update_plugin(
progress=100,
message=f"成功更新 {new_name}: {old_version}{new_version}",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
return {
"success": True,
"message": "插件更新成功",
"plugin_id": request.plugin_id,
"plugin_id": plugin_id,
"plugin_name": new_name,
"old_version": old_version,
"new_version": new_version,
@ -1138,7 +1229,7 @@ async def update_plugin(
progress=0,
message="_manifest.json 无效",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
@ -1149,7 +1240,7 @@ async def update_plugin(
logger.error(f"更新插件失败: {e}", exc_info=True)
await update_progress(
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e

View File

@ -17,6 +17,7 @@ from .plugin_routes import router as plugin_router
from .plugin_progress_ws import get_progress_router
from .routers.system import router as system_router
from .model_routes import router as model_router
from .ws_auth import router as ws_auth_router
logger = get_logger("webui.api")
@ -43,6 +44,8 @@ router.include_router(get_progress_router())
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
class TokenVerifyRequest(BaseModel):

View File

@ -40,6 +40,8 @@ class WebUIServer:
allow_origins=[
"http://localhost:5173", # Vite 开发服务器
"http://127.0.0.1:5173",
"http://localhost:7999", # 前端开发服务器备用端口
"http://127.0.0.1:7999",
"http://localhost:8001", # 生产环境
"http://127.0.0.1:8001",
],

View File

@ -0,0 +1,107 @@
"""WebSocket 认证模块
提供所有 WebSocket 端点统一使用的临时 token 认证机制
临时 token 有效期 60 且只能使用一次用于解决 WebSocket 握手时 Cookie 不可用的问题
"""
from fastapi import APIRouter, Cookie, Header, HTTPException
from typing import Optional
import secrets
import time
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
logger = get_logger("webui.ws_auth")
router = APIRouter()
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
_WS_TOKEN_EXPIRE_SECONDS = 60
def _cleanup_expired_ws_tokens():
"""清理过期的临时 token"""
now = time.time()
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
for t in expired:
del _ws_temp_tokens[t]
def generate_ws_token(session_token: str) -> str:
"""生成 WebSocket 临时 token
Args:
session_token: 原始的 session token
Returns:
临时 token 字符串
"""
_cleanup_expired_ws_tokens()
temp_token = secrets.token_urlsafe(32)
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
return temp_token
def verify_ws_token(temp_token: str) -> bool:
"""验证并消费 WebSocket 临时 token一次性使用
Args:
temp_token: 临时 token
Returns:
验证是否通过
"""
_cleanup_expired_ws_tokens()
if temp_token not in _ws_temp_tokens:
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
return False
expire_time, session_token = _ws_temp_tokens[temp_token]
if time.time() > expire_time:
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
return False
# 验证原始 session token 仍然有效
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
return False
# 消费 token一次性使用
del _ws_temp_tokens[temp_token]
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
return True
@router.get("/ws-token")
async def get_ws_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取 WebSocket 连接用的临时 token
此端点验证当前会话的 Cookie Authorization header
然后返回一个临时 token 用于 WebSocket 握手认证
临时 token 有效期 60 且只能使用一次
"""
# 获取当前 session token
session_token = None
if maibot_session:
session_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "")
if not session_token:
raise HTTPException(status_code=401, detail="未提供认证信息")
# 验证 session token
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
raise HTTPException(status_code=401, detail="认证已过期,请重新登录")
# 生成临时 WebSocket token
ws_token = generate_ws_token(session_token)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}