Merge branch 'Mai-with-u:dev' into dev

pull/1439/head
Dawn ARC 2025-12-14 20:26:21 +08:00 committed by GitHub
commit 72786687b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 944 additions and 209 deletions

View File

@ -3,6 +3,7 @@ WebUI 认证模块
提供统一的认证依赖支持 Cookie Header 两种方式
"""
import os
from typing import Optional
from fastapi import HTTPException, Cookie, Header, Response, Request
from src.common.logger import get_logger
@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session"
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
def _is_secure_environment() -> bool:
"""
检测是否应该启用安全 CookieHTTPS
Returns:
bool: 如果应该使用 secure cookie 则返回 True
"""
# 检查环境变量
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("true", "1", "yes"):
return True
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
return False
# 检查是否是生产环境
env = os.environ.get("WEBUI_MODE", "").lower()
if env in ("production", "prod"):
return True
# 默认:开发环境不启用(因为通常是 HTTP
return False
def get_current_token(
request: Request,
maibot_session: Optional[str] = Cookie(None),
@ -62,16 +85,19 @@ def set_auth_cookie(response: Response, token: str) -> None:
response: FastAPI Response 对象
token: 要设置的 token
"""
# 根据环境决定安全设置
is_secure = _is_secure_environment()
response.set_cookie(
key=COOKIE_NAME,
value=token,
max_age=COOKIE_MAX_AGE,
httponly=True, # 防止 JS 读取
samesite="lax", # 允许同站导航时发送 Cookie兼容开发环境代理
secure=False, # 本地开发不强制 HTTPS生产环境建议设为 True
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
samesite="strict" if is_secure else "lax", # 生产环境使用 strict 防止 CSRF
secure=is_secure, # 生产环境强制 HTTPS
path="/", # 确保 Cookie 在所有路径下可用
)
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
logger.debug(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure})")
def clear_auth_cookie(response: Response) -> None:
@ -81,10 +107,14 @@ def clear_auth_cookie(response: Response) -> None:
Args:
response: FastAPI Response 对象
"""
# 保持与 set_auth_cookie 相同的安全设置
is_secure = _is_secure_environment()
response.delete_cookie(
key=COOKIE_NAME,
httponly=True,
samesite="lax",
samesite="strict" if is_secure else "lax",
secure=is_secure,
path="/",
)
logger.debug("已清除认证 Cookie")

View File

@ -8,18 +8,29 @@
import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel
from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.chat")
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
@ -256,6 +267,7 @@ async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
_auth: bool = Depends(require_auth),
):
"""获取聊天历史记录
@ -272,7 +284,7 @@ async def get_chat_history(
@router.get("/platforms")
async def get_available_platforms():
async def get_available_platforms(_auth: bool = Depends(require_auth)):
"""获取可用平台列表
PersonInfo 表中获取所有已知的平台
@ -303,6 +315,7 @@ async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
_auth: bool = Depends(require_auth),
):
"""获取指定平台的用户列表
@ -350,7 +363,7 @@ async def get_persons_by_platform(
@router.delete("/history")
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
"""清空聊天历史记录
Args:
@ -372,6 +385,7 @@ async def websocket_chat(
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
token: Optional[str] = Query(default=None), # 认证 token
):
"""WebSocket 聊天端点
@ -382,9 +396,45 @@ async def websocket_chat(
person_id: 虚拟身份模式的用户 person_id可选
group_name: 虚拟身份模式的群名可选
group_id: 虚拟身份模式的群 ID可选由前端生成并持久化
token: 认证 token可选也可从 Cookie 获取
虚拟身份模式可通过 URL 参数直接配置或通过消息中的 set_virtual_identity 配置
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())
@ -712,7 +762,7 @@ async def websocket_chat(
@router.get("/info")
async def get_chat_info():
async def get_chat_info(_auth: bool = Depends(require_auth)):
"""获取聊天室信息"""
return {
"bot_name": global_config.bot.nickname,

View File

@ -4,10 +4,11 @@
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body
from typing import Any, Annotated
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
from typing import Any, Annotated, Optional
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import (
@ -49,11 +50,19 @@ PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema():
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
@ -65,7 +74,7 @@ async def get_bot_config_schema():
@router.get("/schema/model")
async def get_model_config_schema():
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
@ -79,7 +88,7 @@ async def get_model_config_schema():
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str):
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
"""
获取指定配置节的架构
@ -149,7 +158,7 @@ async def get_config_section_schema(section_name: str):
@router.get("/bot")
async def get_bot_config():
async def get_bot_config(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@ -168,7 +177,7 @@ async def get_bot_config():
@router.get("/model")
async def get_model_config():
async def get_model_config(_auth: bool = Depends(require_auth)):
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
@ -190,7 +199,7 @@ async def get_model_config():
@router.post("/bot")
async def update_bot_config(config_data: ConfigBody):
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
@ -213,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody):
@router.post("/model")
async def update_model_config(config_data: ConfigBody):
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新模型配置"""
try:
# 验证配置数据
@ -239,7 +248,7 @@ async def update_model_config(config_data: ConfigBody):
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: SectionBody):
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@ -288,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
@router.get("/bot/raw")
async def get_bot_config_raw():
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@ -307,7 +316,7 @@ async def get_bot_config_raw():
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: RawContentBody):
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
@ -337,7 +346,7 @@ async def update_bot_config_raw(raw_content: RawContentBody):
@router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: SectionBody):
async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@ -430,7 +439,7 @@ def _to_relative_path(path: str) -> str:
@router.get("/adapter-config/path")
async def get_adapter_config_path():
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
@ -469,7 +478,7 @@ async def get_adapter_config_path():
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: PathBody):
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
@ -512,7 +521,7 @@ async def save_adapter_config_path(data: PathBody):
@router.get("/adapter-config")
async def get_adapter_config(path: str):
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
"""从指定路径读取适配器配置文件"""
try:
if not path:
@ -544,7 +553,7 @@ async def get_adapter_config(path: str):
@router.post("/adapter-config")
async def save_adapter_config(data: PathBody):
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")

View File

@ -1,15 +1,24 @@
"""知识库图谱可视化 API 路由"""
from typing import List, Optional
from fastapi import APIRouter, Query
from fastapi import APIRouter, Query, Depends, Cookie, Header
from pydantic import BaseModel
import logging
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class KnowledgeNode(BaseModel):
"""知识节点"""
@ -113,6 +122,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
_auth: bool = Depends(require_auth),
):
"""获取知识图谱(限制节点数量)
@ -199,7 +209,7 @@ async def get_knowledge_graph(
@router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats():
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
"""获取知识库统计信息
Returns:
@ -248,7 +258,7 @@ async def get_knowledge_stats():
@router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1)):
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
"""搜索知识节点
Args:

View File

@ -1,10 +1,12 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Optional
import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.logs_ws")
router = APIRouter()
@ -73,14 +75,48 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket):
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:

View File

@ -6,18 +6,27 @@
import os
import httpx
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
from typing import Optional
import tomlkit
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui")
router = APIRouter(prefix="/models", tags=["models"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# 模型获取器配置
MODEL_FETCHER_CONFIG = {
# OpenAI 兼容格式的提供商
@ -184,6 +193,7 @@ async def get_provider_models(
provider_name: str = Query(..., description="提供商名称"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
_auth: bool = Depends(require_auth),
):
"""
获取指定提供商的可用模型列表
@ -228,6 +238,7 @@ async def get_models_by_url(
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
_auth: bool = Depends(require_auth),
):
"""
通过 URL 直接获取模型列表用于自定义提供商
@ -251,6 +262,7 @@ async def get_models_by_url(
async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
_auth: bool = Depends(require_auth),
):
"""
测试提供商连接状态
@ -337,6 +349,7 @@ async def test_provider_connection(
@router.post("/test-connection-by-name")
async def test_provider_connection_by_name(
provider_name: str = Query(..., description="提供商名称"),
_auth: bool = Depends(require_auth),
):
"""
通过提供商名称测试连接从配置文件读取信息

View File

@ -1,10 +1,12 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set, Dict, Any
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Dict, Any, Optional
import json
import asyncio
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.plugin_progress")
@ -89,14 +91,48 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket):
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态

View File

@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, get_origin
from pathlib import Path
import json
import re
from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION
@ -34,6 +35,85 @@ def get_token_from_cookie_or_header(
return None
def validate_safe_path(user_path: str, base_path: Path) -> Path:
"""
验证用户提供的路径是否安全防止路径遍历攻击
Args:
user_path: 用户输入的路径相对路径
base_path: 允许的基础目录
Returns:
安全的绝对路径
Raises:
HTTPException: 如果检测到路径遍历攻击
"""
# 规范化基础路径
base_resolved = base_path.resolve()
# 检查用户路径是否包含可疑字符
# 禁止: .., 绝对路径开头, 空字节等
if any(pattern in user_path for pattern in ["..", "\x00"]):
logger.warning(f"检测到可疑路径: {user_path}")
raise HTTPException(status_code=400, detail="路径包含非法字符")
# 检查是否为绝对路径Windows 和 Unix
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
logger.warning(f"检测到绝对路径: {user_path}")
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
# 构建目标路径并解析
target_path = (base_path / user_path).resolve()
# 验证解析后的路径仍在基础目录内
try:
target_path.relative_to(base_resolved)
except ValueError as e:
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
return target_path
def validate_plugin_id(plugin_id: str) -> str:
"""
验证插件 ID 格式是否安全
Args:
plugin_id: 插件 ID (支持 author.name 格式允许中文)
Returns:
验证通过的插件 ID
Raises:
HTTPException: 如果插件 ID 格式不安全
"""
# 禁止空字符串
if not plugin_id or not plugin_id.strip():
logger.warning("非法插件 ID: 空字符串")
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
# 禁止危险字符: 路径分隔符、空字节、控制字符等
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
for pattern in dangerous_patterns:
if pattern in plugin_id:
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
if plugin_id.startswith(".") or plugin_id.endswith("."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
# 禁止特殊名称
if plugin_id in (".", ".."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
return plugin_id
def parse_version(version_str: str) -> tuple[int, int, int]:
"""
解析版本号字符串
@ -468,17 +548,16 @@ async def fetch_raw_file(
支持多镜像源自动切换和错误重试
注意此接口可公开访问用于获取插件仓库等公开资源
需要认证才能访问防止被滥用作为 SSRF 跳板
"""
# Token 验证(可选,用于日志记录
# Token 验证(强制
token = get_token_from_cookie_or_header(maibot_session, authorization)
token_manager = get_token_manager()
is_authenticated = token and token_manager.verify_token(token)
if not token or not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
# 对于公开仓库的访问,不强制要求认证
# 只在日志中记录是否认证
logger.info(
f"收到获取 Raw 文件请求 (认证: {is_authenticated}): "
f"收到获取 Raw 文件请求: "
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
)
@ -564,10 +643,10 @@ async def clone_repository(
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
try:
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
# TODO: 确定实际的插件目录基路径
base_plugin_path = Path("./plugins") # 临时路径
target_path = base_plugin_path / request.target_path
# 验证 target_path 的安全性,防止路径遍历攻击
base_plugin_path = Path("./plugins").resolve()
base_plugin_path.mkdir(exist_ok=True)
target_path = validate_safe_path(request.target_path, base_plugin_path)
service = get_git_mirror_service()
result = await service.clone_repository(
@ -607,13 +686,16 @@ async def install_plugin(
logger.info(f"收到安装插件请求: {request.plugin_id}")
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始安装
await update_progress(
stage="loading",
progress=5,
message=f"开始安装插件: {request.plugin_id}",
message=f"开始安装插件: {plugin_id}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 1. 解析仓库 URL
@ -634,27 +716,28 @@ async def install_plugin(
progress=10,
message=f"解析仓库信息: {owner}/{repo}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 2. 确定插件安装路径
plugins_dir = Path("plugins")
plugins_dir = Path("plugins").resolve()
plugins_dir.mkdir(exist_ok=True)
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
folder_name = request.plugin_id.replace(".", "_")
target_path = plugins_dir / folder_name
folder_name = plugin_id.replace(".", "_")
# 使用安全路径验证,防止路径遍历
target_path = validate_safe_path(folder_name, plugins_dir)
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
old_format_path = plugins_dir / request.plugin_id
old_format_path = plugins_dir / plugin_id
if target_path.exists() or old_format_path.exists():
await update_progress(
stage="error",
progress=0,
message="插件已存在",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="插件已安装,请先卸载",
)
raise HTTPException(status_code=400, detail="插件已安装")
@ -664,7 +747,7 @@ async def install_plugin(
progress=15,
message=f"准备克隆到: {target_path}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
@ -693,14 +776,14 @@ async def install_plugin(
progress=0,
message="克隆仓库失败",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# 4. 验证插件完整性
await update_progress(
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
)
manifest_path = target_path / "_manifest.json"
@ -715,14 +798,14 @@ async def install_plugin(
progress=0,
message="插件缺少 _manifest.json",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
# 5. 读取并验证 manifest
await update_progress(
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
)
try:
@ -739,7 +822,7 @@ async def install_plugin(
# 将插件 ID 写入 manifest用于后续准确识别
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
manifest["id"] = request.plugin_id
manifest["id"] = plugin_id
with open(manifest_path, "w", encoding="utf-8") as f:
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
@ -754,7 +837,7 @@ async def install_plugin(
progress=0,
message="_manifest.json 无效",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
@ -765,13 +848,13 @@ async def install_plugin(
progress=100,
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
return {
"success": True,
"message": "插件安装成功",
"plugin_id": request.plugin_id,
"plugin_id": plugin_id,
"plugin_name": manifest["name"],
"version": manifest["version"],
"path": str(target_path),
@ -787,7 +870,7 @@ async def install_plugin(
progress=0,
message="安装失败",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
@ -814,22 +897,26 @@ async def uninstall_plugin(
logger.info(f"收到卸载插件请求: {request.plugin_id}")
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始卸载
await update_progress(
stage="loading",
progress=10,
message=f"开始卸载插件: {request.plugin_id}",
message=f"开始卸载插件: {plugin_id}",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 1. 检查插件是否存在(支持新旧两种格式)
plugins_dir = Path("plugins")
plugins_dir = Path("plugins").resolve()
# 新格式:下划线
folder_name = request.plugin_id.replace(".", "_")
plugin_path = plugins_dir / folder_name
folder_name = plugin_id.replace(".", "_")
# 使用安全路径验证
plugin_path = validate_safe_path(folder_name, plugins_dir)
# 旧格式:点
old_format_path = plugins_dir / request.plugin_id
old_format_path = validate_safe_path(plugin_id, plugins_dir)
# 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists():
@ -841,7 +928,7 @@ async def uninstall_plugin(
progress=0,
message="插件不存在",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="插件未安装或已被删除",
)
raise HTTPException(status_code=404, detail="插件未安装")
@ -851,12 +938,12 @@ async def uninstall_plugin(
progress=30,
message=f"正在删除插件文件: {plugin_path}",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 2. 读取插件信息(用于日志)
manifest_path = plugin_path / "_manifest.json"
plugin_name = request.plugin_id
plugin_name = plugin_id
if manifest_path.exists():
try:
@ -864,7 +951,7 @@ async def uninstall_plugin(
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json_module.load(f)
plugin_name = manifest.get("name", request.plugin_id)
plugin_name = manifest.get("name", plugin_id)
except Exception:
pass # 如果读取失败,使用插件 ID 作为名称
@ -873,7 +960,7 @@ async def uninstall_plugin(
progress=50,
message=f"正在删除 {plugin_name}...",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 3. 删除插件目录
@ -889,7 +976,7 @@ async def uninstall_plugin(
shutil.rmtree(plugin_path, onerror=remove_readonly)
logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})")
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
# 4. 推送成功状态
await update_progress(
@ -897,10 +984,10 @@ async def uninstall_plugin(
progress=100,
message=f"成功卸载插件: {plugin_name}",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
except HTTPException:
raise
@ -912,7 +999,7 @@ async def uninstall_plugin(
progress=0,
message="卸载失败",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="权限不足,无法删除插件文件",
)
@ -925,7 +1012,7 @@ async def uninstall_plugin(
progress=0,
message="卸载失败",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
@ -952,22 +1039,26 @@ async def update_plugin(
logger.info(f"收到更新插件请求: {request.plugin_id}")
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始更新
await update_progress(
stage="loading",
progress=5,
message=f"开始更新插件: {request.plugin_id}",
message=f"开始更新插件: {plugin_id}",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 1. 检查插件是否已安装(支持新旧两种格式)
plugins_dir = Path("plugins")
plugins_dir = Path("plugins").resolve()
# 新格式:下划线
folder_name = request.plugin_id.replace(".", "_")
plugin_path = plugins_dir / folder_name
folder_name = plugin_id.replace(".", "_")
# 使用安全路径验证
plugin_path = validate_safe_path(folder_name, plugins_dir)
# 旧格式:点
old_format_path = plugins_dir / request.plugin_id
old_format_path = validate_safe_path(plugin_id, plugins_dir)
# 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists():
@ -979,7 +1070,7 @@ async def update_plugin(
progress=0,
message="插件不存在",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="插件未安装,请先安装",
)
raise HTTPException(status_code=404, detail="插件未安装")
@ -1003,12 +1094,12 @@ async def update_plugin(
progress=10,
message=f"当前版本: {old_version},准备更新...",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 3. 删除旧版本
await update_progress(
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
)
import shutil
@ -1023,7 +1114,7 @@ async def update_plugin(
shutil.rmtree(plugin_path, onerror=remove_readonly)
logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}")
logger.info(f"已删除旧版本: {plugin_id} v{old_version}")
# 4. 解析仓库 URL
await update_progress(
@ -1031,7 +1122,7 @@ async def update_plugin(
progress=30,
message="正在准备下载新版本...",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
repo_url = request.repository_url.rstrip("/")
@ -1069,14 +1160,14 @@ async def update_plugin(
progress=0,
message="下载新版本失败",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# 6. 验证新版本
await update_progress(
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
)
new_manifest_path = plugin_path / "_manifest.json"
@ -1096,7 +1187,7 @@ async def update_plugin(
progress=0,
message="新版本缺少 _manifest.json",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
@ -1107,9 +1198,9 @@ async def update_plugin(
new_manifest = json_module.load(f)
new_version = new_manifest.get("version", "unknown")
new_name = new_manifest.get("name", request.plugin_id)
new_name = new_manifest.get("name", plugin_id)
logger.info(f"成功更新插件: {request.plugin_id} {old_version}{new_version}")
logger.info(f"成功更新插件: {plugin_id} {old_version}{new_version}")
# 8. 推送成功状态
await update_progress(
@ -1117,13 +1208,13 @@ async def update_plugin(
progress=100,
message=f"成功更新 {new_name}: {old_version}{new_version}",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
return {
"success": True,
"message": "插件更新成功",
"plugin_id": request.plugin_id,
"plugin_id": plugin_id,
"plugin_name": new_name,
"old_version": old_version,
"new_version": new_version,
@ -1138,7 +1229,7 @@ async def update_plugin(
progress=0,
message="_manifest.json 无效",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
@ -1149,7 +1240,7 @@ async def update_plugin(
logger.error(f"更新插件失败: {e}", exc_info=True)
await update_progress(
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e

View File

@ -0,0 +1,264 @@
"""
WebUI 请求频率限制模块
防止暴力破解和 API 滥用
"""
import time
from collections import defaultdict
from typing import Dict, Tuple, Optional
from fastapi import Request, HTTPException
from src.common.logger import get_logger
logger = get_logger("webui.rate_limiter")
class RateLimiter:
"""
简单的内存请求频率限制器
使用滑动窗口算法实现
"""
def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {}
def _get_client_ip(self, request: Request) -> str:
"""获取客户端 IP 地址"""
# 检查代理头
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# 取第一个 IP最原始的客户端
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接连接的客户端
if request.client:
return request.client.host
return "unknown"
def _cleanup_old_requests(self, key: str, window_seconds: int):
"""清理过期的请求记录"""
now = time.time()
cutoff = now - window_seconds
self._requests[key] = [
(ts, count) for ts, count in self._requests[key]
if ts > cutoff
]
def _cleanup_expired_blocks(self):
"""清理过期的封禁"""
now = time.time()
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
for ip in expired:
del self._blocked[ip]
logger.info(f"🔓 IP {ip} 封禁已解除")
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
"""
检查 IP 是否被封禁
Returns:
(是否被封禁, 剩余封禁秒数)
"""
self._cleanup_expired_blocks()
ip = self._get_client_ip(request)
if ip in self._blocked:
remaining = int(self._blocked[ip] - time.time())
return True, max(0, remaining)
return False, None
def check_rate_limit(
self,
request: Request,
max_requests: int,
window_seconds: int,
key_suffix: str = ""
) -> Tuple[bool, int]:
"""
检查请求是否超过频率限制
Args:
request: FastAPI Request 对象
max_requests: 窗口期内允许的最大请求数
window_seconds: 窗口时间
key_suffix: 键后缀用于区分不同的限制规则
Returns:
(是否允许, 剩余请求数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:{key_suffix}" if key_suffix else ip
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前窗口内的请求数
current_count = sum(count for _, count in self._requests[key])
if current_count >= max_requests:
return False, 0
# 记录新请求
now = time.time()
self._requests[key].append((now, 1))
remaining = max_requests - current_count - 1
return True, remaining
def block_ip(self, request: Request, duration_seconds: int):
"""
封禁 IP
Args:
request: FastAPI Request 对象
duration_seconds: 封禁时长
"""
ip = self._get_client_ip(request)
self._blocked[ip] = time.time() + duration_seconds
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt(
self,
request: Request,
max_failures: int = 5,
window_seconds: int = 300,
block_duration: int = 600
) -> Tuple[bool, int]:
"""
记录失败尝试如登录失败
如果在窗口期内失败次数过多自动封禁 IP
Args:
request: FastAPI Request 对象
max_failures: 允许的最大失败次数
window_seconds: 统计窗口
block_duration: 封禁时长
Returns:
(是否被封禁, 剩余尝试次数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前失败次数
current_failures = sum(count for _, count in self._requests[key])
# 记录本次失败
now = time.time()
self._requests[key].append((now, 1))
current_failures += 1
remaining = max_failures - current_failures
# 检查是否需要封禁
if current_failures >= max_failures:
self.block_ip(request, block_duration)
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
return True, 0
if current_failures >= max_failures - 2:
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}")
return False, max(0, remaining)
def reset_failures(self, request: Request):
"""
重置失败计数认证成功后调用
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
if key in self._requests:
del self._requests[key]
# 全局单例
_rate_limiter: Optional[RateLimiter] = None
def get_rate_limiter() -> RateLimiter:
"""获取 RateLimiter 单例"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter
async def check_auth_rate_limit(request: Request):
"""
认证接口的频率限制依赖
规则
- 每个 IP 每分钟最多 10 次认证请求
- 连续失败 5 次后封禁 10 分钟
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)}
)
# 检查频率限制
allowed, remaining = limiter.check_rate_limit(
request,
max_requests=10, # 每分钟 10 次
window_seconds=60,
key_suffix="auth"
)
if not allowed:
raise HTTPException(
status_code=429,
detail="认证请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)
async def check_api_rate_limit(request: Request):
"""
普通 API 的频率限制依赖
规则每个 IP 每分钟最多 100 次请求
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)}
)
# 检查频率限制
allowed, _ = limiter.check_rate_limit(
request,
max_requests=100, # 每分钟 100 次
window_seconds=60,
key_suffix="api"
)
if not allowed:
raise HTTPException(
status_code=429,
detail="请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)

View File

@ -7,10 +7,12 @@
import os
import time
from datetime import datetime
from fastapi import APIRouter, HTTPException
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
router = APIRouter(prefix="/system", tags=["system"])
logger = get_logger("webui_system")
@ -19,6 +21,14 @@ logger = get_logger("webui_system")
_start_time = time.time()
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class RestartResponse(BaseModel):
"""重启响应"""
@ -36,7 +46,7 @@ class StatusResponse(BaseModel):
@router.post("/restart", response_model=RestartResponse)
async def restart_maibot():
async def restart_maibot(_auth: bool = Depends(require_auth)):
"""
重启麦麦主程序
@ -67,7 +77,7 @@ async def restart_maibot():
@router.get("/status", response_model=StatusResponse)
async def get_maibot_status():
async def get_maibot_status(_auth: bool = Depends(require_auth)):
"""
获取麦麦运行状态
@ -90,7 +100,7 @@ async def get_maibot_status():
@router.post("/reload-config")
async def reload_config():
async def reload_config(_auth: bool = Depends(require_auth)):
"""
热重载配置不重启进程

View File

@ -1,11 +1,12 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
from pydantic import BaseModel, Field
from typing import Optional
from src.common.logger import get_logger
from .token_manager import get_token_manager
from .auth import set_auth_cookie, clear_auth_cookie
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
from .config_routes import router as config_router
from .statistics_routes import router as statistics_router
from .person_routes import router as person_router
@ -16,6 +17,7 @@ from .plugin_routes import router as plugin_router
from .plugin_progress_ws import get_progress_router
from .routers.system import router as system_router
from .model_routes import router as model_router
from .ws_auth import router as ws_auth_router
logger = get_logger("webui.api")
@ -42,6 +44,8 @@ router.include_router(get_progress_router())
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
class TokenVerifyRequest(BaseModel):
@ -107,12 +111,18 @@ async def health_check():
@router.post("/auth/verify", response_model=TokenVerifyResponse)
async def verify_token(request: TokenVerifyRequest, response: Response):
async def verify_token(
request_body: TokenVerifyRequest,
request: Request,
response: Response,
_rate_limit: None = Depends(check_auth_rate_limit),
):
"""
验证访问令牌验证成功后设置 HttpOnly Cookie
Args:
request: 包含 token 的验证请求
request_body: 包含 token 的验证请求
request: FastAPI Request 对象用于获取客户端 IP
response: FastAPI Response 对象
Returns:
@ -120,16 +130,40 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
"""
try:
token_manager = get_token_manager()
is_valid = token_manager.verify_token(request.token)
rate_limiter = get_rate_limiter()
is_valid = token_manager.verify_token(request_body.token)
if is_valid:
# 认证成功,重置失败计数
rate_limiter.reset_failures(request)
# 设置 HttpOnly Cookie
set_auth_cookie(response, request.token)
set_auth_cookie(response, request_body.token)
# 同时返回首次配置状态,避免额外请求
is_first_setup = token_manager.is_first_setup()
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
else:
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
# 记录失败尝试
blocked, remaining = rate_limiter.record_failed_attempt(
request,
max_failures=5, # 5 次失败
window_seconds=300, # 5 分钟窗口
block_duration=600 # 封禁 10 分钟
)
if blocked:
raise HTTPException(
status_code=429,
detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟"
)
message = "Token 无效或已过期"
if remaining <= 2:
message += f"(剩余 {remaining} 次尝试机会)"
return TokenVerifyResponse(valid=False, message=message)
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 验证失败: {e}")
raise HTTPException(status_code=500, detail="Token 验证失败") from e

View File

@ -1,19 +1,28 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.statistics")
router = APIRouter(prefix="/statistics", tags=["statistics"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
@ -58,7 +67,7 @@ class DashboardData(BaseModel):
@router.get("/dashboard", response_model=DashboardData)
async def get_dashboard_data(hours: int = 24):
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取仪表盘统计数据
@ -275,7 +284,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
@router.get("/summary")
async def get_summary(hours: int = 24):
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取统计摘要
@ -293,7 +302,7 @@ async def get_summary(hours: int = 24):
@router.get("/models")
async def get_model_stats(hours: int = 24):
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取模型统计

View File

@ -46,12 +46,21 @@ class WebUIServer:
allow_origins=[
"http://localhost:5173", # Vite 开发服务器
"http://127.0.0.1:5173",
"http://localhost:7999", # 前端开发服务器备用端口
"http://127.0.0.1:7999",
"http://localhost:8001", # 生产环境
"http://127.0.0.1:8001",
],
allow_credentials=True, # 允许携带 Cookie
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"X-Requested-With",
], # 明确指定允许的头
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
)
logger.debug("✅ CORS 中间件已配置")
@ -89,23 +98,46 @@ class WebUIServer:
logger.warning("💡 请确认前端已正确构建")
return
# robots.txt - 禁止搜索引擎索引
@self.app.get("/robots.txt", include_in_schema=False)
async def robots_txt():
"""返回 robots.txt 禁止所有爬虫"""
from fastapi.responses import PlainTextResponse
content = """User-agent: *
Disallow: /
# MaiBot Dashboard - 私有管理面板,禁止索引
"""
return PlainTextResponse(
content=content,
headers={"X-Robots-Tag": "noindex, nofollow, noarchive"}
)
# 处理 SPA 路由 - 注意:这个路由优先级最低
@self.app.get("/{full_path:path}", include_in_schema=False)
async def serve_spa(full_path: str):
"""服务单页应用 - 只处理非 API 请求"""
# 如果是根路径,直接返回 index.html
if not full_path or full_path == "/":
return FileResponse(static_path / "index.html", media_type="text/html")
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 检查是否是静态文件
file_path = static_path / full_path
if file_path.is_file() and file_path.exists():
# 自动检测 MIME 类型
media_type = mimetypes.guess_type(str(file_path))[0]
return FileResponse(file_path, media_type=media_type)
response = FileResponse(file_path, media_type=media_type)
# HTML 文件添加防索引头
if str(file_path).endswith('.html'):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 其他路径返回 index.htmlSPA 路由)
return FileResponse(static_path / "index.html", media_type="text/html")
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")

View File

@ -0,0 +1,107 @@
"""WebSocket 认证模块
提供所有 WebSocket 端点统一使用的临时 token 认证机制
临时 token 有效期 60 且只能使用一次用于解决 WebSocket 握手时 Cookie 不可用的问题
"""
from fastapi import APIRouter, Cookie, Header, HTTPException
from typing import Optional
import secrets
import time
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
logger = get_logger("webui.ws_auth")
router = APIRouter()
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
_WS_TOKEN_EXPIRE_SECONDS = 60
def _cleanup_expired_ws_tokens():
"""清理过期的临时 token"""
now = time.time()
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
for t in expired:
del _ws_temp_tokens[t]
def generate_ws_token(session_token: str) -> str:
"""生成 WebSocket 临时 token
Args:
session_token: 原始的 session token
Returns:
临时 token 字符串
"""
_cleanup_expired_ws_tokens()
temp_token = secrets.token_urlsafe(32)
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
return temp_token
def verify_ws_token(temp_token: str) -> bool:
"""验证并消费 WebSocket 临时 token一次性使用
Args:
temp_token: 临时 token
Returns:
验证是否通过
"""
_cleanup_expired_ws_tokens()
if temp_token not in _ws_temp_tokens:
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
return False
expire_time, session_token = _ws_temp_tokens[temp_token]
if time.time() > expire_time:
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
return False
# 验证原始 session token 仍然有效
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
return False
# 消费 token一次性使用
del _ws_temp_tokens[temp_token]
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
return True
@router.get("/ws-token")
async def get_ws_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取 WebSocket 连接用的临时 token
此端点验证当前会话的 Cookie Authorization header
然后返回一个临时 token 用于 WebSocket 握手认证
临时 token 有效期 60 且只能使用一次
"""
# 获取当前 session token
session_token = None
if maibot_session:
session_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "")
if not session_token:
raise HTTPException(status_code=401, detail="未提供认证信息")
# 验证 session token
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
raise HTTPException(status_code=401, detail="认证已过期,请重新登录")
# 生成临时 WebSocket token
ws_token = generate_ws_token(session_token)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -4,10 +4,14 @@
<meta charset="UTF-8" />
<meta name="google" content="notranslate" />
<meta http-equiv="content-language" content="zh-CN" />
<!-- 防止搜索引擎索引 -->
<meta name="robots" content="noindex, nofollow, noarchive, nosnippet" />
<meta name="googlebot" content="noindex, nofollow" />
<meta name="bingbot" content="noindex, nofollow" />
<link rel="icon" type="image/x-icon" href="/maimai.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>MaiBot Dashboard</title>
<script type="module" crossorigin src="/assets/index-DcGiKm2P.js"></script>
<script type="module" crossorigin src="/assets/index-D-S1XZ00.js"></script>
<link rel="modulepreload" crossorigin href="/assets/react-vendor-BmxF9s7Q.js">
<link rel="modulepreload" crossorigin href="/assets/router-Bz250laD.js">
<link rel="modulepreload" crossorigin href="/assets/utils-BXc2jIuz.js">