mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'Mai-with-u:dev' into dev
commit
72786687b9
|
|
@ -3,6 +3,7 @@ WebUI 认证模块
|
|||
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||
from src.common.logger import get_logger
|
||||
|
|
@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session"
|
|||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||
|
||||
|
||||
def _is_secure_environment() -> bool:
|
||||
"""
|
||||
检测是否应该启用安全 Cookie(HTTPS)
|
||||
|
||||
Returns:
|
||||
bool: 如果应该使用 secure cookie 则返回 True
|
||||
"""
|
||||
# 检查环境变量
|
||||
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("true", "1", "yes"):
|
||||
return True
|
||||
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
|
||||
return False
|
||||
|
||||
# 检查是否是生产环境
|
||||
env = os.environ.get("WEBUI_MODE", "").lower()
|
||||
if env in ("production", "prod"):
|
||||
return True
|
||||
|
||||
# 默认:开发环境不启用(因为通常是 HTTP)
|
||||
return False
|
||||
|
||||
|
||||
def get_current_token(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
|
|
@ -62,16 +85,19 @@ def set_auth_cookie(response: Response, token: str) -> None:
|
|||
response: FastAPI Response 对象
|
||||
token: 要设置的 token
|
||||
"""
|
||||
# 根据环境决定安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True, # 防止 JS 读取
|
||||
samesite="lax", # 允许同站导航时发送 Cookie(兼容开发环境代理)
|
||||
secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True
|
||||
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
|
||||
samesite="strict" if is_secure else "lax", # 生产环境使用 strict 防止 CSRF
|
||||
secure=is_secure, # 生产环境强制 HTTPS
|
||||
path="/", # 确保 Cookie 在所有路径下可用
|
||||
)
|
||||
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
|
||||
logger.debug(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure})")
|
||||
|
||||
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
|
|
@ -81,10 +107,14 @@ def clear_auth_cookie(response: Response) -> None:
|
|||
Args:
|
||||
response: FastAPI Response 对象
|
||||
"""
|
||||
# 保持与 set_auth_cookie 相同的安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
response.delete_cookie(
|
||||
key=COOKIE_NAME,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
samesite="strict" if is_secure else "lax",
|
||||
secure=is_secure,
|
||||
path="/",
|
||||
)
|
||||
logger.debug("已清除认证 Cookie")
|
||||
|
|
|
|||
|
|
@ -8,18 +8,29 @@
|
|||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
# WebUI 聊天的虚拟群组 ID
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
|
|
@ -256,6 +267,7 @@ async def get_chat_history(
|
|||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取聊天历史记录
|
||||
|
||||
|
|
@ -272,7 +284,7 @@ async def get_chat_history(
|
|||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms():
|
||||
async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
||||
"""获取可用平台列表
|
||||
|
||||
从 PersonInfo 表中获取所有已知的平台
|
||||
|
|
@ -303,6 +315,7 @@ async def get_persons_by_platform(
|
|||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取指定平台的用户列表
|
||||
|
||||
|
|
@ -350,7 +363,7 @@ async def get_persons_by_platform(
|
|||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
|
|
@ -372,6 +385,7 @@ async def websocket_chat(
|
|||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||
token: Optional[str] = Query(default=None), # 认证 token
|
||||
):
|
||||
"""WebSocket 聊天端点
|
||||
|
||||
|
|
@ -382,9 +396,45 @@ async def websocket_chat(
|
|||
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||
group_name: 虚拟身份模式的群名(可选)
|
||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||
token: 认证 token(可选,也可从 Cookie 获取)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/api/chat/ws?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
|
|
@ -712,7 +762,7 @@ async def websocket_chat(
|
|||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info():
|
||||
async def get_chat_info(_auth: bool = Depends(require_auth)):
|
||||
"""获取聊天室信息"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@
|
|||
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from typing import Any, Annotated
|
||||
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.official_configs import (
|
||||
|
|
@ -49,11 +50,19 @@ PathBody = Annotated[dict[str, str], Body()]
|
|||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/bot")
|
||||
async def get_bot_config_schema():
|
||||
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置架构"""
|
||||
try:
|
||||
# Config 类包含所有子配置
|
||||
|
|
@ -65,7 +74,7 @@ async def get_bot_config_schema():
|
|||
|
||||
|
||||
@router.get("/schema/model")
|
||||
async def get_model_config_schema():
|
||||
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||
|
|
@ -79,7 +88,7 @@ async def get_model_config_schema():
|
|||
|
||||
|
||||
@router.get("/schema/section/{section_name}")
|
||||
async def get_config_section_schema(section_name: str):
|
||||
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取指定配置节的架构
|
||||
|
||||
|
|
@ -149,7 +158,7 @@ async def get_config_section_schema(section_name: str):
|
|||
|
||||
|
||||
@router.get("/bot")
|
||||
async def get_bot_config():
|
||||
async def get_bot_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
|
|
@ -168,7 +177,7 @@ async def get_bot_config():
|
|||
|
||||
|
||||
@router.get("/model")
|
||||
async def get_model_config():
|
||||
async def get_model_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
|
|
@ -190,7 +199,7 @@ async def get_model_config():
|
|||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: ConfigBody):
|
||||
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
|
|
@ -213,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody):
|
|||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: ConfigBody):
|
||||
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
|
|
@ -239,7 +248,7 @@ async def update_model_config(config_data: ConfigBody):
|
|||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
|
|
@ -288,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
|||
|
||||
|
||||
@router.get("/bot/raw")
|
||||
async def get_bot_config_raw():
|
||||
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
|
|
@ -307,7 +316,7 @@ async def get_bot_config_raw():
|
|||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: RawContentBody):
|
||||
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
|
|
@ -337,7 +346,7 @@ async def update_bot_config_raw(raw_content: RawContentBody):
|
|||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
||||
async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
|
|
@ -430,7 +439,7 @@ def _to_relative_path(path: str) -> str:
|
|||
|
||||
|
||||
@router.get("/adapter-config/path")
|
||||
async def get_adapter_config_path():
|
||||
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
|
||||
"""获取保存的适配器配置文件路径"""
|
||||
try:
|
||||
# 从 data/webui.json 读取路径偏好
|
||||
|
|
@ -469,7 +478,7 @@ async def get_adapter_config_path():
|
|||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: PathBody):
|
||||
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
|
|
@ -512,7 +521,7 @@ async def save_adapter_config_path(data: PathBody):
|
|||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
async def get_adapter_config(path: str):
|
||||
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
|
||||
"""从指定路径读取适配器配置文件"""
|
||||
try:
|
||||
if not path:
|
||||
|
|
@ -544,7 +553,7 @@ async def get_adapter_config(path: str):
|
|||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: PathBody):
|
||||
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
|
|
|
|||
|
|
@ -1,15 +1,24 @@
|
|||
"""知识库图谱可视化 API 路由"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi import APIRouter, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class KnowledgeNode(BaseModel):
|
||||
"""知识节点"""
|
||||
|
||||
|
|
@ -113,6 +122,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
|||
async def get_knowledge_graph(
|
||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取知识图谱(限制节点数量)
|
||||
|
||||
|
|
@ -199,7 +209,7 @@ async def get_knowledge_graph(
|
|||
|
||||
|
||||
@router.get("/stats", response_model=KnowledgeStats)
|
||||
async def get_knowledge_stats():
|
||||
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
|
||||
"""获取知识库统计信息
|
||||
|
||||
Returns:
|
||||
|
|
@ -248,7 +258,7 @@ async def get_knowledge_stats():
|
|||
|
||||
|
||||
@router.get("/search", response_model=List[KnowledgeNode])
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
|
||||
"""搜索知识节点
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""WebSocket 日志推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Optional
|
||||
import json
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.logs_ws")
|
||||
router = APIRouter()
|
||||
|
|
@ -73,14 +75,48 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
|||
|
||||
|
||||
@router.websocket("/ws/logs")
|
||||
async def websocket_logs(websocket: WebSocket):
|
||||
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 日志推送端点
|
||||
|
||||
客户端连接后会持续接收服务器端的日志消息
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/logs?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
||||
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
# 连接建立后,立即发送历史日志
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -6,18 +6,27 @@
|
|||
|
||||
import os
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
|
||||
from typing import Optional
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# 模型获取器配置
|
||||
MODEL_FETCHER_CONFIG = {
|
||||
# OpenAI 兼容格式的提供商
|
||||
|
|
@ -184,6 +193,7 @@ async def get_provider_models(
|
|||
provider_name: str = Query(..., description="提供商名称"),
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
|
|
@ -228,6 +238,7 @@ async def get_models_by_url(
|
|||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||
|
|
@ -251,6 +262,7 @@ async def get_models_by_url(
|
|||
async def test_provider_connection(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
|
|
@ -337,6 +349,7 @@ async def test_provider_connection(
|
|||
@router.post("/test-connection-by-name")
|
||||
async def test_provider_connection_by_name(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过提供商名称测试连接(从配置文件读取信息)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""WebSocket 插件加载进度推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set, Dict, Any
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Dict, Any, Optional
|
||||
import json
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.plugin_progress")
|
||||
|
||||
|
|
@ -89,14 +91,48 @@ async def update_progress(
|
|||
|
||||
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket):
|
||||
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 插件加载进度推送端点
|
||||
|
||||
客户端连接后会立即收到当前进度状态
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/plugin-progress?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
try:
|
||||
# 发送当前进度状态
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
|
|||
from typing import Optional, List, Dict, Any, get_origin
|
||||
from pathlib import Path
|
||||
import json
|
||||
import re
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import MMC_VERSION
|
||||
|
|
@ -34,6 +35,85 @@ def get_token_from_cookie_or_header(
|
|||
return None
|
||||
|
||||
|
||||
def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
||||
"""
|
||||
验证用户提供的路径是否安全,防止路径遍历攻击
|
||||
|
||||
Args:
|
||||
user_path: 用户输入的路径(相对路径)
|
||||
base_path: 允许的基础目录
|
||||
|
||||
Returns:
|
||||
安全的绝对路径
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果检测到路径遍历攻击
|
||||
"""
|
||||
# 规范化基础路径
|
||||
base_resolved = base_path.resolve()
|
||||
|
||||
# 检查用户路径是否包含可疑字符
|
||||
# 禁止: .., 绝对路径开头, 空字节等
|
||||
if any(pattern in user_path for pattern in ["..", "\x00"]):
|
||||
logger.warning(f"检测到可疑路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="路径包含非法字符")
|
||||
|
||||
# 检查是否为绝对路径(Windows 和 Unix)
|
||||
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
|
||||
logger.warning(f"检测到绝对路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
|
||||
|
||||
# 构建目标路径并解析
|
||||
target_path = (base_path / user_path).resolve()
|
||||
|
||||
# 验证解析后的路径仍在基础目录内
|
||||
try:
|
||||
target_path.relative_to(base_resolved)
|
||||
except ValueError as e:
|
||||
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
|
||||
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
|
||||
|
||||
return target_path
|
||||
|
||||
|
||||
def validate_plugin_id(plugin_id: str) -> str:
|
||||
"""
|
||||
验证插件 ID 格式是否安全
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID (支持 author.name 格式,允许中文)
|
||||
|
||||
Returns:
|
||||
验证通过的插件 ID
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果插件 ID 格式不安全
|
||||
"""
|
||||
# 禁止空字符串
|
||||
if not plugin_id or not plugin_id.strip():
|
||||
logger.warning("非法插件 ID: 空字符串")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
|
||||
|
||||
# 禁止危险字符: 路径分隔符、空字节、控制字符等
|
||||
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in plugin_id:
|
||||
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
|
||||
|
||||
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
|
||||
if plugin_id.startswith(".") or plugin_id.endswith("."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
|
||||
|
||||
# 禁止特殊名称
|
||||
if plugin_id in (".", ".."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
|
||||
|
||||
return plugin_id
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
解析版本号字符串
|
||||
|
|
@ -468,17 +548,16 @@ async def fetch_raw_file(
|
|||
|
||||
支持多镜像源自动切换和错误重试
|
||||
|
||||
注意:此接口可公开访问,用于获取插件仓库等公开资源
|
||||
需要认证才能访问,防止被滥用作为 SSRF 跳板
|
||||
"""
|
||||
# Token 验证(可选,用于日志记录)
|
||||
# Token 验证(强制)
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
is_authenticated = token and token_manager.verify_token(token)
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
# 对于公开仓库的访问,不强制要求认证
|
||||
# 只在日志中记录是否认证
|
||||
logger.info(
|
||||
f"收到获取 Raw 文件请求 (认证: {is_authenticated}): "
|
||||
f"收到获取 Raw 文件请求: "
|
||||
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
|
||||
)
|
||||
|
||||
|
|
@ -564,10 +643,10 @@ async def clone_repository(
|
|||
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
||||
|
||||
try:
|
||||
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
||||
# TODO: 确定实际的插件目录基路径
|
||||
base_plugin_path = Path("./plugins") # 临时路径
|
||||
target_path = base_plugin_path / request.target_path
|
||||
# 验证 target_path 的安全性,防止路径遍历攻击
|
||||
base_plugin_path = Path("./plugins").resolve()
|
||||
base_plugin_path.mkdir(exist_ok=True)
|
||||
target_path = validate_safe_path(request.target_path, base_plugin_path)
|
||||
|
||||
service = get_git_mirror_service()
|
||||
result = await service.clone_repository(
|
||||
|
|
@ -607,13 +686,16 @@ async def install_plugin(
|
|||
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
||||
|
||||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
# 推送进度:开始安装
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=5,
|
||||
message=f"开始安装插件: {request.plugin_id}",
|
||||
message=f"开始安装插件: {plugin_id}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 1. 解析仓库 URL
|
||||
|
|
@ -634,27 +716,28 @@ async def install_plugin(
|
|||
progress=10,
|
||||
message=f"解析仓库信息: {owner}/{repo}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 2. 确定插件安装路径
|
||||
plugins_dir = Path("plugins")
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
plugins_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
|
||||
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
|
||||
folder_name = request.plugin_id.replace(".", "_")
|
||||
target_path = plugins_dir / folder_name
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
# 使用安全路径验证,防止路径遍历
|
||||
target_path = validate_safe_path(folder_name, plugins_dir)
|
||||
|
||||
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
old_format_path = plugins_dir / plugin_id
|
||||
if target_path.exists() or old_format_path.exists():
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="插件已存在",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="插件已安装,请先卸载",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="插件已安装")
|
||||
|
|
@ -664,7 +747,7 @@ async def install_plugin(
|
|||
progress=15,
|
||||
message=f"准备克隆到: {target_path}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
||||
|
|
@ -693,14 +776,14 @@ async def install_plugin(
|
|||
progress=0,
|
||||
message="克隆仓库失败",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# 4. 验证插件完整性
|
||||
await update_progress(
|
||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
|
||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
manifest_path = target_path / "_manifest.json"
|
||||
|
|
@ -715,14 +798,14 @@ async def install_plugin(
|
|||
progress=0,
|
||||
message="插件缺少 _manifest.json",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
||||
# 5. 读取并验证 manifest
|
||||
await update_progress(
|
||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
|
||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -739,7 +822,7 @@ async def install_plugin(
|
|||
|
||||
# 将插件 ID 写入 manifest(用于后续准确识别)
|
||||
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
||||
manifest["id"] = request.plugin_id
|
||||
manifest["id"] = plugin_id
|
||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
|
@ -754,7 +837,7 @@ async def install_plugin(
|
|||
progress=0,
|
||||
message="_manifest.json 无效",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
|
|
@ -765,13 +848,13 @@ async def install_plugin(
|
|||
progress=100,
|
||||
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件安装成功",
|
||||
"plugin_id": request.plugin_id,
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": manifest["name"],
|
||||
"version": manifest["version"],
|
||||
"path": str(target_path),
|
||||
|
|
@ -787,7 +870,7 @@ async def install_plugin(
|
|||
progress=0,
|
||||
message="安装失败",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
|
@ -814,22 +897,26 @@ async def uninstall_plugin(
|
|||
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
||||
|
||||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
# 推送进度:开始卸载
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=10,
|
||||
message=f"开始卸载插件: {request.plugin_id}",
|
||||
message=f"开始卸载插件: {plugin_id}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 1. 检查插件是否存在(支持新旧两种格式)
|
||||
plugins_dir = Path("plugins")
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
# 新格式:下划线
|
||||
folder_name = request.plugin_id.replace(".", "_")
|
||||
plugin_path = plugins_dir / folder_name
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
# 使用安全路径验证
|
||||
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
|
|
@ -841,7 +928,7 @@ async def uninstall_plugin(
|
|||
progress=0,
|
||||
message="插件不存在",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="插件未安装或已被删除",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
|
@ -851,12 +938,12 @@ async def uninstall_plugin(
|
|||
progress=30,
|
||||
message=f"正在删除插件文件: {plugin_path}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 2. 读取插件信息(用于日志)
|
||||
manifest_path = plugin_path / "_manifest.json"
|
||||
plugin_name = request.plugin_id
|
||||
plugin_name = plugin_id
|
||||
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
|
|
@ -864,7 +951,7 @@ async def uninstall_plugin(
|
|||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json_module.load(f)
|
||||
plugin_name = manifest.get("name", request.plugin_id)
|
||||
plugin_name = manifest.get("name", plugin_id)
|
||||
except Exception:
|
||||
pass # 如果读取失败,使用插件 ID 作为名称
|
||||
|
||||
|
|
@ -873,7 +960,7 @@ async def uninstall_plugin(
|
|||
progress=50,
|
||||
message=f"正在删除 {plugin_name}...",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 3. 删除插件目录
|
||||
|
|
@ -889,7 +976,7 @@ async def uninstall_plugin(
|
|||
|
||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||
|
||||
logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})")
|
||||
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
|
||||
|
||||
# 4. 推送成功状态
|
||||
await update_progress(
|
||||
|
|
@ -897,10 +984,10 @@ async def uninstall_plugin(
|
|||
progress=100,
|
||||
message=f"成功卸载插件: {plugin_name}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
|
||||
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -912,7 +999,7 @@ async def uninstall_plugin(
|
|||
progress=0,
|
||||
message="卸载失败",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="权限不足,无法删除插件文件",
|
||||
)
|
||||
|
||||
|
|
@ -925,7 +1012,7 @@ async def uninstall_plugin(
|
|||
progress=0,
|
||||
message="卸载失败",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
|
@ -952,22 +1039,26 @@ async def update_plugin(
|
|||
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
||||
|
||||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
# 推送进度:开始更新
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=5,
|
||||
message=f"开始更新插件: {request.plugin_id}",
|
||||
message=f"开始更新插件: {plugin_id}",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 1. 检查插件是否已安装(支持新旧两种格式)
|
||||
plugins_dir = Path("plugins")
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
# 新格式:下划线
|
||||
folder_name = request.plugin_id.replace(".", "_")
|
||||
plugin_path = plugins_dir / folder_name
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
# 使用安全路径验证
|
||||
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
|
|
@ -979,7 +1070,7 @@ async def update_plugin(
|
|||
progress=0,
|
||||
message="插件不存在",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="插件未安装,请先安装",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
|
@ -1003,12 +1094,12 @@ async def update_plugin(
|
|||
progress=10,
|
||||
message=f"当前版本: {old_version},准备更新...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
# 3. 删除旧版本
|
||||
await update_progress(
|
||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
|
||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
import shutil
|
||||
|
|
@ -1023,7 +1114,7 @@ async def update_plugin(
|
|||
|
||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||
|
||||
logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}")
|
||||
logger.info(f"已删除旧版本: {plugin_id} v{old_version}")
|
||||
|
||||
# 4. 解析仓库 URL
|
||||
await update_progress(
|
||||
|
|
@ -1031,7 +1122,7 @@ async def update_plugin(
|
|||
progress=30,
|
||||
message="正在准备下载新版本...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
repo_url = request.repository_url.rstrip("/")
|
||||
|
|
@ -1069,14 +1160,14 @@ async def update_plugin(
|
|||
progress=0,
|
||||
message="下载新版本失败",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# 6. 验证新版本
|
||||
await update_progress(
|
||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
|
||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
new_manifest_path = plugin_path / "_manifest.json"
|
||||
|
|
@ -1096,7 +1187,7 @@ async def update_plugin(
|
|||
progress=0,
|
||||
message="新版本缺少 _manifest.json",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
|
@ -1107,9 +1198,9 @@ async def update_plugin(
|
|||
new_manifest = json_module.load(f)
|
||||
|
||||
new_version = new_manifest.get("version", "unknown")
|
||||
new_name = new_manifest.get("name", request.plugin_id)
|
||||
new_name = new_manifest.get("name", plugin_id)
|
||||
|
||||
logger.info(f"成功更新插件: {request.plugin_id} {old_version} → {new_version}")
|
||||
logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}")
|
||||
|
||||
# 8. 推送成功状态
|
||||
await update_progress(
|
||||
|
|
@ -1117,13 +1208,13 @@ async def update_plugin(
|
|||
progress=100,
|
||||
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件更新成功",
|
||||
"plugin_id": request.plugin_id,
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": new_name,
|
||||
"old_version": old_version,
|
||||
"new_version": new_version,
|
||||
|
|
@ -1138,7 +1229,7 @@ async def update_plugin(
|
|||
progress=0,
|
||||
message="_manifest.json 无效",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
|
|
@ -1149,7 +1240,7 @@ async def update_plugin(
|
|||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||
|
||||
await update_progress(
|
||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
|
||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
|
|
|||
|
|
@ -0,0 +1,264 @@
|
|||
"""
|
||||
WebUI 请求频率限制模块
|
||||
防止暴力破解和 API 滥用
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Tuple, Optional
|
||||
from fastapi import Request, HTTPException
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.rate_limiter")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
简单的内存请求频率限制器
|
||||
|
||||
使用滑动窗口算法实现
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 存储格式: {key: [(timestamp, count), ...]}
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||
self._blocked: Dict[str, float] = {}
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""获取客户端 IP 地址"""
|
||||
# 检查代理头
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
# 取第一个 IP(最原始的客户端)
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# 直接连接的客户端
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _cleanup_old_requests(self, key: str, window_seconds: int):
|
||||
"""清理过期的请求记录"""
|
||||
now = time.time()
|
||||
cutoff = now - window_seconds
|
||||
self._requests[key] = [
|
||||
(ts, count) for ts, count in self._requests[key]
|
||||
if ts > cutoff
|
||||
]
|
||||
|
||||
def _cleanup_expired_blocks(self):
|
||||
"""清理过期的封禁"""
|
||||
now = time.time()
|
||||
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
|
||||
for ip in expired:
|
||||
del self._blocked[ip]
|
||||
logger.info(f"🔓 IP {ip} 封禁已解除")
|
||||
|
||||
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
|
||||
"""
|
||||
检查 IP 是否被封禁
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余封禁秒数)
|
||||
"""
|
||||
self._cleanup_expired_blocks()
|
||||
ip = self._get_client_ip(request)
|
||||
|
||||
if ip in self._blocked:
|
||||
remaining = int(self._blocked[ip] - time.time())
|
||||
return True, max(0, remaining)
|
||||
|
||||
return False, None
|
||||
|
||||
def check_rate_limit(
|
||||
self,
|
||||
request: Request,
|
||||
max_requests: int,
|
||||
window_seconds: int,
|
||||
key_suffix: str = ""
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
检查请求是否超过频率限制
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_requests: 窗口期内允许的最大请求数
|
||||
window_seconds: 窗口时间(秒)
|
||||
key_suffix: 键后缀,用于区分不同的限制规则
|
||||
|
||||
Returns:
|
||||
(是否允许, 剩余请求数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:{key_suffix}" if key_suffix else ip
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前窗口内的请求数
|
||||
current_count = sum(count for _, count in self._requests[key])
|
||||
|
||||
if current_count >= max_requests:
|
||||
return False, 0
|
||||
|
||||
# 记录新请求
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
|
||||
remaining = max_requests - current_count - 1
|
||||
return True, remaining
|
||||
|
||||
def block_ip(self, request: Request, duration_seconds: int):
|
||||
"""
|
||||
封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
duration_seconds: 封禁时长(秒)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
self._blocked[ip] = time.time() + duration_seconds
|
||||
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒")
|
||||
|
||||
def record_failed_attempt(
|
||||
self,
|
||||
request: Request,
|
||||
max_failures: int = 5,
|
||||
window_seconds: int = 300,
|
||||
block_duration: int = 600
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
记录失败尝试(如登录失败)
|
||||
|
||||
如果在窗口期内失败次数过多,自动封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_failures: 允许的最大失败次数
|
||||
window_seconds: 统计窗口(秒)
|
||||
block_duration: 封禁时长(秒)
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余尝试次数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前失败次数
|
||||
current_failures = sum(count for _, count in self._requests[key])
|
||||
|
||||
# 记录本次失败
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
current_failures += 1
|
||||
|
||||
remaining = max_failures - current_failures
|
||||
|
||||
# 检查是否需要封禁
|
||||
if current_failures >= max_failures:
|
||||
self.block_ip(request, block_duration)
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
|
||||
return True, 0
|
||||
|
||||
if current_failures >= max_failures - 2:
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次")
|
||||
|
||||
return False, max(0, remaining)
|
||||
|
||||
def reset_failures(self, request: Request):
|
||||
"""
|
||||
重置失败计数(认证成功后调用)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
if key in self._requests:
|
||||
del self._requests[key]
|
||||
|
||||
|
||||
# 全局单例
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""获取 RateLimiter 单例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
async def check_auth_rate_limit(request: Request):
|
||||
"""
|
||||
认证接口的频率限制依赖
|
||||
|
||||
规则:
|
||||
- 每个 IP 每分钟最多 10 次认证请求
|
||||
- 连续失败 5 次后封禁 10 分钟
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)}
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, remaining = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=10, # 每分钟 10 次
|
||||
window_seconds=60,
|
||||
key_suffix="auth"
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="认证请求过于频繁,请稍后重试",
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
|
||||
|
||||
async def check_api_rate_limit(request: Request):
|
||||
"""
|
||||
普通 API 的频率限制依赖
|
||||
|
||||
规则:每个 IP 每分钟最多 100 次请求
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)}
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, _ = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=100, # 每分钟 100 次
|
||||
window_seconds=60,
|
||||
key_suffix="api"
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="请求过于频繁,请稍后重试",
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
|
|
@ -7,10 +7,12 @@
|
|||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
logger = get_logger("webui_system")
|
||||
|
|
@ -19,6 +21,14 @@ logger = get_logger("webui_system")
|
|||
_start_time = time.time()
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
"""重启响应"""
|
||||
|
||||
|
|
@ -36,7 +46,7 @@ class StatusResponse(BaseModel):
|
|||
|
||||
|
||||
@router.post("/restart", response_model=RestartResponse)
|
||||
async def restart_maibot():
|
||||
async def restart_maibot(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
重启麦麦主程序
|
||||
|
||||
|
|
@ -67,7 +77,7 @@ async def restart_maibot():
|
|||
|
||||
|
||||
@router.get("/status", response_model=StatusResponse)
|
||||
async def get_maibot_status():
|
||||
async def get_maibot_status(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取麦麦运行状态
|
||||
|
||||
|
|
@ -90,7 +100,7 @@ async def get_maibot_status():
|
|||
|
||||
|
||||
@router.post("/reload-config")
|
||||
async def reload_config():
|
||||
async def reload_config(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
热重载配置(不重启进程)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
|
||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import set_auth_cookie, clear_auth_cookie
|
||||
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
|
||||
from .config_routes import router as config_router
|
||||
from .statistics_routes import router as statistics_router
|
||||
from .person_routes import router as person_router
|
||||
|
|
@ -16,6 +17,7 @@ from .plugin_routes import router as plugin_router
|
|||
from .plugin_progress_ws import get_progress_router
|
||||
from .routers.system import router as system_router
|
||||
from .model_routes import router as model_router
|
||||
from .ws_auth import router as ws_auth_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
|
|
@ -42,6 +44,8 @@ router.include_router(get_progress_router())
|
|||
router.include_router(system_router)
|
||||
# 注册模型列表获取路由
|
||||
router.include_router(model_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
|
|
@ -107,12 +111,18 @@ async def health_check():
|
|||
|
||||
|
||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||
async def verify_token(request: TokenVerifyRequest, response: Response):
|
||||
async def verify_token(
|
||||
request_body: TokenVerifyRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
_rate_limit: None = Depends(check_auth_rate_limit),
|
||||
):
|
||||
"""
|
||||
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
||||
|
||||
Args:
|
||||
request: 包含 token 的验证请求
|
||||
request_body: 包含 token 的验证请求
|
||||
request: FastAPI Request 对象(用于获取客户端 IP)
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
|
|
@ -120,16 +130,40 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
|
|||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(request.token)
|
||||
rate_limiter = get_rate_limiter()
|
||||
|
||||
is_valid = token_manager.verify_token(request_body.token)
|
||||
|
||||
if is_valid:
|
||||
# 认证成功,重置失败计数
|
||||
rate_limiter.reset_failures(request)
|
||||
# 设置 HttpOnly Cookie
|
||||
set_auth_cookie(response, request.token)
|
||||
set_auth_cookie(response, request_body.token)
|
||||
# 同时返回首次配置状态,避免额外请求
|
||||
is_first_setup = token_manager.is_first_setup()
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
||||
else:
|
||||
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
||||
# 记录失败尝试
|
||||
blocked, remaining = rate_limiter.record_failed_attempt(
|
||||
request,
|
||||
max_failures=5, # 5 次失败
|
||||
window_seconds=300, # 5 分钟窗口
|
||||
block_duration=600 # 封禁 10 分钟
|
||||
)
|
||||
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟"
|
||||
)
|
||||
|
||||
message = "Token 无效或已过期"
|
||||
if remaining <= 2:
|
||||
message += f"(剩余 {remaining} 次尝试机会)"
|
||||
|
||||
return TokenVerifyResponse(valid=False, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 验证失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
|
|
|
|||
|
|
@ -1,19 +1,28 @@
|
|||
"""统计数据 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui.statistics")
|
||||
|
||||
router = APIRouter(prefix="/statistics", tags=["statistics"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
"""统计数据摘要"""
|
||||
|
||||
|
|
@ -58,7 +67,7 @@ class DashboardData(BaseModel):
|
|||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardData)
|
||||
async def get_dashboard_data(hours: int = 24):
|
||||
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
|
|
@ -275,7 +284,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
|||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_summary(hours: int = 24):
|
||||
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取统计摘要
|
||||
|
||||
|
|
@ -293,7 +302,7 @@ async def get_summary(hours: int = 24):
|
|||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_model_stats(hours: int = 24):
|
||||
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取模型统计
|
||||
|
||||
|
|
|
|||
|
|
@ -46,12 +46,21 @@ class WebUIServer:
|
|||
allow_origins=[
|
||||
"http://localhost:5173", # Vite 开发服务器
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:7999", # 前端开发服务器备用端口
|
||||
"http://127.0.0.1:7999",
|
||||
"http://localhost:8001", # 生产环境
|
||||
"http://127.0.0.1:8001",
|
||||
],
|
||||
allow_credentials=True, # 允许携带 Cookie
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Accept",
|
||||
"Origin",
|
||||
"X-Requested-With",
|
||||
], # 明确指定允许的头
|
||||
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
|
||||
)
|
||||
logger.debug("✅ CORS 中间件已配置")
|
||||
|
||||
|
|
@ -89,23 +98,46 @@ class WebUIServer:
|
|||
logger.warning("💡 请确认前端已正确构建")
|
||||
return
|
||||
|
||||
# robots.txt - 禁止搜索引擎索引
|
||||
@self.app.get("/robots.txt", include_in_schema=False)
|
||||
async def robots_txt():
|
||||
"""返回 robots.txt 禁止所有爬虫"""
|
||||
from fastapi.responses import PlainTextResponse
|
||||
content = """User-agent: *
|
||||
Disallow: /
|
||||
|
||||
# MaiBot Dashboard - 私有管理面板,禁止索引
|
||||
"""
|
||||
return PlainTextResponse(
|
||||
content=content,
|
||||
headers={"X-Robots-Tag": "noindex, nofollow, noarchive"}
|
||||
)
|
||||
|
||||
# 处理 SPA 路由 - 注意:这个路由优先级最低
|
||||
@self.app.get("/{full_path:path}", include_in_schema=False)
|
||||
async def serve_spa(full_path: str):
|
||||
"""服务单页应用 - 只处理非 API 请求"""
|
||||
# 如果是根路径,直接返回 index.html
|
||||
if not full_path or full_path == "/":
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
# 检查是否是静态文件
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file() and file_path.exists():
|
||||
# 自动检测 MIME 类型
|
||||
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||
return FileResponse(file_path, media_type=media_type)
|
||||
response = FileResponse(file_path, media_type=media_type)
|
||||
# HTML 文件添加防索引头
|
||||
if str(file_path).endswith('.html'):
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
# 其他路径返回 index.html(SPA 路由)
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,107 @@
|
|||
"""WebSocket 认证模块
|
||||
|
||||
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
|
||||
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Cookie, Header, HTTPException
|
||||
from typing import Optional
|
||||
import secrets
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.ws_auth")
|
||||
router = APIRouter()
|
||||
|
||||
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
||||
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
||||
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
|
||||
_WS_TOKEN_EXPIRE_SECONDS = 60
|
||||
|
||||
|
||||
def _cleanup_expired_ws_tokens():
|
||||
"""清理过期的临时 token"""
|
||||
now = time.time()
|
||||
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
|
||||
for t in expired:
|
||||
del _ws_temp_tokens[t]
|
||||
|
||||
|
||||
def generate_ws_token(session_token: str) -> str:
|
||||
"""生成 WebSocket 临时 token
|
||||
|
||||
Args:
|
||||
session_token: 原始的 session token
|
||||
|
||||
Returns:
|
||||
临时 token 字符串
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
temp_token = secrets.token_urlsafe(32)
|
||||
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
|
||||
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
|
||||
return temp_token
|
||||
|
||||
|
||||
def verify_ws_token(temp_token: str) -> bool:
|
||||
"""验证并消费 WebSocket 临时 token(一次性使用)
|
||||
|
||||
Args:
|
||||
temp_token: 临时 token
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
if temp_token not in _ws_temp_tokens:
|
||||
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
|
||||
return False
|
||||
expire_time, session_token = _ws_temp_tokens[temp_token]
|
||||
if time.time() > expire_time:
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
|
||||
return False
|
||||
# 验证原始 session token 仍然有效
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
|
||||
return False
|
||||
# 消费 token(一次性使用)
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/ws-token")
|
||||
async def get_ws_token(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取 WebSocket 连接用的临时 token
|
||||
|
||||
此端点验证当前会话的 Cookie 或 Authorization header,
|
||||
然后返回一个临时 token 用于 WebSocket 握手认证。
|
||||
临时 token 有效期 60 秒,且只能使用一次。
|
||||
"""
|
||||
# 获取当前 session token
|
||||
session_token = None
|
||||
if maibot_session:
|
||||
session_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
session_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not session_token:
|
||||
raise HTTPException(status_code=401, detail="未提供认证信息")
|
||||
|
||||
# 验证 session token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
raise HTTPException(status_code=401, detail="认证已过期,请重新登录")
|
||||
|
||||
# 生成临时 WebSocket token
|
||||
ws_token = generate_ws_token(session_token)
|
||||
|
||||
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -4,10 +4,14 @@
|
|||
<meta charset="UTF-8" />
|
||||
<meta name="google" content="notranslate" />
|
||||
<meta http-equiv="content-language" content="zh-CN" />
|
||||
<!-- 防止搜索引擎索引 -->
|
||||
<meta name="robots" content="noindex, nofollow, noarchive, nosnippet" />
|
||||
<meta name="googlebot" content="noindex, nofollow" />
|
||||
<meta name="bingbot" content="noindex, nofollow" />
|
||||
<link rel="icon" type="image/x-icon" href="/maimai.ico" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>MaiBot Dashboard</title>
|
||||
<script type="module" crossorigin src="/assets/index-DcGiKm2P.js"></script>
|
||||
<script type="module" crossorigin src="/assets/index-D-S1XZ00.js"></script>
|
||||
<link rel="modulepreload" crossorigin href="/assets/react-vendor-BmxF9s7Q.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/router-Bz250laD.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/utils-BXc2jIuz.js">
|
||||
|
|
|
|||
Loading…
Reference in New Issue