mirror of https://github.com/Mai-with-u/MaiBot.git
添加 WebSocket 认证模块,支持临时 token 认证机制,增强安全性并解决 Cookie 不可用问题
parent
ea420f9f59
commit
6055b087f0
|
|
@ -17,6 +17,7 @@ from src.config.config import global_config
|
|||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
|
|
@ -398,23 +399,40 @@ async def websocket_chat(
|
|||
token: 认证 token(可选,也可从 Cookie 获取)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/api/chat/ws?token=xxx
|
||||
"""
|
||||
# 认证检查
|
||||
auth_token = token
|
||||
if not auth_token:
|
||||
# 尝试从 Cookie 获取 token
|
||||
auth_token = websocket.cookies.get("maibot_session")
|
||||
is_authenticated = False
|
||||
|
||||
if not auth_token:
|
||||
logger.warning("WebSocket 聊天连接被拒绝:未提供认证 token")
|
||||
await websocket.close(code=4001, reason="未提供认证信息")
|
||||
return
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(auth_token):
|
||||
logger.warning("WebSocket 聊天连接被拒绝:token 无效")
|
||||
await websocket.close(code=4003, reason="Token 无效或已过期")
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import json
|
|||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.logs_ws")
|
||||
router = APIRouter()
|
||||
|
|
@ -78,23 +79,39 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
|
|||
"""WebSocket 日志推送端点
|
||||
|
||||
客户端连接后会持续接收服务器端的日志消息
|
||||
需要通过 query 参数传递 token 进行认证,例如:ws://host/ws/logs?token=xxx
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/logs?token=xxx
|
||||
"""
|
||||
# 认证检查
|
||||
if not token:
|
||||
# 尝试从 Cookie 获取 token
|
||||
token = websocket.cookies.get("maibot_session")
|
||||
is_authenticated = False
|
||||
|
||||
if not token:
|
||||
logger.warning("WebSocket 连接被拒绝:未提供认证 token")
|
||||
await websocket.close(code=4001, reason="未提供认证信息")
|
||||
return
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
logger.warning("WebSocket 连接被拒绝:token 无效")
|
||||
await websocket.close(code=4003, reason="Token 无效或已过期")
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import json
|
|||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.plugin_progress")
|
||||
|
||||
|
|
@ -94,24 +95,39 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
|||
"""WebSocket 插件加载进度推送端点
|
||||
|
||||
客户端连接后会立即收到当前进度状态
|
||||
需要通过 query 参数或 Cookie 传递 token 进行认证
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/plugin-progress?token=xxx
|
||||
"""
|
||||
# 认证检查
|
||||
auth_token = token
|
||||
if not auth_token:
|
||||
# 尝试从 Cookie 获取 token
|
||||
auth_token = websocket.cookies.get("maibot_session")
|
||||
is_authenticated = False
|
||||
|
||||
if not auth_token:
|
||||
logger.warning("插件进度 WebSocket 连接被拒绝:未提供认证 token")
|
||||
await websocket.close(code=4001, reason="未提供认证信息")
|
||||
return
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(auth_token):
|
||||
logger.warning("插件进度 WebSocket 连接被拒绝:token 无效")
|
||||
await websocket.close(code=4003, reason="Token 无效或已过期")
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
|
|||
from typing import Optional, List, Dict, Any, get_origin
|
||||
from pathlib import Path
|
||||
import json
|
||||
import re
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import MMC_VERSION
|
||||
|
|
@ -34,6 +35,85 @@ def get_token_from_cookie_or_header(
|
|||
return None
|
||||
|
||||
|
||||
def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
||||
"""
|
||||
验证用户提供的路径是否安全,防止路径遍历攻击
|
||||
|
||||
Args:
|
||||
user_path: 用户输入的路径(相对路径)
|
||||
base_path: 允许的基础目录
|
||||
|
||||
Returns:
|
||||
安全的绝对路径
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果检测到路径遍历攻击
|
||||
"""
|
||||
# 规范化基础路径
|
||||
base_resolved = base_path.resolve()
|
||||
|
||||
# 检查用户路径是否包含可疑字符
|
||||
# 禁止: .., 绝对路径开头, 空字节等
|
||||
if any(pattern in user_path for pattern in ["..", "\x00"]):
|
||||
logger.warning(f"检测到可疑路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="路径包含非法字符")
|
||||
|
||||
# 检查是否为绝对路径(Windows 和 Unix)
|
||||
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
|
||||
logger.warning(f"检测到绝对路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
|
||||
|
||||
# 构建目标路径并解析
|
||||
target_path = (base_path / user_path).resolve()
|
||||
|
||||
# 验证解析后的路径仍在基础目录内
|
||||
try:
|
||||
target_path.relative_to(base_resolved)
|
||||
except ValueError as e:
|
||||
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
|
||||
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
|
||||
|
||||
return target_path
|
||||
|
||||
|
||||
def validate_plugin_id(plugin_id: str) -> str:
|
||||
"""
|
||||
验证插件 ID 格式是否安全
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID (支持 author.name 格式,允许中文)
|
||||
|
||||
Returns:
|
||||
验证通过的插件 ID
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果插件 ID 格式不安全
|
||||
"""
|
||||
# 禁止空字符串
|
||||
if not plugin_id or not plugin_id.strip():
|
||||
logger.warning("非法插件 ID: 空字符串")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
|
||||
|
||||
# 禁止危险字符: 路径分隔符、空字节、控制字符等
|
||||
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in plugin_id:
|
||||
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
|
||||
|
||||
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
|
||||
if plugin_id.startswith(".") or plugin_id.endswith("."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
|
||||
|
||||
# 禁止特殊名称
|
||||
if plugin_id in (".", ".."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
|
||||
|
||||
return plugin_id
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
解析版本号字符串
|
||||
|
|
@ -468,17 +548,16 @@ async def fetch_raw_file(
|
|||
|
||||
支持多镜像源自动切换和错误重试
|
||||
|
||||
注意:此接口可公开访问,用于获取插件仓库等公开资源
|
||||
需要认证才能访问,防止被滥用作为 SSRF 跳板
|
||||
"""
|
||||
# Token 验证(可选,用于日志记录)
|
||||
# Token 验证(强制)
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
is_authenticated = token and token_manager.verify_token(token)
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
# 对于公开仓库的访问,不强制要求认证
|
||||
# 只在日志中记录是否认证
|
||||
logger.info(
|
||||
f"收到获取 Raw 文件请求 (认证: {is_authenticated}): "
|
||||
f"收到获取 Raw 文件请求: "
|
||||
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
|
||||
)
|
||||
|
||||
|
|
@ -564,10 +643,10 @@ async def clone_repository(
|
|||
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
||||
|
||||
try:
|
||||
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
||||
# TODO: 确定实际的插件目录基路径
|
||||
base_plugin_path = Path("./plugins") # 临时路径
|
||||
target_path = base_plugin_path / request.target_path
|
||||
# 验证 target_path 的安全性,防止路径遍历攻击
|
||||
base_plugin_path = Path("./plugins").resolve()
|
||||
base_plugin_path.mkdir(exist_ok=True)
|
||||
target_path = validate_safe_path(request.target_path, base_plugin_path)
|
||||
|
||||
service = get_git_mirror_service()
|
||||
result = await service.clone_repository(
|
||||
|
|
@ -607,13 +686,16 @@ async def install_plugin(
|
|||
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
||||
|
||||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
# 推送进度:开始安装
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=5,
|
||||
message=f"开始安装插件: {request.plugin_id}",
|
||||
message=f"开始安装插件: {plugin_id}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 1. 解析仓库 URL
|
||||
|
|
@ -634,27 +716,28 @@ async def install_plugin(
|
|||
progress=10,
|
||||
message=f"解析仓库信息: {owner}/{repo}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 2. 确定插件安装路径
|
||||
plugins_dir = Path("plugins")
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
plugins_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
|
||||
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
|
||||
folder_name = request.plugin_id.replace(".", "_")
|
||||
target_path = plugins_dir / folder_name
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
# 使用安全路径验证,防止路径遍历
|
||||
target_path = validate_safe_path(folder_name, plugins_dir)
|
||||
|
||||
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
old_format_path = plugins_dir / plugin_id
|
||||
if target_path.exists() or old_format_path.exists():
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="插件已存在",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="插件已安装,请先卸载",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="插件已安装")
|
||||
|
|
@ -664,7 +747,7 @@ async def install_plugin(
|
|||
progress=15,
|
||||
message=f"准备克隆到: {target_path}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
||||
|
|
@ -693,14 +776,14 @@ async def install_plugin(
|
|||
progress=0,
|
||||
message="克隆仓库失败",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# 4. 验证插件完整性
|
||||
await update_progress(
|
||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
|
||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
manifest_path = target_path / "_manifest.json"
|
||||
|
|
@ -715,14 +798,14 @@ async def install_plugin(
|
|||
progress=0,
|
||||
message="插件缺少 _manifest.json",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
||||
# 5. 读取并验证 manifest
|
||||
await update_progress(
|
||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
|
||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -739,7 +822,7 @@ async def install_plugin(
|
|||
|
||||
# 将插件 ID 写入 manifest(用于后续准确识别)
|
||||
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
||||
manifest["id"] = request.plugin_id
|
||||
manifest["id"] = plugin_id
|
||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
|
@ -754,7 +837,7 @@ async def install_plugin(
|
|||
progress=0,
|
||||
message="_manifest.json 无效",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
|
|
@ -765,13 +848,13 @@ async def install_plugin(
|
|||
progress=100,
|
||||
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件安装成功",
|
||||
"plugin_id": request.plugin_id,
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": manifest["name"],
|
||||
"version": manifest["version"],
|
||||
"path": str(target_path),
|
||||
|
|
@ -787,7 +870,7 @@ async def install_plugin(
|
|||
progress=0,
|
||||
message="安装失败",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
|
@ -814,22 +897,26 @@ async def uninstall_plugin(
|
|||
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
||||
|
||||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
# 推送进度:开始卸载
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=10,
|
||||
message=f"开始卸载插件: {request.plugin_id}",
|
||||
message=f"开始卸载插件: {plugin_id}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 1. 检查插件是否存在(支持新旧两种格式)
|
||||
plugins_dir = Path("plugins")
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
# 新格式:下划线
|
||||
folder_name = request.plugin_id.replace(".", "_")
|
||||
plugin_path = plugins_dir / folder_name
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
# 使用安全路径验证
|
||||
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
|
|
@ -841,7 +928,7 @@ async def uninstall_plugin(
|
|||
progress=0,
|
||||
message="插件不存在",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="插件未安装或已被删除",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
|
@ -851,12 +938,12 @@ async def uninstall_plugin(
|
|||
progress=30,
|
||||
message=f"正在删除插件文件: {plugin_path}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 2. 读取插件信息(用于日志)
|
||||
manifest_path = plugin_path / "_manifest.json"
|
||||
plugin_name = request.plugin_id
|
||||
plugin_name = plugin_id
|
||||
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
|
|
@ -864,7 +951,7 @@ async def uninstall_plugin(
|
|||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json_module.load(f)
|
||||
plugin_name = manifest.get("name", request.plugin_id)
|
||||
plugin_name = manifest.get("name", plugin_id)
|
||||
except Exception:
|
||||
pass # 如果读取失败,使用插件 ID 作为名称
|
||||
|
||||
|
|
@ -873,7 +960,7 @@ async def uninstall_plugin(
|
|||
progress=50,
|
||||
message=f"正在删除 {plugin_name}...",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 3. 删除插件目录
|
||||
|
|
@ -889,7 +976,7 @@ async def uninstall_plugin(
|
|||
|
||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||
|
||||
logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})")
|
||||
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
|
||||
|
||||
# 4. 推送成功状态
|
||||
await update_progress(
|
||||
|
|
@ -897,10 +984,10 @@ async def uninstall_plugin(
|
|||
progress=100,
|
||||
message=f"成功卸载插件: {plugin_name}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
|
||||
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -912,7 +999,7 @@ async def uninstall_plugin(
|
|||
progress=0,
|
||||
message="卸载失败",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="权限不足,无法删除插件文件",
|
||||
)
|
||||
|
||||
|
|
@ -925,7 +1012,7 @@ async def uninstall_plugin(
|
|||
progress=0,
|
||||
message="卸载失败",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
|
@ -952,22 +1039,26 @@ async def update_plugin(
|
|||
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
||||
|
||||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
# 推送进度:开始更新
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=5,
|
||||
message=f"开始更新插件: {request.plugin_id}",
|
||||
message=f"开始更新插件: {plugin_id}",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 1. 检查插件是否已安装(支持新旧两种格式)
|
||||
plugins_dir = Path("plugins")
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
# 新格式:下划线
|
||||
folder_name = request.plugin_id.replace(".", "_")
|
||||
plugin_path = plugins_dir / folder_name
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
# 使用安全路径验证
|
||||
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
|
|
@ -979,7 +1070,7 @@ async def update_plugin(
|
|||
progress=0,
|
||||
message="插件不存在",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="插件未安装,请先安装",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
|
@ -1003,12 +1094,12 @@ async def update_plugin(
|
|||
progress=10,
|
||||
message=f"当前版本: {old_version},准备更新...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 3. 删除旧版本
|
||||
await update_progress(
|
||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
|
||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
import shutil
|
||||
|
|
@ -1023,7 +1114,7 @@ async def update_plugin(
|
|||
|
||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||
|
||||
logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}")
|
||||
logger.info(f"已删除旧版本: {plugin_id} v{old_version}")
|
||||
|
||||
# 4. 解析仓库 URL
|
||||
await update_progress(
|
||||
|
|
@ -1031,7 +1122,7 @@ async def update_plugin(
|
|||
progress=30,
|
||||
message="正在准备下载新版本...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
repo_url = request.repository_url.rstrip("/")
|
||||
|
|
@ -1069,14 +1160,14 @@ async def update_plugin(
|
|||
progress=0,
|
||||
message="下载新版本失败",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# 6. 验证新版本
|
||||
await update_progress(
|
||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
|
||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
new_manifest_path = plugin_path / "_manifest.json"
|
||||
|
|
@ -1096,7 +1187,7 @@ async def update_plugin(
|
|||
progress=0,
|
||||
message="新版本缺少 _manifest.json",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
|
@ -1107,9 +1198,9 @@ async def update_plugin(
|
|||
new_manifest = json_module.load(f)
|
||||
|
||||
new_version = new_manifest.get("version", "unknown")
|
||||
new_name = new_manifest.get("name", request.plugin_id)
|
||||
new_name = new_manifest.get("name", plugin_id)
|
||||
|
||||
logger.info(f"成功更新插件: {request.plugin_id} {old_version} → {new_version}")
|
||||
logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}")
|
||||
|
||||
# 8. 推送成功状态
|
||||
await update_progress(
|
||||
|
|
@ -1117,13 +1208,13 @@ async def update_plugin(
|
|||
progress=100,
|
||||
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件更新成功",
|
||||
"plugin_id": request.plugin_id,
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": new_name,
|
||||
"old_version": old_version,
|
||||
"new_version": new_version,
|
||||
|
|
@ -1138,7 +1229,7 @@ async def update_plugin(
|
|||
progress=0,
|
||||
message="_manifest.json 无效",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
|
|
@ -1149,7 +1240,7 @@ async def update_plugin(
|
|||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||
|
||||
await update_progress(
|
||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
|
||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from .plugin_routes import router as plugin_router
|
|||
from .plugin_progress_ws import get_progress_router
|
||||
from .routers.system import router as system_router
|
||||
from .model_routes import router as model_router
|
||||
from .ws_auth import router as ws_auth_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
|
|
@ -43,6 +44,8 @@ router.include_router(get_progress_router())
|
|||
router.include_router(system_router)
|
||||
# 注册模型列表获取路由
|
||||
router.include_router(model_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
|
|
|
|||
|
|
@ -40,6 +40,8 @@ class WebUIServer:
|
|||
allow_origins=[
|
||||
"http://localhost:5173", # Vite 开发服务器
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:7999", # 前端开发服务器备用端口
|
||||
"http://127.0.0.1:7999",
|
||||
"http://localhost:8001", # 生产环境
|
||||
"http://127.0.0.1:8001",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,107 @@
|
|||
"""WebSocket 认证模块
|
||||
|
||||
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
|
||||
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Cookie, Header, HTTPException
|
||||
from typing import Optional
|
||||
import secrets
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.ws_auth")
|
||||
router = APIRouter()
|
||||
|
||||
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
||||
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
||||
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
|
||||
_WS_TOKEN_EXPIRE_SECONDS = 60
|
||||
|
||||
|
||||
def _cleanup_expired_ws_tokens():
|
||||
"""清理过期的临时 token"""
|
||||
now = time.time()
|
||||
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
|
||||
for t in expired:
|
||||
del _ws_temp_tokens[t]
|
||||
|
||||
|
||||
def generate_ws_token(session_token: str) -> str:
|
||||
"""生成 WebSocket 临时 token
|
||||
|
||||
Args:
|
||||
session_token: 原始的 session token
|
||||
|
||||
Returns:
|
||||
临时 token 字符串
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
temp_token = secrets.token_urlsafe(32)
|
||||
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
|
||||
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
|
||||
return temp_token
|
||||
|
||||
|
||||
def verify_ws_token(temp_token: str) -> bool:
|
||||
"""验证并消费 WebSocket 临时 token(一次性使用)
|
||||
|
||||
Args:
|
||||
temp_token: 临时 token
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
if temp_token not in _ws_temp_tokens:
|
||||
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
|
||||
return False
|
||||
expire_time, session_token = _ws_temp_tokens[temp_token]
|
||||
if time.time() > expire_time:
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
|
||||
return False
|
||||
# 验证原始 session token 仍然有效
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
|
||||
return False
|
||||
# 消费 token(一次性使用)
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/ws-token")
|
||||
async def get_ws_token(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取 WebSocket 连接用的临时 token
|
||||
|
||||
此端点验证当前会话的 Cookie 或 Authorization header,
|
||||
然后返回一个临时 token 用于 WebSocket 握手认证。
|
||||
临时 token 有效期 60 秒,且只能使用一次。
|
||||
"""
|
||||
# 获取当前 session token
|
||||
session_token = None
|
||||
if maibot_session:
|
||||
session_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
session_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not session_token:
|
||||
raise HTTPException(status_code=401, detail="未提供认证信息")
|
||||
|
||||
# 验证 session token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
raise HTTPException(status_code=401, detail="认证已过期,请重新登录")
|
||||
|
||||
# 生成临时 WebSocket token
|
||||
ws_token = generate_ws_token(session_token)
|
||||
|
||||
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}
|
||||
Loading…
Reference in New Issue