mirror of https://github.com/Mai-with-u/MaiBot.git
添加 WebSocket 认证模块,支持临时 token 认证机制,增强安全性并解决 Cookie 不可用问题
parent
ea420f9f59
commit
6055b087f0
|
|
@ -17,6 +17,7 @@ from src.config.config import global_config
|
||||||
from src.chat.message_receive.bot import chat_bot
|
from src.chat.message_receive.bot import chat_bot
|
||||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
from src.webui.token_manager import get_token_manager
|
from src.webui.token_manager import get_token_manager
|
||||||
|
from src.webui.ws_auth import verify_ws_token
|
||||||
|
|
||||||
logger = get_logger("webui.chat")
|
logger = get_logger("webui.chat")
|
||||||
|
|
||||||
|
|
@ -398,23 +399,40 @@ async def websocket_chat(
|
||||||
token: 认证 token(可选,也可从 Cookie 获取)
|
token: 认证 token(可选,也可从 Cookie 获取)
|
||||||
|
|
||||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||||
|
|
||||||
|
支持三种认证方式(按优先级):
|
||||||
|
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||||
|
2. Cookie 中的 maibot_session
|
||||||
|
3. 直接使用 session token(兼容)
|
||||||
|
|
||||||
|
示例:ws://host/api/chat/ws?token=xxx
|
||||||
"""
|
"""
|
||||||
# 认证检查
|
is_authenticated = False
|
||||||
auth_token = token
|
|
||||||
if not auth_token:
|
|
||||||
# 尝试从 Cookie 获取 token
|
|
||||||
auth_token = websocket.cookies.get("maibot_session")
|
|
||||||
|
|
||||||
if not auth_token:
|
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||||
logger.warning("WebSocket 聊天连接被拒绝:未提供认证 token")
|
if token and verify_ws_token(token):
|
||||||
await websocket.close(code=4001, reason="未提供认证信息")
|
is_authenticated = True
|
||||||
return
|
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||||
|
|
||||||
# 验证 token
|
# 方式 2: 尝试从 Cookie 获取 session token
|
||||||
token_manager = get_token_manager()
|
if not is_authenticated:
|
||||||
if not token_manager.verify_token(auth_token):
|
cookie_token = websocket.cookies.get("maibot_session")
|
||||||
logger.warning("WebSocket 聊天连接被拒绝:token 无效")
|
if cookie_token:
|
||||||
await websocket.close(code=4003, reason="Token 无效或已过期")
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(cookie_token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||||
|
|
||||||
|
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||||
|
if not is_authenticated and token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||||
|
|
||||||
|
if not is_authenticated:
|
||||||
|
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||||
|
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 生成会话 ID(每次连接都是新的)
|
# 生成会话 ID(每次连接都是新的)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.webui.token_manager import get_token_manager
|
from src.webui.token_manager import get_token_manager
|
||||||
|
from src.webui.ws_auth import verify_ws_token
|
||||||
|
|
||||||
logger = get_logger("webui.logs_ws")
|
logger = get_logger("webui.logs_ws")
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -78,23 +79,39 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
|
||||||
"""WebSocket 日志推送端点
|
"""WebSocket 日志推送端点
|
||||||
|
|
||||||
客户端连接后会持续接收服务器端的日志消息
|
客户端连接后会持续接收服务器端的日志消息
|
||||||
需要通过 query 参数传递 token 进行认证,例如:ws://host/ws/logs?token=xxx
|
支持三种认证方式(按优先级):
|
||||||
|
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||||
|
2. Cookie 中的 maibot_session
|
||||||
|
3. 直接使用 session token(兼容)
|
||||||
|
|
||||||
|
示例:ws://host/ws/logs?token=xxx
|
||||||
"""
|
"""
|
||||||
# 认证检查
|
is_authenticated = False
|
||||||
if not token:
|
|
||||||
# 尝试从 Cookie 获取 token
|
|
||||||
token = websocket.cookies.get("maibot_session")
|
|
||||||
|
|
||||||
if not token:
|
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||||
logger.warning("WebSocket 连接被拒绝:未提供认证 token")
|
if token and verify_ws_token(token):
|
||||||
await websocket.close(code=4001, reason="未提供认证信息")
|
is_authenticated = True
|
||||||
return
|
logger.debug("WebSocket 使用临时 token 认证成功")
|
||||||
|
|
||||||
# 验证 token
|
# 方式 2: 尝试从 Cookie 获取 session token
|
||||||
token_manager = get_token_manager()
|
if not is_authenticated:
|
||||||
if not token_manager.verify_token(token):
|
cookie_token = websocket.cookies.get("maibot_session")
|
||||||
logger.warning("WebSocket 连接被拒绝:token 无效")
|
if cookie_token:
|
||||||
await websocket.close(code=4003, reason="Token 无效或已过期")
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(cookie_token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||||
|
|
||||||
|
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||||
|
if not is_authenticated and token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("WebSocket 使用 session token 认证成功")
|
||||||
|
|
||||||
|
if not is_authenticated:
|
||||||
|
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||||
|
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||||
return
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.webui.token_manager import get_token_manager
|
from src.webui.token_manager import get_token_manager
|
||||||
|
from src.webui.ws_auth import verify_ws_token
|
||||||
|
|
||||||
logger = get_logger("webui.plugin_progress")
|
logger = get_logger("webui.plugin_progress")
|
||||||
|
|
||||||
|
|
@ -94,24 +95,39 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
||||||
"""WebSocket 插件加载进度推送端点
|
"""WebSocket 插件加载进度推送端点
|
||||||
|
|
||||||
客户端连接后会立即收到当前进度状态
|
客户端连接后会立即收到当前进度状态
|
||||||
需要通过 query 参数或 Cookie 传递 token 进行认证
|
支持三种认证方式(按优先级):
|
||||||
|
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||||
|
2. Cookie 中的 maibot_session
|
||||||
|
3. 直接使用 session token(兼容)
|
||||||
|
|
||||||
|
示例:ws://host/ws/plugin-progress?token=xxx
|
||||||
"""
|
"""
|
||||||
# 认证检查
|
is_authenticated = False
|
||||||
auth_token = token
|
|
||||||
if not auth_token:
|
|
||||||
# 尝试从 Cookie 获取 token
|
|
||||||
auth_token = websocket.cookies.get("maibot_session")
|
|
||||||
|
|
||||||
if not auth_token:
|
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||||
logger.warning("插件进度 WebSocket 连接被拒绝:未提供认证 token")
|
if token and verify_ws_token(token):
|
||||||
await websocket.close(code=4001, reason="未提供认证信息")
|
is_authenticated = True
|
||||||
return
|
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
|
||||||
|
|
||||||
# 验证 token
|
# 方式 2: 尝试从 Cookie 获取 session token
|
||||||
token_manager = get_token_manager()
|
if not is_authenticated:
|
||||||
if not token_manager.verify_token(auth_token):
|
cookie_token = websocket.cookies.get("maibot_session")
|
||||||
logger.warning("插件进度 WebSocket 连接被拒绝:token 无效")
|
if cookie_token:
|
||||||
await websocket.close(code=4003, reason="Token 无效或已过期")
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(cookie_token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||||
|
|
||||||
|
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||||
|
if not is_authenticated and token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
|
||||||
|
|
||||||
|
if not is_authenticated:
|
||||||
|
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
|
||||||
|
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||||
return
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
|
||||||
from typing import Optional, List, Dict, Any, get_origin
|
from typing import Optional, List, Dict, Any, get_origin
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.toml_utils import save_toml_with_format
|
from src.common.toml_utils import save_toml_with_format
|
||||||
from src.config.config import MMC_VERSION
|
from src.config.config import MMC_VERSION
|
||||||
|
|
@ -34,6 +35,85 @@ def get_token_from_cookie_or_header(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
||||||
|
"""
|
||||||
|
验证用户提供的路径是否安全,防止路径遍历攻击
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_path: 用户输入的路径(相对路径)
|
||||||
|
base_path: 允许的基础目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
安全的绝对路径
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 如果检测到路径遍历攻击
|
||||||
|
"""
|
||||||
|
# 规范化基础路径
|
||||||
|
base_resolved = base_path.resolve()
|
||||||
|
|
||||||
|
# 检查用户路径是否包含可疑字符
|
||||||
|
# 禁止: .., 绝对路径开头, 空字节等
|
||||||
|
if any(pattern in user_path for pattern in ["..", "\x00"]):
|
||||||
|
logger.warning(f"检测到可疑路径: {user_path}")
|
||||||
|
raise HTTPException(status_code=400, detail="路径包含非法字符")
|
||||||
|
|
||||||
|
# 检查是否为绝对路径(Windows 和 Unix)
|
||||||
|
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
|
||||||
|
logger.warning(f"检测到绝对路径: {user_path}")
|
||||||
|
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
|
||||||
|
|
||||||
|
# 构建目标路径并解析
|
||||||
|
target_path = (base_path / user_path).resolve()
|
||||||
|
|
||||||
|
# 验证解析后的路径仍在基础目录内
|
||||||
|
try:
|
||||||
|
target_path.relative_to(base_resolved)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
|
||||||
|
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
|
||||||
|
|
||||||
|
return target_path
|
||||||
|
|
||||||
|
|
||||||
|
def validate_plugin_id(plugin_id: str) -> str:
|
||||||
|
"""
|
||||||
|
验证插件 ID 格式是否安全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 插件 ID (支持 author.name 格式,允许中文)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证通过的插件 ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 如果插件 ID 格式不安全
|
||||||
|
"""
|
||||||
|
# 禁止空字符串
|
||||||
|
if not plugin_id or not plugin_id.strip():
|
||||||
|
logger.warning("非法插件 ID: 空字符串")
|
||||||
|
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
|
||||||
|
|
||||||
|
# 禁止危险字符: 路径分隔符、空字节、控制字符等
|
||||||
|
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
|
||||||
|
for pattern in dangerous_patterns:
|
||||||
|
if pattern in plugin_id:
|
||||||
|
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
|
||||||
|
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
|
||||||
|
|
||||||
|
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
|
||||||
|
if plugin_id.startswith(".") or plugin_id.endswith("."):
|
||||||
|
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||||
|
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
|
||||||
|
|
||||||
|
# 禁止特殊名称
|
||||||
|
if plugin_id in (".", ".."):
|
||||||
|
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||||
|
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
|
||||||
|
|
||||||
|
return plugin_id
|
||||||
|
|
||||||
|
|
||||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
解析版本号字符串
|
解析版本号字符串
|
||||||
|
|
@ -468,17 +548,16 @@ async def fetch_raw_file(
|
||||||
|
|
||||||
支持多镜像源自动切换和错误重试
|
支持多镜像源自动切换和错误重试
|
||||||
|
|
||||||
注意:此接口可公开访问,用于获取插件仓库等公开资源
|
需要认证才能访问,防止被滥用作为 SSRF 跳板
|
||||||
"""
|
"""
|
||||||
# Token 验证(可选,用于日志记录)
|
# Token 验证(强制)
|
||||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
token_manager = get_token_manager()
|
token_manager = get_token_manager()
|
||||||
is_authenticated = token and token_manager.verify_token(token)
|
if not token or not token_manager.verify_token(token):
|
||||||
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||||
|
|
||||||
# 对于公开仓库的访问,不强制要求认证
|
|
||||||
# 只在日志中记录是否认证
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"收到获取 Raw 文件请求 (认证: {is_authenticated}): "
|
f"收到获取 Raw 文件请求: "
|
||||||
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
|
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -564,10 +643,10 @@ async def clone_repository(
|
||||||
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
# 验证 target_path 的安全性,防止路径遍历攻击
|
||||||
# TODO: 确定实际的插件目录基路径
|
base_plugin_path = Path("./plugins").resolve()
|
||||||
base_plugin_path = Path("./plugins") # 临时路径
|
base_plugin_path.mkdir(exist_ok=True)
|
||||||
target_path = base_plugin_path / request.target_path
|
target_path = validate_safe_path(request.target_path, base_plugin_path)
|
||||||
|
|
||||||
service = get_git_mirror_service()
|
service = get_git_mirror_service()
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
|
|
@ -607,13 +686,16 @@ async def install_plugin(
|
||||||
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 验证插件 ID 格式安全性
|
||||||
|
plugin_id = validate_plugin_id(request.plugin_id)
|
||||||
|
|
||||||
# 推送进度:开始安装
|
# 推送进度:开始安装
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading",
|
||||||
progress=5,
|
progress=5,
|
||||||
message=f"开始安装插件: {request.plugin_id}",
|
message=f"开始安装插件: {plugin_id}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 解析仓库 URL
|
# 1. 解析仓库 URL
|
||||||
|
|
@ -634,27 +716,28 @@ async def install_plugin(
|
||||||
progress=10,
|
progress=10,
|
||||||
message=f"解析仓库信息: {owner}/{repo}",
|
message=f"解析仓库信息: {owner}/{repo}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 确定插件安装路径
|
# 2. 确定插件安装路径
|
||||||
plugins_dir = Path("plugins")
|
plugins_dir = Path("plugins").resolve()
|
||||||
plugins_dir.mkdir(exist_ok=True)
|
plugins_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
|
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
|
||||||
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
|
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
|
||||||
folder_name = request.plugin_id.replace(".", "_")
|
folder_name = plugin_id.replace(".", "_")
|
||||||
target_path = plugins_dir / folder_name
|
# 使用安全路径验证,防止路径遍历
|
||||||
|
target_path = validate_safe_path(folder_name, plugins_dir)
|
||||||
|
|
||||||
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
|
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
|
||||||
old_format_path = plugins_dir / request.plugin_id
|
old_format_path = plugins_dir / plugin_id
|
||||||
if target_path.exists() or old_format_path.exists():
|
if target_path.exists() or old_format_path.exists():
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error",
|
stage="error",
|
||||||
progress=0,
|
progress=0,
|
||||||
message="插件已存在",
|
message="插件已存在",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="插件已安装,请先卸载",
|
error="插件已安装,请先卸载",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="插件已安装")
|
raise HTTPException(status_code=400, detail="插件已安装")
|
||||||
|
|
@ -664,7 +747,7 @@ async def install_plugin(
|
||||||
progress=15,
|
progress=15,
|
||||||
message=f"准备克隆到: {target_path}",
|
message=f"准备克隆到: {target_path}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
||||||
|
|
@ -693,14 +776,14 @@ async def install_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="克隆仓库失败",
|
message="克隆仓库失败",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=error_msg,
|
error=error_msg,
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
# 4. 验证插件完整性
|
# 4. 验证插件完整性
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
|
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
|
||||||
)
|
)
|
||||||
|
|
||||||
manifest_path = target_path / "_manifest.json"
|
manifest_path = target_path / "_manifest.json"
|
||||||
|
|
@ -715,14 +798,14 @@ async def install_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="插件缺少 _manifest.json",
|
message="插件缺少 _manifest.json",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="无效的插件格式",
|
error="无效的插件格式",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
|
|
||||||
# 5. 读取并验证 manifest
|
# 5. 读取并验证 manifest
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
|
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -739,7 +822,7 @@ async def install_plugin(
|
||||||
|
|
||||||
# 将插件 ID 写入 manifest(用于后续准确识别)
|
# 将插件 ID 写入 manifest(用于后续准确识别)
|
||||||
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
||||||
manifest["id"] = request.plugin_id
|
manifest["id"] = plugin_id
|
||||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||||
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
|
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
@ -754,7 +837,7 @@ async def install_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="_manifest.json 无效",
|
message="_manifest.json 无效",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
|
|
@ -765,13 +848,13 @@ async def install_plugin(
|
||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "插件安装成功",
|
"message": "插件安装成功",
|
||||||
"plugin_id": request.plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"plugin_name": manifest["name"],
|
"plugin_name": manifest["name"],
|
||||||
"version": manifest["version"],
|
"version": manifest["version"],
|
||||||
"path": str(target_path),
|
"path": str(target_path),
|
||||||
|
|
@ -787,7 +870,7 @@ async def install_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="安装失败",
|
message="安装失败",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -814,22 +897,26 @@ async def uninstall_plugin(
|
||||||
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 验证插件 ID 格式安全性
|
||||||
|
plugin_id = validate_plugin_id(request.plugin_id)
|
||||||
|
|
||||||
# 推送进度:开始卸载
|
# 推送进度:开始卸载
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading",
|
||||||
progress=10,
|
progress=10,
|
||||||
message=f"开始卸载插件: {request.plugin_id}",
|
message=f"开始卸载插件: {plugin_id}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 检查插件是否存在(支持新旧两种格式)
|
# 1. 检查插件是否存在(支持新旧两种格式)
|
||||||
plugins_dir = Path("plugins")
|
plugins_dir = Path("plugins").resolve()
|
||||||
# 新格式:下划线
|
# 新格式:下划线
|
||||||
folder_name = request.plugin_id.replace(".", "_")
|
folder_name = plugin_id.replace(".", "_")
|
||||||
plugin_path = plugins_dir / folder_name
|
# 使用安全路径验证
|
||||||
|
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||||
# 旧格式:点
|
# 旧格式:点
|
||||||
old_format_path = plugins_dir / request.plugin_id
|
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||||
|
|
||||||
# 优先使用新格式,如果不存在则尝试旧格式
|
# 优先使用新格式,如果不存在则尝试旧格式
|
||||||
if not plugin_path.exists():
|
if not plugin_path.exists():
|
||||||
|
|
@ -841,7 +928,7 @@ async def uninstall_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="插件不存在",
|
message="插件不存在",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="插件未安装或已被删除",
|
error="插件未安装或已被删除",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
|
|
@ -851,12 +938,12 @@ async def uninstall_plugin(
|
||||||
progress=30,
|
progress=30,
|
||||||
message=f"正在删除插件文件: {plugin_path}",
|
message=f"正在删除插件文件: {plugin_path}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 读取插件信息(用于日志)
|
# 2. 读取插件信息(用于日志)
|
||||||
manifest_path = plugin_path / "_manifest.json"
|
manifest_path = plugin_path / "_manifest.json"
|
||||||
plugin_name = request.plugin_id
|
plugin_name = plugin_id
|
||||||
|
|
||||||
if manifest_path.exists():
|
if manifest_path.exists():
|
||||||
try:
|
try:
|
||||||
|
|
@ -864,7 +951,7 @@ async def uninstall_plugin(
|
||||||
|
|
||||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
plugin_name = manifest.get("name", request.plugin_id)
|
plugin_name = manifest.get("name", plugin_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # 如果读取失败,使用插件 ID 作为名称
|
pass # 如果读取失败,使用插件 ID 作为名称
|
||||||
|
|
||||||
|
|
@ -873,7 +960,7 @@ async def uninstall_plugin(
|
||||||
progress=50,
|
progress=50,
|
||||||
message=f"正在删除 {plugin_name}...",
|
message=f"正在删除 {plugin_name}...",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 删除插件目录
|
# 3. 删除插件目录
|
||||||
|
|
@ -889,7 +976,7 @@ async def uninstall_plugin(
|
||||||
|
|
||||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||||
|
|
||||||
logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})")
|
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
|
||||||
|
|
||||||
# 4. 推送成功状态
|
# 4. 推送成功状态
|
||||||
await update_progress(
|
await update_progress(
|
||||||
|
|
@ -897,10 +984,10 @@ async def uninstall_plugin(
|
||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功卸载插件: {plugin_name}",
|
message=f"成功卸载插件: {plugin_name}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
|
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|
@ -912,7 +999,7 @@ async def uninstall_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="卸载失败",
|
message="卸载失败",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="权限不足,无法删除插件文件",
|
error="权限不足,无法删除插件文件",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -925,7 +1012,7 @@ async def uninstall_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="卸载失败",
|
message="卸载失败",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -952,22 +1039,26 @@ async def update_plugin(
|
||||||
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 验证插件 ID 格式安全性
|
||||||
|
plugin_id = validate_plugin_id(request.plugin_id)
|
||||||
|
|
||||||
# 推送进度:开始更新
|
# 推送进度:开始更新
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading",
|
||||||
progress=5,
|
progress=5,
|
||||||
message=f"开始更新插件: {request.plugin_id}",
|
message=f"开始更新插件: {plugin_id}",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 检查插件是否已安装(支持新旧两种格式)
|
# 1. 检查插件是否已安装(支持新旧两种格式)
|
||||||
plugins_dir = Path("plugins")
|
plugins_dir = Path("plugins").resolve()
|
||||||
# 新格式:下划线
|
# 新格式:下划线
|
||||||
folder_name = request.plugin_id.replace(".", "_")
|
folder_name = plugin_id.replace(".", "_")
|
||||||
plugin_path = plugins_dir / folder_name
|
# 使用安全路径验证
|
||||||
|
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||||
# 旧格式:点
|
# 旧格式:点
|
||||||
old_format_path = plugins_dir / request.plugin_id
|
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||||
|
|
||||||
# 优先使用新格式,如果不存在则尝试旧格式
|
# 优先使用新格式,如果不存在则尝试旧格式
|
||||||
if not plugin_path.exists():
|
if not plugin_path.exists():
|
||||||
|
|
@ -979,7 +1070,7 @@ async def update_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="插件不存在",
|
message="插件不存在",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="插件未安装,请先安装",
|
error="插件未安装,请先安装",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
|
|
@ -1003,12 +1094,12 @@ async def update_plugin(
|
||||||
progress=10,
|
progress=10,
|
||||||
message=f"当前版本: {old_version},准备更新...",
|
message=f"当前版本: {old_version},准备更新...",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 删除旧版本
|
# 3. 删除旧版本
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
|
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
|
||||||
)
|
)
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
|
|
@ -1023,7 +1114,7 @@ async def update_plugin(
|
||||||
|
|
||||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||||
|
|
||||||
logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}")
|
logger.info(f"已删除旧版本: {plugin_id} v{old_version}")
|
||||||
|
|
||||||
# 4. 解析仓库 URL
|
# 4. 解析仓库 URL
|
||||||
await update_progress(
|
await update_progress(
|
||||||
|
|
@ -1031,7 +1122,7 @@ async def update_plugin(
|
||||||
progress=30,
|
progress=30,
|
||||||
message="正在准备下载新版本...",
|
message="正在准备下载新版本...",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
repo_url = request.repository_url.rstrip("/")
|
repo_url = request.repository_url.rstrip("/")
|
||||||
|
|
@ -1069,14 +1160,14 @@ async def update_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="下载新版本失败",
|
message="下载新版本失败",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=error_msg,
|
error=error_msg,
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
# 6. 验证新版本
|
# 6. 验证新版本
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
|
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
|
||||||
)
|
)
|
||||||
|
|
||||||
new_manifest_path = plugin_path / "_manifest.json"
|
new_manifest_path = plugin_path / "_manifest.json"
|
||||||
|
|
@ -1096,7 +1187,7 @@ async def update_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="新版本缺少 _manifest.json",
|
message="新版本缺少 _manifest.json",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="无效的插件格式",
|
error="无效的插件格式",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
|
|
@ -1107,9 +1198,9 @@ async def update_plugin(
|
||||||
new_manifest = json_module.load(f)
|
new_manifest = json_module.load(f)
|
||||||
|
|
||||||
new_version = new_manifest.get("version", "unknown")
|
new_version = new_manifest.get("version", "unknown")
|
||||||
new_name = new_manifest.get("name", request.plugin_id)
|
new_name = new_manifest.get("name", plugin_id)
|
||||||
|
|
||||||
logger.info(f"成功更新插件: {request.plugin_id} {old_version} → {new_version}")
|
logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}")
|
||||||
|
|
||||||
# 8. 推送成功状态
|
# 8. 推送成功状态
|
||||||
await update_progress(
|
await update_progress(
|
||||||
|
|
@ -1117,13 +1208,13 @@ async def update_plugin(
|
||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "插件更新成功",
|
"message": "插件更新成功",
|
||||||
"plugin_id": request.plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"plugin_name": new_name,
|
"plugin_name": new_name,
|
||||||
"old_version": old_version,
|
"old_version": old_version,
|
||||||
"new_version": new_version,
|
"new_version": new_version,
|
||||||
|
|
@ -1138,7 +1229,7 @@ async def update_plugin(
|
||||||
progress=0,
|
progress=0,
|
||||||
message="_manifest.json 无效",
|
message="_manifest.json 无效",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
|
|
@ -1149,7 +1240,7 @@ async def update_plugin(
|
||||||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||||
|
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
|
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from .plugin_routes import router as plugin_router
|
||||||
from .plugin_progress_ws import get_progress_router
|
from .plugin_progress_ws import get_progress_router
|
||||||
from .routers.system import router as system_router
|
from .routers.system import router as system_router
|
||||||
from .model_routes import router as model_router
|
from .model_routes import router as model_router
|
||||||
|
from .ws_auth import router as ws_auth_router
|
||||||
|
|
||||||
logger = get_logger("webui.api")
|
logger = get_logger("webui.api")
|
||||||
|
|
||||||
|
|
@ -43,6 +44,8 @@ router.include_router(get_progress_router())
|
||||||
router.include_router(system_router)
|
router.include_router(system_router)
|
||||||
# 注册模型列表获取路由
|
# 注册模型列表获取路由
|
||||||
router.include_router(model_router)
|
router.include_router(model_router)
|
||||||
|
# 注册 WebSocket 认证路由
|
||||||
|
router.include_router(ws_auth_router)
|
||||||
|
|
||||||
|
|
||||||
class TokenVerifyRequest(BaseModel):
|
class TokenVerifyRequest(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,8 @@ class WebUIServer:
|
||||||
allow_origins=[
|
allow_origins=[
|
||||||
"http://localhost:5173", # Vite 开发服务器
|
"http://localhost:5173", # Vite 开发服务器
|
||||||
"http://127.0.0.1:5173",
|
"http://127.0.0.1:5173",
|
||||||
|
"http://localhost:7999", # 前端开发服务器备用端口
|
||||||
|
"http://127.0.0.1:7999",
|
||||||
"http://localhost:8001", # 生产环境
|
"http://localhost:8001", # 生产环境
|
||||||
"http://127.0.0.1:8001",
|
"http://127.0.0.1:8001",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
"""WebSocket 认证模块
|
||||||
|
|
||||||
|
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
|
||||||
|
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Cookie, Header, HTTPException
|
||||||
|
from typing import Optional
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.token_manager import get_token_manager
|
||||||
|
|
||||||
|
logger = get_logger("webui.ws_auth")
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
||||||
|
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
||||||
|
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
|
||||||
|
_WS_TOKEN_EXPIRE_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_expired_ws_tokens():
|
||||||
|
"""清理过期的临时 token"""
|
||||||
|
now = time.time()
|
||||||
|
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
|
||||||
|
for t in expired:
|
||||||
|
del _ws_temp_tokens[t]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ws_token(session_token: str) -> str:
|
||||||
|
"""生成 WebSocket 临时 token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_token: 原始的 session token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
临时 token 字符串
|
||||||
|
"""
|
||||||
|
_cleanup_expired_ws_tokens()
|
||||||
|
temp_token = secrets.token_urlsafe(32)
|
||||||
|
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
|
||||||
|
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
|
||||||
|
return temp_token
|
||||||
|
|
||||||
|
|
||||||
|
def verify_ws_token(temp_token: str) -> bool:
|
||||||
|
"""验证并消费 WebSocket 临时 token(一次性使用)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temp_token: 临时 token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证是否通过
|
||||||
|
"""
|
||||||
|
_cleanup_expired_ws_tokens()
|
||||||
|
if temp_token not in _ws_temp_tokens:
|
||||||
|
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
|
||||||
|
return False
|
||||||
|
expire_time, session_token = _ws_temp_tokens[temp_token]
|
||||||
|
if time.time() > expire_time:
|
||||||
|
del _ws_temp_tokens[temp_token]
|
||||||
|
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
|
||||||
|
return False
|
||||||
|
# 验证原始 session token 仍然有效
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token_manager.verify_token(session_token):
|
||||||
|
del _ws_temp_tokens[temp_token]
|
||||||
|
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
|
||||||
|
return False
|
||||||
|
# 消费 token(一次性使用)
|
||||||
|
del _ws_temp_tokens[temp_token]
|
||||||
|
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/ws-token")
|
||||||
|
async def get_ws_token(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取 WebSocket 连接用的临时 token
|
||||||
|
|
||||||
|
此端点验证当前会话的 Cookie 或 Authorization header,
|
||||||
|
然后返回一个临时 token 用于 WebSocket 握手认证。
|
||||||
|
临时 token 有效期 60 秒,且只能使用一次。
|
||||||
|
"""
|
||||||
|
# 获取当前 session token
|
||||||
|
session_token = None
|
||||||
|
if maibot_session:
|
||||||
|
session_token = maibot_session
|
||||||
|
elif authorization and authorization.startswith("Bearer "):
|
||||||
|
session_token = authorization.replace("Bearer ", "")
|
||||||
|
|
||||||
|
if not session_token:
|
||||||
|
raise HTTPException(status_code=401, detail="未提供认证信息")
|
||||||
|
|
||||||
|
# 验证 session token
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token_manager.verify_token(session_token):
|
||||||
|
raise HTTPException(status_code=401, detail="认证已过期,请重新登录")
|
||||||
|
|
||||||
|
# 生成临时 WebSocket token
|
||||||
|
ws_token = generate_ws_token(session_token)
|
||||||
|
|
||||||
|
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}
|
||||||
Loading…
Reference in New Issue