From ea420f9f5953e7d44c28f4c19b19c11744b67521 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sun, 14 Dec 2025 19:39:56 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=AE=A4=E8=AF=81=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E5=92=8C=E8=AF=B7=E6=B1=82=E9=A2=91=E7=8E=87=E9=99=90?= =?UTF-8?q?=E5=88=B6=E6=A8=A1=E5=9D=97=EF=BC=8C=E5=A2=9E=E5=BC=BA=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E6=80=A7=E5=92=8C=E9=98=B2=E6=AD=A2API=E6=BB=A5?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/webui/auth.py | 40 ++++- src/webui/chat_routes.py | 40 ++++- src/webui/config_routes.py | 43 ++++-- src/webui/knowledge_routes.py | 16 +- src/webui/logs_ws.py | 27 +++- src/webui/model_routes.py | 15 +- src/webui/plugin_progress_ws.py | 28 +++- src/webui/rate_limiter.py | 264 ++++++++++++++++++++++++++++++++ src/webui/routers/system.py | 18 ++- src/webui/routes.py | 43 +++++- src/webui/statistics_routes.py | 19 ++- src/webui/webui_server.py | 11 +- 12 files changed, 509 insertions(+), 55 deletions(-) create mode 100644 src/webui/rate_limiter.py diff --git a/src/webui/auth.py b/src/webui/auth.py index 804cef55..c5989387 100644 --- a/src/webui/auth.py +++ b/src/webui/auth.py @@ -3,6 +3,7 @@ WebUI 认证模块 提供统一的认证依赖,支持 Cookie 和 Header 两种方式 """ +import os from typing import Optional from fastapi import HTTPException, Cookie, Header, Response, Request from src.common.logger import get_logger @@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session" COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天 +def _is_secure_environment() -> bool: + """ + 检测是否应该启用安全 Cookie(HTTPS) + + Returns: + bool: 如果应该使用 secure cookie 则返回 True + """ + # 检查环境变量 + if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("true", "1", "yes"): + return True + if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"): + return False + + # 检查是否是生产环境 + env = os.environ.get("WEBUI_MODE", "").lower() + if env in ("production", "prod"): + return True + + # 默认:开发环境不启用(因为通常是 HTTP) + return False + + def get_current_token( request: Request, maibot_session: Optional[str] = Cookie(None), @@ -62,16 +85,19 @@ def set_auth_cookie(response: Response, token: str) -> None: response: FastAPI Response 对象 token: 要设置的 token """ + # 根据环境决定安全设置 + is_secure = _is_secure_environment() + response.set_cookie( key=COOKIE_NAME, value=token, max_age=COOKIE_MAX_AGE, - httponly=True, # 防止 JS 读取 - samesite="lax", # 允许同站导航时发送 Cookie(兼容开发环境代理) - secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True + httponly=True, # 防止 JS 读取,阻止 XSS 窃取 + samesite="strict" if is_secure else "lax", # 生产环境使用 strict 防止 CSRF + secure=is_secure, # 生产环境强制 HTTPS path="/", # 确保 Cookie 在所有路径下可用 ) - logger.debug(f"已设置认证 Cookie: {token[:8]}...") + logger.debug(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure})") def clear_auth_cookie(response: Response) -> None: @@ -81,10 +107,14 @@ def clear_auth_cookie(response: Response) -> None: Args: response: FastAPI Response 对象 """ + # 保持与 set_auth_cookie 相同的安全设置 + is_secure = _is_secure_environment() + response.delete_cookie( key=COOKIE_NAME, httponly=True, - samesite="lax", + samesite="strict" if is_secure else "lax", + secure=is_secure, path="/", ) logger.debug("已清除认证 Cookie") diff --git a/src/webui/chat_routes.py b/src/webui/chat_routes.py index 5e492cb2..6dfbca5b 100644 --- a/src/webui/chat_routes.py +++ b/src/webui/chat_routes.py @@ -8,18 +8,28 @@ import time import uuid from typing import Dict, Any, Optional, List -from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query +from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header from pydantic import BaseModel from src.common.logger import get_logger from src.common.database.database_model import Messages, PersonInfo from src.config.config import global_config from src.chat.message_receive.bot import chat_bot +from src.webui.auth import verify_auth_token_from_cookie_or_header +from src.webui.token_manager import get_token_manager logger = get_logger("webui.chat") router = APIRouter(prefix="/api/chat", tags=["LocalChat"]) + +def require_auth( + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> bool: + """认证依赖:验证用户是否已登录""" + return verify_auth_token_from_cookie_or_header(maibot_session, authorization) + # WebUI 聊天的虚拟群组 ID WEBUI_CHAT_GROUP_ID = "webui_local_chat" WEBUI_CHAT_PLATFORM = "webui" @@ -256,6 +266,7 @@ async def get_chat_history( limit: int = Query(default=50, ge=1, le=200), user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤 group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史 + _auth: bool = Depends(require_auth), ): """获取聊天历史记录 @@ -272,7 +283,7 @@ async def get_chat_history( @router.get("/platforms") -async def get_available_platforms(): +async def get_available_platforms(_auth: bool = Depends(require_auth)): """获取可用平台列表 从 PersonInfo 表中获取所有已知的平台 @@ -303,6 +314,7 @@ async def get_persons_by_platform( platform: str = Query(..., description="平台名称"), search: Optional[str] = Query(default=None, description="搜索关键词"), limit: int = Query(default=50, ge=1, le=200), + _auth: bool = Depends(require_auth), ): """获取指定平台的用户列表 @@ -350,7 +362,7 @@ async def get_persons_by_platform( @router.delete("/history") -async def clear_chat_history(group_id: Optional[str] = Query(default=None)): +async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)): """清空聊天历史记录 Args: @@ -372,6 +384,7 @@ async def websocket_chat( person_id: Optional[str] = Query(default=None), group_name: Optional[str] = Query(default=None), group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id + token: Optional[str] = Query(default=None), # 认证 token ): """WebSocket 聊天端点 @@ -382,9 +395,28 @@ async def websocket_chat( person_id: 虚拟身份模式的用户 person_id(可选) group_name: 虚拟身份模式的群名(可选) group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化) + token: 认证 token(可选,也可从 Cookie 获取) 虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置 """ + # 认证检查 + auth_token = token + if not auth_token: + # 尝试从 Cookie 获取 token + auth_token = websocket.cookies.get("maibot_session") + + if not auth_token: + logger.warning("WebSocket 聊天连接被拒绝:未提供认证 token") + await websocket.close(code=4001, reason="未提供认证信息") + return + + # 验证 token + token_manager = get_token_manager() + if not token_manager.verify_token(auth_token): + logger.warning("WebSocket 聊天连接被拒绝:token 无效") + await websocket.close(code=4003, reason="Token 无效或已过期") + return + # 生成会话 ID(每次连接都是新的) session_id = str(uuid.uuid4()) @@ -712,7 +744,7 @@ async def websocket_chat( @router.get("/info") -async def get_chat_info(): +async def get_chat_info(_auth: bool = Depends(require_auth)): """获取聊天室信息""" return { "bot_name": global_config.bot.nickname, diff --git a/src/webui/config_routes.py b/src/webui/config_routes.py index 438683b6..58557aa7 100644 --- a/src/webui/config_routes.py +++ b/src/webui/config_routes.py @@ -4,10 +4,11 @@ import os import tomlkit -from fastapi import APIRouter, HTTPException, Body -from typing import Any, Annotated +from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header +from typing import Any, Annotated, Optional from src.common.logger import get_logger +from src.webui.auth import verify_auth_token_from_cookie_or_header from src.common.toml_utils import save_toml_with_format, _update_toml_doc from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT from src.config.official_configs import ( @@ -49,11 +50,19 @@ PathBody = Annotated[dict[str, str], Body()] router = APIRouter(prefix="/config", tags=["config"]) +def require_auth( + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> bool: + """认证依赖:验证用户是否已登录""" + return verify_auth_token_from_cookie_or_header(maibot_session, authorization) + + # ===== 架构获取接口 ===== @router.get("/schema/bot") -async def get_bot_config_schema(): +async def get_bot_config_schema(_auth: bool = Depends(require_auth)): """获取麦麦主程序配置架构""" try: # Config 类包含所有子配置 @@ -65,7 +74,7 @@ async def get_bot_config_schema(): @router.get("/schema/model") -async def get_model_config_schema(): +async def get_model_config_schema(_auth: bool = Depends(require_auth)): """获取模型配置架构(包含提供商和模型任务配置)""" try: schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig) @@ -79,7 +88,7 @@ async def get_model_config_schema(): @router.get("/schema/section/{section_name}") -async def get_config_section_schema(section_name: str): +async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)): """ 获取指定配置节的架构 @@ -149,7 +158,7 @@ async def get_config_section_schema(section_name: str): @router.get("/bot") -async def get_bot_config(): +async def get_bot_config(_auth: bool = Depends(require_auth)): """获取麦麦主程序配置""" try: config_path = os.path.join(CONFIG_DIR, "bot_config.toml") @@ -168,7 +177,7 @@ async def get_bot_config(): @router.get("/model") -async def get_model_config(): +async def get_model_config(_auth: bool = Depends(require_auth)): """获取模型配置(包含提供商和模型任务配置)""" try: config_path = os.path.join(CONFIG_DIR, "model_config.toml") @@ -190,7 +199,7 @@ async def get_model_config(): @router.post("/bot") -async def update_bot_config(config_data: ConfigBody): +async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)): """更新麦麦主程序配置""" try: # 验证配置数据 @@ -213,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody): @router.post("/model") -async def update_model_config(config_data: ConfigBody): +async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)): """更新模型配置""" try: # 验证配置数据 @@ -239,7 +248,7 @@ async def update_model_config(config_data: ConfigBody): @router.post("/bot/section/{section_name}") -async def update_bot_config_section(section_name: str, section_data: SectionBody): +async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)): """更新麦麦主程序配置的指定节(保留注释和格式)""" try: # 读取现有配置 @@ -288,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody @router.get("/bot/raw") -async def get_bot_config_raw(): +async def get_bot_config_raw(_auth: bool = Depends(require_auth)): """获取麦麦主程序配置的原始 TOML 内容""" try: config_path = os.path.join(CONFIG_DIR, "bot_config.toml") @@ -307,7 +316,7 @@ async def get_bot_config_raw(): @router.post("/bot/raw") -async def update_bot_config_raw(raw_content: RawContentBody): +async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)): """更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)""" try: # 验证 TOML 格式 @@ -337,7 +346,7 @@ async def update_bot_config_raw(raw_content: RawContentBody): @router.post("/model/section/{section_name}") -async def update_model_config_section(section_name: str, section_data: SectionBody): +async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)): """更新模型配置的指定节(保留注释和格式)""" try: # 读取现有配置 @@ -430,7 +439,7 @@ def _to_relative_path(path: str) -> str: @router.get("/adapter-config/path") -async def get_adapter_config_path(): +async def get_adapter_config_path(_auth: bool = Depends(require_auth)): """获取保存的适配器配置文件路径""" try: # 从 data/webui.json 读取路径偏好 @@ -469,7 +478,7 @@ async def get_adapter_config_path(): @router.post("/adapter-config/path") -async def save_adapter_config_path(data: PathBody): +async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)): """保存适配器配置文件路径偏好""" try: path = data.get("path") @@ -512,7 +521,7 @@ async def save_adapter_config_path(data: PathBody): @router.get("/adapter-config") -async def get_adapter_config(path: str): +async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)): """从指定路径读取适配器配置文件""" try: if not path: @@ -544,7 +553,7 @@ async def get_adapter_config(path: str): @router.post("/adapter-config") -async def save_adapter_config(data: PathBody): +async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)): """保存适配器配置到指定路径""" try: path = data.get("path") diff --git a/src/webui/knowledge_routes.py b/src/webui/knowledge_routes.py index af4594b6..87b2e7b5 100644 --- a/src/webui/knowledge_routes.py +++ b/src/webui/knowledge_routes.py @@ -1,15 +1,24 @@ """知识库图谱可视化 API 路由""" from typing import List, Optional -from fastapi import APIRouter, Query +from fastapi import APIRouter, Query, Depends, Cookie, Header from pydantic import BaseModel import logging +from src.webui.auth import verify_auth_token_from_cookie_or_header logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"]) +def require_auth( + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> bool: + """认证依赖:验证用户是否已登录""" + return verify_auth_token_from_cookie_or_header(maibot_session, authorization) + + class KnowledgeNode(BaseModel): """知识节点""" @@ -113,6 +122,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph: async def get_knowledge_graph( limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"), node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"), + _auth: bool = Depends(require_auth), ): """获取知识图谱(限制节点数量) @@ -199,7 +209,7 @@ async def get_knowledge_graph( @router.get("/stats", response_model=KnowledgeStats) -async def get_knowledge_stats(): +async def get_knowledge_stats(_auth: bool = Depends(require_auth)): """获取知识库统计信息 Returns: @@ -248,7 +258,7 @@ async def get_knowledge_stats(): @router.get("/search", response_model=List[KnowledgeNode]) -async def search_knowledge_node(query: str = Query(..., min_length=1)): +async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)): """搜索知识节点 Args: diff --git a/src/webui/logs_ws.py b/src/webui/logs_ws.py index e0e0a9a1..836191ee 100644 --- a/src/webui/logs_ws.py +++ b/src/webui/logs_ws.py @@ -1,10 +1,11 @@ """WebSocket 日志推送模块""" -from fastapi import APIRouter, WebSocket, WebSocketDisconnect -from typing import Set +from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query +from typing import Set, Optional import json from pathlib import Path from src.common.logger import get_logger +from src.webui.token_manager import get_token_manager logger = get_logger("webui.logs_ws") router = APIRouter() @@ -73,14 +74,32 @@ def load_recent_logs(limit: int = 100) -> list[dict]: @router.websocket("/ws/logs") -async def websocket_logs(websocket: WebSocket): +async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)): """WebSocket 日志推送端点 客户端连接后会持续接收服务器端的日志消息 + 需要通过 query 参数传递 token 进行认证,例如:ws://host/ws/logs?token=xxx """ + # 认证检查 + if not token: + # 尝试从 Cookie 获取 token + token = websocket.cookies.get("maibot_session") + + if not token: + logger.warning("WebSocket 连接被拒绝:未提供认证 token") + await websocket.close(code=4001, reason="未提供认证信息") + return + + # 验证 token + token_manager = get_token_manager() + if not token_manager.verify_token(token): + logger.warning("WebSocket 连接被拒绝:token 无效") + await websocket.close(code=4003, reason="Token 无效或已过期") + return + await websocket.accept() active_connections.add(websocket) - logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}") + logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}") # 连接建立后,立即发送历史日志 try: diff --git a/src/webui/model_routes.py b/src/webui/model_routes.py index 7d8310ee..a84241b9 100644 --- a/src/webui/model_routes.py +++ b/src/webui/model_routes.py @@ -6,18 +6,27 @@ import os import httpx -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header from typing import Optional import tomlkit from src.common.logger import get_logger from src.config.config import CONFIG_DIR +from src.webui.auth import verify_auth_token_from_cookie_or_header logger = get_logger("webui") router = APIRouter(prefix="/models", tags=["models"]) +def require_auth( + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> bool: + """认证依赖:验证用户是否已登录""" + return verify_auth_token_from_cookie_or_header(maibot_session, authorization) + + # 模型获取器配置 MODEL_FETCHER_CONFIG = { # OpenAI 兼容格式的提供商 @@ -184,6 +193,7 @@ async def get_provider_models( provider_name: str = Query(..., description="提供商名称"), parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), + _auth: bool = Depends(require_auth), ): """ 获取指定提供商的可用模型列表 @@ -228,6 +238,7 @@ async def get_models_by_url( parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), client_type: str = Query("openai", description="客户端类型 (openai | gemini)"), + _auth: bool = Depends(require_auth), ): """ 通过 URL 直接获取模型列表(用于自定义提供商) @@ -251,6 +262,7 @@ async def get_models_by_url( async def test_provider_connection( base_url: str = Query(..., description="提供商的基础 URL"), api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"), + _auth: bool = Depends(require_auth), ): """ 测试提供商连接状态 @@ -337,6 +349,7 @@ async def test_provider_connection( @router.post("/test-connection-by-name") async def test_provider_connection_by_name( provider_name: str = Query(..., description="提供商名称"), + _auth: bool = Depends(require_auth), ): """ 通过提供商名称测试连接(从配置文件读取信息) diff --git a/src/webui/plugin_progress_ws.py b/src/webui/plugin_progress_ws.py index 7e0fb647..3d334ca9 100644 --- a/src/webui/plugin_progress_ws.py +++ b/src/webui/plugin_progress_ws.py @@ -1,10 +1,11 @@ """WebSocket 插件加载进度推送模块""" -from fastapi import APIRouter, WebSocket, WebSocketDisconnect -from typing import Set, Dict, Any +from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query +from typing import Set, Dict, Any, Optional import json import asyncio from src.common.logger import get_logger +from src.webui.token_manager import get_token_manager logger = get_logger("webui.plugin_progress") @@ -89,14 +90,33 @@ async def update_progress( @router.websocket("/ws/plugin-progress") -async def websocket_plugin_progress(websocket: WebSocket): +async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)): """WebSocket 插件加载进度推送端点 客户端连接后会立即收到当前进度状态 + 需要通过 query 参数或 Cookie 传递 token 进行认证 """ + # 认证检查 + auth_token = token + if not auth_token: + # 尝试从 Cookie 获取 token + auth_token = websocket.cookies.get("maibot_session") + + if not auth_token: + logger.warning("插件进度 WebSocket 连接被拒绝:未提供认证 token") + await websocket.close(code=4001, reason="未提供认证信息") + return + + # 验证 token + token_manager = get_token_manager() + if not token_manager.verify_token(auth_token): + logger.warning("插件进度 WebSocket 连接被拒绝:token 无效") + await websocket.close(code=4003, reason="Token 无效或已过期") + return + await websocket.accept() active_connections.add(websocket) - logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}") + logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}") try: # 发送当前进度状态 diff --git a/src/webui/rate_limiter.py b/src/webui/rate_limiter.py new file mode 100644 index 00000000..675e1c02 --- /dev/null +++ b/src/webui/rate_limiter.py @@ -0,0 +1,264 @@ +""" +WebUI 请求频率限制模块 +防止暴力破解和 API 滥用 +""" + +import time +from collections import defaultdict +from typing import Dict, Tuple, Optional +from fastapi import Request, HTTPException +from src.common.logger import get_logger + +logger = get_logger("webui.rate_limiter") + + +class RateLimiter: + """ + 简单的内存请求频率限制器 + + 使用滑动窗口算法实现 + """ + + def __init__(self): + # 存储格式: {key: [(timestamp, count), ...]} + self._requests: Dict[str, list] = defaultdict(list) + # 被封禁的 IP: {ip: unblock_timestamp} + self._blocked: Dict[str, float] = {} + + def _get_client_ip(self, request: Request) -> str: + """获取客户端 IP 地址""" + # 检查代理头 + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + # 取第一个 IP(最原始的客户端) + return forwarded.split(",")[0].strip() + + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + + # 直接连接的客户端 + if request.client: + return request.client.host + + return "unknown" + + def _cleanup_old_requests(self, key: str, window_seconds: int): + """清理过期的请求记录""" + now = time.time() + cutoff = now - window_seconds + self._requests[key] = [ + (ts, count) for ts, count in self._requests[key] + if ts > cutoff + ] + + def _cleanup_expired_blocks(self): + """清理过期的封禁""" + now = time.time() + expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time] + for ip in expired: + del self._blocked[ip] + logger.info(f"🔓 IP {ip} 封禁已解除") + + def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]: + """ + 检查 IP 是否被封禁 + + Returns: + (是否被封禁, 剩余封禁秒数) + """ + self._cleanup_expired_blocks() + ip = self._get_client_ip(request) + + if ip in self._blocked: + remaining = int(self._blocked[ip] - time.time()) + return True, max(0, remaining) + + return False, None + + def check_rate_limit( + self, + request: Request, + max_requests: int, + window_seconds: int, + key_suffix: str = "" + ) -> Tuple[bool, int]: + """ + 检查请求是否超过频率限制 + + Args: + request: FastAPI Request 对象 + max_requests: 窗口期内允许的最大请求数 + window_seconds: 窗口时间(秒) + key_suffix: 键后缀,用于区分不同的限制规则 + + Returns: + (是否允许, 剩余请求数) + """ + ip = self._get_client_ip(request) + key = f"{ip}:{key_suffix}" if key_suffix else ip + + # 清理过期记录 + self._cleanup_old_requests(key, window_seconds) + + # 计算当前窗口内的请求数 + current_count = sum(count for _, count in self._requests[key]) + + if current_count >= max_requests: + return False, 0 + + # 记录新请求 + now = time.time() + self._requests[key].append((now, 1)) + + remaining = max_requests - current_count - 1 + return True, remaining + + def block_ip(self, request: Request, duration_seconds: int): + """ + 封禁 IP + + Args: + request: FastAPI Request 对象 + duration_seconds: 封禁时长(秒) + """ + ip = self._get_client_ip(request) + self._blocked[ip] = time.time() + duration_seconds + logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒") + + def record_failed_attempt( + self, + request: Request, + max_failures: int = 5, + window_seconds: int = 300, + block_duration: int = 600 + ) -> Tuple[bool, int]: + """ + 记录失败尝试(如登录失败) + + 如果在窗口期内失败次数过多,自动封禁 IP + + Args: + request: FastAPI Request 对象 + max_failures: 允许的最大失败次数 + window_seconds: 统计窗口(秒) + block_duration: 封禁时长(秒) + + Returns: + (是否被封禁, 剩余尝试次数) + """ + ip = self._get_client_ip(request) + key = f"{ip}:auth_failures" + + # 清理过期记录 + self._cleanup_old_requests(key, window_seconds) + + # 计算当前失败次数 + current_failures = sum(count for _, count in self._requests[key]) + + # 记录本次失败 + now = time.time() + self._requests[key].append((now, 1)) + current_failures += 1 + + remaining = max_failures - current_failures + + # 检查是否需要封禁 + if current_failures >= max_failures: + self.block_ip(request, block_duration) + logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁") + return True, 0 + + if current_failures >= max_failures - 2: + logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次") + + return False, max(0, remaining) + + def reset_failures(self, request: Request): + """ + 重置失败计数(认证成功后调用) + """ + ip = self._get_client_ip(request) + key = f"{ip}:auth_failures" + if key in self._requests: + del self._requests[key] + + +# 全局单例 +_rate_limiter: Optional[RateLimiter] = None + + +def get_rate_limiter() -> RateLimiter: + """获取 RateLimiter 单例""" + global _rate_limiter + if _rate_limiter is None: + _rate_limiter = RateLimiter() + return _rate_limiter + + +async def check_auth_rate_limit(request: Request): + """ + 认证接口的频率限制依赖 + + 规则: + - 每个 IP 每分钟最多 10 次认证请求 + - 连续失败 5 次后封禁 10 分钟 + """ + limiter = get_rate_limiter() + + # 检查是否被封禁 + blocked, remaining_block = limiter.is_blocked(request) + if blocked: + raise HTTPException( + status_code=429, + detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", + headers={"Retry-After": str(remaining_block)} + ) + + # 检查频率限制 + allowed, remaining = limiter.check_rate_limit( + request, + max_requests=10, # 每分钟 10 次 + window_seconds=60, + key_suffix="auth" + ) + + if not allowed: + raise HTTPException( + status_code=429, + detail="认证请求过于频繁,请稍后重试", + headers={"Retry-After": "60"} + ) + + +async def check_api_rate_limit(request: Request): + """ + 普通 API 的频率限制依赖 + + 规则:每个 IP 每分钟最多 100 次请求 + """ + limiter = get_rate_limiter() + + # 检查是否被封禁 + blocked, remaining_block = limiter.is_blocked(request) + if blocked: + raise HTTPException( + status_code=429, + detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", + headers={"Retry-After": str(remaining_block)} + ) + + # 检查频率限制 + allowed, _ = limiter.check_rate_limit( + request, + max_requests=100, # 每分钟 100 次 + window_seconds=60, + key_suffix="api" + ) + + if not allowed: + raise HTTPException( + status_code=429, + detail="请求过于频繁,请稍后重试", + headers={"Retry-After": "60"} + ) diff --git a/src/webui/routers/system.py b/src/webui/routers/system.py index d6932896..b1d3729a 100644 --- a/src/webui/routers/system.py +++ b/src/webui/routers/system.py @@ -7,10 +7,12 @@ import os import time from datetime import datetime -from fastapi import APIRouter, HTTPException +from typing import Optional +from fastapi import APIRouter, HTTPException, Depends, Cookie, Header from pydantic import BaseModel from src.config.config import MMC_VERSION from src.common.logger import get_logger +from src.webui.auth import verify_auth_token_from_cookie_or_header router = APIRouter(prefix="/system", tags=["system"]) logger = get_logger("webui_system") @@ -19,6 +21,14 @@ logger = get_logger("webui_system") _start_time = time.time() +def require_auth( + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> bool: + """认证依赖:验证用户是否已登录""" + return verify_auth_token_from_cookie_or_header(maibot_session, authorization) + + class RestartResponse(BaseModel): """重启响应""" @@ -36,7 +46,7 @@ class StatusResponse(BaseModel): @router.post("/restart", response_model=RestartResponse) -async def restart_maibot(): +async def restart_maibot(_auth: bool = Depends(require_auth)): """ 重启麦麦主程序 @@ -67,7 +77,7 @@ async def restart_maibot(): @router.get("/status", response_model=StatusResponse) -async def get_maibot_status(): +async def get_maibot_status(_auth: bool = Depends(require_auth)): """ 获取麦麦运行状态 @@ -90,7 +100,7 @@ async def get_maibot_status(): @router.post("/reload-config") -async def reload_config(): +async def reload_config(_auth: bool = Depends(require_auth)): """ 热重载配置(不重启进程) diff --git a/src/webui/routes.py b/src/webui/routes.py index 36ee8b1f..8be6f84f 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -1,11 +1,12 @@ """WebUI API 路由""" -from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie +from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends from pydantic import BaseModel, Field from typing import Optional from src.common.logger import get_logger from .token_manager import get_token_manager from .auth import set_auth_cookie, clear_auth_cookie +from .rate_limiter import get_rate_limiter, check_auth_rate_limit from .config_routes import router as config_router from .statistics_routes import router as statistics_router from .person_routes import router as person_router @@ -107,12 +108,18 @@ async def health_check(): @router.post("/auth/verify", response_model=TokenVerifyResponse) -async def verify_token(request: TokenVerifyRequest, response: Response): +async def verify_token( + request_body: TokenVerifyRequest, + request: Request, + response: Response, + _rate_limit: None = Depends(check_auth_rate_limit), +): """ 验证访问令牌,验证成功后设置 HttpOnly Cookie Args: - request: 包含 token 的验证请求 + request_body: 包含 token 的验证请求 + request: FastAPI Request 对象(用于获取客户端 IP) response: FastAPI Response 对象 Returns: @@ -120,16 +127,40 @@ async def verify_token(request: TokenVerifyRequest, response: Response): """ try: token_manager = get_token_manager() - is_valid = token_manager.verify_token(request.token) + rate_limiter = get_rate_limiter() + + is_valid = token_manager.verify_token(request_body.token) if is_valid: + # 认证成功,重置失败计数 + rate_limiter.reset_failures(request) # 设置 HttpOnly Cookie - set_auth_cookie(response, request.token) + set_auth_cookie(response, request_body.token) # 同时返回首次配置状态,避免额外请求 is_first_setup = token_manager.is_first_setup() return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup) else: - return TokenVerifyResponse(valid=False, message="Token 无效或已过期") + # 记录失败尝试 + blocked, remaining = rate_limiter.record_failed_attempt( + request, + max_failures=5, # 5 次失败 + window_seconds=300, # 5 分钟窗口 + block_duration=600 # 封禁 10 分钟 + ) + + if blocked: + raise HTTPException( + status_code=429, + detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟" + ) + + message = "Token 无效或已过期" + if remaining <= 2: + message += f"(剩余 {remaining} 次尝试机会)" + + return TokenVerifyResponse(valid=False, message=message) + except HTTPException: + raise except Exception as e: logger.error(f"Token 验证失败: {e}") raise HTTPException(status_code=500, detail="Token 验证失败") from e diff --git a/src/webui/statistics_routes.py b/src/webui/statistics_routes.py index b0a3664c..e5628538 100644 --- a/src/webui/statistics_routes.py +++ b/src/webui/statistics_routes.py @@ -1,19 +1,28 @@ """统计数据 API 路由""" -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Depends, Cookie, Header from pydantic import BaseModel, Field -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional from datetime import datetime, timedelta from peewee import fn from src.common.logger import get_logger from src.common.database.database_model import LLMUsage, OnlineTime, Messages +from src.webui.auth import verify_auth_token_from_cookie_or_header logger = get_logger("webui.statistics") router = APIRouter(prefix="/statistics", tags=["statistics"]) +def require_auth( + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> bool: + """认证依赖:验证用户是否已登录""" + return verify_auth_token_from_cookie_or_header(maibot_session, authorization) + + class StatisticsSummary(BaseModel): """统计数据摘要""" @@ -58,7 +67,7 @@ class DashboardData(BaseModel): @router.get("/dashboard", response_model=DashboardData) -async def get_dashboard_data(hours: int = 24): +async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)): """ 获取仪表盘统计数据 @@ -275,7 +284,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]: @router.get("/summary") -async def get_summary(hours: int = 24): +async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)): """ 获取统计摘要 @@ -293,7 +302,7 @@ async def get_summary(hours: int = 24): @router.get("/models") -async def get_model_stats(hours: int = 24): +async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)): """ 获取模型统计 diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index 87b47192..4ecd509d 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -44,8 +44,15 @@ class WebUIServer: "http://127.0.0.1:8001", ], allow_credentials=True, # 允许携带 Cookie - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法 + allow_headers=[ + "Content-Type", + "Authorization", + "Accept", + "Origin", + "X-Requested-With", + ], # 明确指定允许的头 + expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头 ) logger.debug("✅ CORS 中间件已配置")