From 74a2f4346afae8af8fee534bde8fe6d217ee637f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sun, 14 Dec 2025 21:39:09 +0800 Subject: [PATCH] Ruff Format --- src/webui/anti_crawler.py | 81 ++++++++++++--------- src/webui/auth.py | 10 +-- src/webui/chat_routes.py | 15 ++-- src/webui/config_routes.py | 7 +- src/webui/logs_ws.py | 12 +-- src/webui/plugin_progress_ws.py | 12 +-- src/webui/plugin_routes.py | 44 +++++------ src/webui/rate_limiter.py | 125 ++++++++++++++------------------ src/webui/routes.py | 19 ++--- src/webui/webui_server.py | 26 +++---- src/webui/ws_auth.py | 22 +++--- 11 files changed, 176 insertions(+), 197 deletions(-) diff --git a/src/webui/anti_crawler.py b/src/webui/anti_crawler.py index c82afa7c..997d854b 100644 --- a/src/webui/anti_crawler.py +++ b/src/webui/anti_crawler.py @@ -126,6 +126,7 @@ SCANNER_SPECIFIC_HEADERS = { # basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP) ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower() + # IP白名单配置(从环境变量读取,逗号分隔) # 支持格式: # - 精确IP:127.0.0.1, 192.168.1.100 @@ -135,10 +136,10 @@ ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower() def _parse_allowed_ips(ip_string: str) -> list: """ 解析IP白名单字符串,支持精确IP、CIDR格式和通配符 - + Args: ip_string: 逗号分隔的IP字符串 - + Returns: IP白名单列表,每个元素可能是: - ipaddress.IPv4Network/IPv6Network对象(CIDR格式) @@ -148,12 +149,12 @@ def _parse_allowed_ips(ip_string: str) -> list: allowed = [] if not ip_string: return allowed - + for ip_entry in ip_string.split(","): ip_entry = ip_entry.strip() # 去除空格 if not ip_entry: continue - + # 检查通配符格式(包含*) if "*" in ip_entry: # 处理通配符 @@ -163,7 +164,7 @@ def _parse_allowed_ips(ip_string: str) -> list: else: logger.warning(f"无效的通配符IP格式,已忽略: {ip_entry}") continue - + try: # 尝试解析为CIDR格式(包含/) if "/" in ip_entry: @@ -173,39 +174,39 @@ def _parse_allowed_ips(ip_string: str) -> list: allowed.append(ipaddress.ip_address(ip_entry)) except (ValueError, AttributeError) as e: logger.warning(f"无效的IP白名单条目,已忽略: {ip_entry} ({e})") - + return allowed def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]: """ 将通配符IP模式转换为正则表达式 - + 支持的格式: - 192.168.*.* 或 192.168.* - 10.*.*.* 或 10.* - *.*.*.* 或 * - + Args: wildcard_pattern: 通配符模式字符串 - + Returns: 正则表达式字符串,如果格式无效则返回None """ # 去除空格 pattern = wildcard_pattern.strip() - + # 处理单个*(匹配所有) if pattern == "*": return r".*" - + # 处理IPv4通配符格式 # 支持:192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等 parts = pattern.split(".") - + if len(parts) > 4: return None # IPv4最多4段 - + # 构建正则表达式 regex_parts = [] for part in parts: @@ -221,15 +222,16 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]: return None # 无效的数字 else: return None # 无效的格式 - + # 如果部分少于4段,补充.* while len(regex_parts) < 4: regex_parts.append(r"\d+") - + # 组合成正则表达式 regex = r"^" + r"\.".join(regex_parts) + r"$" return regex + ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", "")) # 信任的代理IP配置(从环境变量读取,逗号分隔) @@ -250,7 +252,7 @@ def _get_mode_config(mode: str) -> dict: 配置字典,包含所有相关参数 """ mode = mode.lower() - + if mode == "false": return { "enabled": False, @@ -320,7 +322,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): self.check_asset_scanner = config["check_asset_scanner"] self.check_rate_limit = config["check_rate_limit"] self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问 - + # 用于存储每个IP的请求时间戳(使用deque提高性能) self.request_times: dict[str, deque] = {} # 上次清理时间 @@ -354,7 +356,6 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): return False - def _is_asset_scanner_header(self, request: Request) -> bool: """ 检测是否为资产测绘工具的HTTP头(只检查特定头,收紧匹配) @@ -499,7 +500,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): empty_ips = [] # 找到最久未访问的IP(最旧时间戳) oldest_ip = None - oldest_time = float('inf') + oldest_time = float("inf") # 全量遍历找真正的oldest(超限时性能可接受) for ip, times in self.request_times.items(): @@ -532,7 +533,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): """ if not TRUSTED_PROXIES or ip == "unknown": return False - + # 检查代理列表中的每个条目 for trusted_entry in TRUSTED_PROXIES: # 通配符模式(字符串,正则表达式) @@ -558,7 +559,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): return True except (ValueError, AttributeError): continue - + return False def _get_client_ip(self, request: Request) -> str: @@ -635,7 +636,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): """ if not ALLOWED_IPS or ip == "unknown": return False - + # 检查白名单中的每个条目 for allowed_entry in ALLOWED_IPS: # 通配符模式(字符串,正则表达式) @@ -664,7 +665,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): except (ValueError, AttributeError): # IP格式无效,跳过 continue - + return False async def dispatch(self, request: Request, call_next): @@ -689,16 +690,31 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): # 允许访问静态资源(CSS、JS、图片等) # 注意:.json 已移除,避免 API 路径绕过防护 # 静态资源只在特定前缀下放行(/static/、/assets/、/dist/) - static_extensions = {".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".ttf", ".eot"} + static_extensions = { + ".css", + ".js", + ".png", + ".jpg", + ".jpeg", + ".gif", + ".svg", + ".ico", + ".woff", + ".woff2", + ".ttf", + ".eot", + } static_prefixes = {"/static/", "/assets/", "/dist/"} - + # 检查是否是静态资源路径(特定前缀下的静态文件) path = request.url.path - is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(path.endswith(ext) for ext in static_extensions) - + is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any( + path.endswith(ext) for ext in static_extensions + ) + # 也允许根路径下的静态文件(如 /favicon.ico) is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions) - + if is_static_path or is_root_static: return await call_next(request) @@ -729,9 +745,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): # 检测爬虫 User-Agent if self.check_user_agent and self._is_crawler_user_agent(user_agent): - logger.warning( - f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}" - ) + logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}") # 根据配置决定是否阻止 if self.block_on_detect: return PlainTextResponse( @@ -741,9 +755,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): # 检查请求频率限制 if self.check_rate_limit and self._check_rate_limit(client_ip): - logger.warning( - f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}" - ) + logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}") return PlainTextResponse( "Too Many Requests: Rate limit exceeded", status_code=429, @@ -770,4 +782,3 @@ Disallow: / media_type="text/plain", headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时 ) - diff --git a/src/webui/auth.py b/src/webui/auth.py index c5989387..86621714 100644 --- a/src/webui/auth.py +++ b/src/webui/auth.py @@ -19,7 +19,7 @@ COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天 def _is_secure_environment() -> bool: """ 检测是否应该启用安全 Cookie(HTTPS) - + Returns: bool: 如果应该使用 secure cookie 则返回 True """ @@ -28,12 +28,12 @@ def _is_secure_environment() -> bool: return True if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"): return False - + # 检查是否是生产环境 env = os.environ.get("WEBUI_MODE", "").lower() if env in ("production", "prod"): return True - + # 默认:开发环境不启用(因为通常是 HTTP) return False @@ -87,7 +87,7 @@ def set_auth_cookie(response: Response, token: str) -> None: """ # 根据环境决定安全设置 is_secure = _is_secure_environment() - + response.set_cookie( key=COOKIE_NAME, value=token, @@ -109,7 +109,7 @@ def clear_auth_cookie(response: Response) -> None: """ # 保持与 set_auth_cookie 相同的安全设置 is_secure = _is_secure_environment() - + response.delete_cookie( key=COOKIE_NAME, httponly=True, diff --git a/src/webui/chat_routes.py b/src/webui/chat_routes.py index 4c889915..6535b9e9 100644 --- a/src/webui/chat_routes.py +++ b/src/webui/chat_routes.py @@ -31,6 +31,7 @@ def require_auth( """认证依赖:验证用户是否已登录""" return verify_auth_token_from_cookie_or_header(maibot_session, authorization) + # WebUI 聊天的虚拟群组 ID WEBUI_CHAT_GROUP_ID = "webui_local_chat" WEBUI_CHAT_PLATFORM = "webui" @@ -399,21 +400,21 @@ async def websocket_chat( token: 认证 token(可选,也可从 Cookie 获取) 虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置 - + 支持三种认证方式(按优先级): 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) 2. Cookie 中的 maibot_session 3. 直接使用 session token(兼容) - + 示例:ws://host/api/chat/ws?token=xxx """ is_authenticated = False - + # 方式 1: 尝试验证临时 WebSocket token(推荐方式) if token and verify_ws_token(token): is_authenticated = True logger.debug("聊天 WebSocket 使用临时 token 认证成功") - + # 方式 2: 尝试从 Cookie 获取 session token if not is_authenticated: cookie_token = websocket.cookies.get("maibot_session") @@ -422,19 +423,19 @@ async def websocket_chat( if token_manager.verify_token(cookie_token): is_authenticated = True logger.debug("聊天 WebSocket 使用 Cookie 认证成功") - + # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) if not is_authenticated and token: token_manager = get_token_manager() if token_manager.verify_token(token): is_authenticated = True logger.debug("聊天 WebSocket 使用 session token 认证成功") - + if not is_authenticated: logger.warning("聊天 WebSocket 连接被拒绝:认证失败") await websocket.close(code=4001, reason="认证失败,请重新登录") return - + # 生成会话 ID(每次连接都是新的) session_id = str(uuid.uuid4()) diff --git a/src/webui/config_routes.py b/src/webui/config_routes.py index 58557aa7..6a028927 100644 --- a/src/webui/config_routes.py +++ b/src/webui/config_routes.py @@ -346,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depen @router.post("/model/section/{section_name}") -async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)): +async def update_model_config_section( + section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth) +): """更新模型配置的指定节(保留注释和格式)""" try: # 读取现有配置 @@ -383,8 +385,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo provider_names = {p.get("name") for p in section_data if isinstance(p, dict)} models = config_data.get("models", []) orphaned_models = [ - m.get("name") for m in models - if isinstance(m, dict) and m.get("api_provider") not in provider_names + m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names ] if orphaned_models: error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。" diff --git a/src/webui/logs_ws.py b/src/webui/logs_ws.py index 382a09a2..5ae92189 100644 --- a/src/webui/logs_ws.py +++ b/src/webui/logs_ws.py @@ -83,16 +83,16 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) 2. Cookie 中的 maibot_session 3. 直接使用 session token(兼容) - + 示例:ws://host/ws/logs?token=xxx """ is_authenticated = False - + # 方式 1: 尝试验证临时 WebSocket token(推荐方式) if token and verify_ws_token(token): is_authenticated = True logger.debug("WebSocket 使用临时 token 认证成功") - + # 方式 2: 尝试从 Cookie 获取 session token if not is_authenticated: cookie_token = websocket.cookies.get("maibot_session") @@ -101,19 +101,19 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None if token_manager.verify_token(cookie_token): is_authenticated = True logger.debug("WebSocket 使用 Cookie 认证成功") - + # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) if not is_authenticated and token: token_manager = get_token_manager() if token_manager.verify_token(token): is_authenticated = True logger.debug("WebSocket 使用 session token 认证成功") - + if not is_authenticated: logger.warning("WebSocket 连接被拒绝:认证失败") await websocket.close(code=4001, reason="认证失败,请重新登录") return - + await websocket.accept() active_connections.add(websocket) logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}") diff --git a/src/webui/plugin_progress_ws.py b/src/webui/plugin_progress_ws.py index 2f576b6a..8d0a18c6 100644 --- a/src/webui/plugin_progress_ws.py +++ b/src/webui/plugin_progress_ws.py @@ -99,16 +99,16 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) 2. Cookie 中的 maibot_session 3. 直接使用 session token(兼容) - + 示例:ws://host/ws/plugin-progress?token=xxx """ is_authenticated = False - + # 方式 1: 尝试验证临时 WebSocket token(推荐方式) if token and verify_ws_token(token): is_authenticated = True logger.debug("插件进度 WebSocket 使用临时 token 认证成功") - + # 方式 2: 尝试从 Cookie 获取 session token if not is_authenticated: cookie_token = websocket.cookies.get("maibot_session") @@ -117,19 +117,19 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = if token_manager.verify_token(cookie_token): is_authenticated = True logger.debug("插件进度 WebSocket 使用 Cookie 认证成功") - + # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) if not is_authenticated and token: token_manager = get_token_manager() if token_manager.verify_token(token): is_authenticated = True logger.debug("插件进度 WebSocket 使用 session token 认证成功") - + if not is_authenticated: logger.warning("插件进度 WebSocket 连接被拒绝:认证失败") await websocket.close(code=4001, reason="认证失败,请重新登录") return - + await websocket.accept() active_connections.add(websocket) logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}") diff --git a/src/webui/plugin_routes.py b/src/webui/plugin_routes.py index fd5f7462..1d11a20c 100644 --- a/src/webui/plugin_routes.py +++ b/src/webui/plugin_routes.py @@ -3,7 +3,6 @@ from pydantic import BaseModel, Field from typing import Optional, List, Dict, Any, get_origin from pathlib import Path import json -import re from src.common.logger import get_logger from src.common.toml_utils import save_toml_with_format from src.config.config import MMC_VERSION @@ -38,54 +37,54 @@ def get_token_from_cookie_or_header( def validate_safe_path(user_path: str, base_path: Path) -> Path: """ 验证用户提供的路径是否安全,防止路径遍历攻击 - + Args: user_path: 用户输入的路径(相对路径) base_path: 允许的基础目录 - + Returns: 安全的绝对路径 - + Raises: HTTPException: 如果检测到路径遍历攻击 """ # 规范化基础路径 base_resolved = base_path.resolve() - + # 检查用户路径是否包含可疑字符 # 禁止: .., 绝对路径开头, 空字节等 if any(pattern in user_path for pattern in ["..", "\x00"]): logger.warning(f"检测到可疑路径: {user_path}") raise HTTPException(status_code=400, detail="路径包含非法字符") - + # 检查是否为绝对路径(Windows 和 Unix) if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"): logger.warning(f"检测到绝对路径: {user_path}") raise HTTPException(status_code=400, detail="不允许使用绝对路径") - + # 构建目标路径并解析 target_path = (base_path / user_path).resolve() - + # 验证解析后的路径仍在基础目录内 try: target_path.relative_to(base_resolved) except ValueError as e: logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}") raise HTTPException(status_code=400, detail="路径超出允许范围") from e - + return target_path def validate_plugin_id(plugin_id: str) -> str: """ 验证插件 ID 格式是否安全 - + Args: plugin_id: 插件 ID (支持 author.name 格式,允许中文) - + Returns: 验证通过的插件 ID - + Raises: HTTPException: 如果插件 ID 格式不安全 """ @@ -93,24 +92,24 @@ def validate_plugin_id(plugin_id: str) -> str: if not plugin_id or not plugin_id.strip(): logger.warning("非法插件 ID: 空字符串") raise HTTPException(status_code=400, detail="插件 ID 不能为空") - + # 禁止危险字符: 路径分隔符、空字节、控制字符等 dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"] for pattern in dangerous_patterns: if pattern in plugin_id: logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)") raise HTTPException(status_code=400, detail="插件 ID 包含非法字符") - + # 禁止以点开头或结尾(防止隐藏文件和路径问题) if plugin_id.startswith(".") or plugin_id.endswith("."): logger.warning(f"非法插件 ID: {plugin_id}") raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾") - + # 禁止特殊名称 if plugin_id in (".", ".."): logger.warning(f"非法插件 ID: {plugin_id}") raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名") - + return plugin_id @@ -556,10 +555,7 @@ async def fetch_raw_file( if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - logger.info( - f"收到获取 Raw 文件请求: " - f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}" - ) + logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}") # 发送开始加载进度 await update_progress( @@ -688,7 +684,7 @@ async def install_plugin( try: # 验证插件 ID 格式安全性 plugin_id = validate_plugin_id(request.plugin_id) - + # 推送进度:开始安装 await update_progress( stage="loading", @@ -899,7 +895,7 @@ async def uninstall_plugin( try: # 验证插件 ID 格式安全性 plugin_id = validate_plugin_id(request.plugin_id) - + # 推送进度:开始卸载 await update_progress( stage="loading", @@ -1041,7 +1037,7 @@ async def update_plugin( try: # 验证插件 ID 格式安全性 plugin_id = validate_plugin_id(request.plugin_id) - + # 推送进度:开始更新 await update_progress( stage="loading", @@ -1494,7 +1490,7 @@ async def get_plugin_config_schema( ui_type = "text" item_type = None item_fields = None - + if isinstance(field_value, bool): ui_type = "switch" elif isinstance(field_value, (int, float)): diff --git a/src/webui/rate_limiter.py b/src/webui/rate_limiter.py index 675e1c02..23cfc0f0 100644 --- a/src/webui/rate_limiter.py +++ b/src/webui/rate_limiter.py @@ -15,16 +15,16 @@ logger = get_logger("webui.rate_limiter") class RateLimiter: """ 简单的内存请求频率限制器 - + 使用滑动窗口算法实现 """ - + def __init__(self): # 存储格式: {key: [(timestamp, count), ...]} self._requests: Dict[str, list] = defaultdict(list) # 被封禁的 IP: {ip: unblock_timestamp} self._blocked: Dict[str, float] = {} - + def _get_client_ip(self, request: Request) -> str: """获取客户端 IP 地址""" # 检查代理头 @@ -32,26 +32,23 @@ class RateLimiter: if forwarded: # 取第一个 IP(最原始的客户端) return forwarded.split(",")[0].strip() - + real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip - + # 直接连接的客户端 if request.client: return request.client.host - + return "unknown" - + def _cleanup_old_requests(self, key: str, window_seconds: int): """清理过期的请求记录""" now = time.time() cutoff = now - window_seconds - self._requests[key] = [ - (ts, count) for ts, count in self._requests[key] - if ts > cutoff - ] - + self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff] + def _cleanup_expired_blocks(self): """清理过期的封禁""" now = time.time() @@ -59,65 +56,61 @@ class RateLimiter: for ip in expired: del self._blocked[ip] logger.info(f"🔓 IP {ip} 封禁已解除") - + def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]: """ 检查 IP 是否被封禁 - + Returns: (是否被封禁, 剩余封禁秒数) """ self._cleanup_expired_blocks() ip = self._get_client_ip(request) - + if ip in self._blocked: remaining = int(self._blocked[ip] - time.time()) return True, max(0, remaining) - + return False, None - + def check_rate_limit( - self, - request: Request, - max_requests: int, - window_seconds: int, - key_suffix: str = "" + self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = "" ) -> Tuple[bool, int]: """ 检查请求是否超过频率限制 - + Args: request: FastAPI Request 对象 max_requests: 窗口期内允许的最大请求数 window_seconds: 窗口时间(秒) key_suffix: 键后缀,用于区分不同的限制规则 - + Returns: (是否允许, 剩余请求数) """ ip = self._get_client_ip(request) key = f"{ip}:{key_suffix}" if key_suffix else ip - + # 清理过期记录 self._cleanup_old_requests(key, window_seconds) - + # 计算当前窗口内的请求数 current_count = sum(count for _, count in self._requests[key]) - + if current_count >= max_requests: return False, 0 - + # 记录新请求 now = time.time() self._requests[key].append((now, 1)) - + remaining = max_requests - current_count - 1 return True, remaining - + def block_ip(self, request: Request, duration_seconds: int): """ 封禁 IP - + Args: request: FastAPI Request 对象 duration_seconds: 封禁时长(秒) @@ -125,55 +118,51 @@ class RateLimiter: ip = self._get_client_ip(request) self._blocked[ip] = time.time() + duration_seconds logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒") - + def record_failed_attempt( - self, - request: Request, - max_failures: int = 5, - window_seconds: int = 300, - block_duration: int = 600 + self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600 ) -> Tuple[bool, int]: """ 记录失败尝试(如登录失败) - + 如果在窗口期内失败次数过多,自动封禁 IP - + Args: request: FastAPI Request 对象 max_failures: 允许的最大失败次数 window_seconds: 统计窗口(秒) block_duration: 封禁时长(秒) - + Returns: (是否被封禁, 剩余尝试次数) """ ip = self._get_client_ip(request) key = f"{ip}:auth_failures" - + # 清理过期记录 self._cleanup_old_requests(key, window_seconds) - + # 计算当前失败次数 current_failures = sum(count for _, count in self._requests[key]) - + # 记录本次失败 now = time.time() self._requests[key].append((now, 1)) current_failures += 1 - + remaining = max_failures - current_failures - + # 检查是否需要封禁 if current_failures >= max_failures: self.block_ip(request, block_duration) logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁") return True, 0 - + if current_failures >= max_failures - 2: logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次") - + return False, max(0, remaining) - + def reset_failures(self, request: Request): """ 重置失败计数(认证成功后调用) @@ -199,66 +188,58 @@ def get_rate_limiter() -> RateLimiter: async def check_auth_rate_limit(request: Request): """ 认证接口的频率限制依赖 - + 规则: - 每个 IP 每分钟最多 10 次认证请求 - 连续失败 5 次后封禁 10 分钟 """ limiter = get_rate_limiter() - + # 检查是否被封禁 blocked, remaining_block = limiter.is_blocked(request) if blocked: raise HTTPException( status_code=429, detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", - headers={"Retry-After": str(remaining_block)} + headers={"Retry-After": str(remaining_block)}, ) - + # 检查频率限制 allowed, remaining = limiter.check_rate_limit( - request, + request, max_requests=10, # 每分钟 10 次 window_seconds=60, - key_suffix="auth" + key_suffix="auth", ) - + if not allowed: - raise HTTPException( - status_code=429, - detail="认证请求过于频繁,请稍后重试", - headers={"Retry-After": "60"} - ) + raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"}) async def check_api_rate_limit(request: Request): """ 普通 API 的频率限制依赖 - + 规则:每个 IP 每分钟最多 100 次请求 """ limiter = get_rate_limiter() - + # 检查是否被封禁 blocked, remaining_block = limiter.is_blocked(request) if blocked: raise HTTPException( status_code=429, detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", - headers={"Retry-After": str(remaining_block)} + headers={"Retry-After": str(remaining_block)}, ) - + # 检查频率限制 allowed, _ = limiter.check_rate_limit( - request, + request, max_requests=100, # 每分钟 100 次 window_seconds=60, - key_suffix="api" + key_suffix="api", ) - + if not allowed: - raise HTTPException( - status_code=429, - detail="请求过于频繁,请稍后重试", - headers={"Retry-After": "60"} - ) + raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"}) diff --git a/src/webui/routes.py b/src/webui/routes.py index 558b8852..bb92d6cc 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -112,7 +112,7 @@ async def health_check(): @router.post("/auth/verify", response_model=TokenVerifyResponse) async def verify_token( - request_body: TokenVerifyRequest, + request_body: TokenVerifyRequest, request: Request, response: Response, _rate_limit: None = Depends(check_auth_rate_limit), @@ -131,7 +131,7 @@ async def verify_token( try: token_manager = get_token_manager() rate_limiter = get_rate_limiter() - + is_valid = token_manager.verify_token(request_body.token) if is_valid: @@ -146,21 +146,18 @@ async def verify_token( # 记录失败尝试 blocked, remaining = rate_limiter.record_failed_attempt( request, - max_failures=5, # 5 次失败 + max_failures=5, # 5 次失败 window_seconds=300, # 5 分钟窗口 - block_duration=600 # 封禁 10 分钟 + block_duration=600, # 封禁 10 分钟 ) - + if blocked: - raise HTTPException( - status_code=429, - detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟" - ) - + raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟") + message = "Token 无效或已过期" if remaining <= 2: message += f"(剩余 {remaining} 次尝试机会)" - + return TokenVerifyResponse(valid=False, message=message) except HTTPException: raise diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index 8162642f..267787c4 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -34,7 +34,7 @@ class WebUIServer: # 重要:先注册 API 路由,再设置静态文件 self._register_api_routes() self._setup_static_files() - + # 注册robots.txt路由 self._setup_robots_txt() @@ -115,7 +115,7 @@ class WebUIServer: media_type = mimetypes.guess_type(str(file_path))[0] response = FileResponse(file_path, media_type=media_type) # HTML 文件添加防索引头 - if str(file_path).endswith('.html'): + if str(file_path).endswith(".html"): response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive" return response @@ -130,23 +130,15 @@ class WebUIServer: """配置防爬虫中间件""" try: from src.webui.anti_crawler import AntiCrawlerMiddleware - + # 从环境变量读取防爬虫模式(false/strict/loose/basic) anti_crawler_mode = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower() - + # 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行 # 我们需要在CORS之前注册,这样防爬虫检查会在CORS之前执行 - self.app.add_middleware( - AntiCrawlerMiddleware, - mode=anti_crawler_mode - ) - - mode_descriptions = { - "false": "已禁用", - "strict": "严格模式", - "loose": "宽松模式", - "basic": "基础模式" - } + self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode) + + mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"} mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式") logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}") except Exception as e: @@ -156,12 +148,12 @@ class WebUIServer: """设置robots.txt路由""" try: from src.webui.anti_crawler import create_robots_txt_response - + @self.app.get("/robots.txt", include_in_schema=False) async def robots_txt(): """返回robots.txt,禁止所有爬虫""" return create_robots_txt_response() - + logger.debug("✅ robots.txt 路由已注册") except Exception as e: logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True) diff --git a/src/webui/ws_auth.py b/src/webui/ws_auth.py index 7d07b5c2..e6bb00e7 100644 --- a/src/webui/ws_auth.py +++ b/src/webui/ws_auth.py @@ -4,7 +4,7 @@ 临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。 """ -from fastapi import APIRouter, Cookie, Header, HTTPException +from fastapi import APIRouter, Cookie, Header from typing import Optional import secrets import time @@ -30,10 +30,10 @@ def _cleanup_expired_ws_tokens(): def generate_ws_token(session_token: str) -> str: """生成 WebSocket 临时 token - + Args: session_token: 原始的 session token - + Returns: 临时 token 字符串 """ @@ -46,10 +46,10 @@ def generate_ws_token(session_token: str) -> str: def verify_ws_token(temp_token: str) -> bool: """验证并消费 WebSocket 临时 token(一次性使用) - + Args: temp_token: 临时 token - + Returns: 验证是否通过 """ @@ -81,11 +81,11 @@ async def get_ws_token( ): """ 获取 WebSocket 连接用的临时 token - + 此端点验证当前会话的 Cookie 或 Authorization header, 然后返回一个临时 token 用于 WebSocket 握手认证。 临时 token 有效期 60 秒,且只能使用一次。 - + 注意:在未认证时返回 200 状态码但 success=False,避免前端因 401 刷新页面。 """ # 获取当前 session token @@ -94,21 +94,21 @@ async def get_ws_token( session_token = maibot_session elif authorization and authorization.startswith("Bearer "): session_token = authorization.replace("Bearer ", "") - + if not session_token: # 返回 200 但 success=False,避免前端因 401 刷新页面 # 这在登录页面是正常情况,不应该触发错误处理 logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)") return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0} - + # 验证 session token token_manager = get_token_manager() if not token_manager.verify_token(session_token): # 同样返回 200 但 success=False,避免前端刷新 logger.debug("ws-token 请求:认证已过期") return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0} - + # 生成临时 WebSocket token ws_token = generate_ws_token(session_token) - + return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}