Ruff Format

pull/1443/head
墨梓柒 2025-12-14 21:39:09 +08:00
parent ee4cb3dc67
commit 74a2f4346a
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
11 changed files with 176 additions and 197 deletions

View File

@ -126,6 +126,7 @@ SCANNER_SPECIFIC_HEADERS = {
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
# IP白名单配置从环境变量读取逗号分隔
# 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100
@ -135,10 +136,10 @@ ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
def _parse_allowed_ips(ip_string: str) -> list:
"""
解析IP白名单字符串支持精确IPCIDR格式和通配符
Args:
ip_string: 逗号分隔的IP字符串
Returns:
IP白名单列表每个元素可能是
- ipaddress.IPv4Network/IPv6Network对象CIDR格式
@ -148,12 +149,12 @@ def _parse_allowed_ips(ip_string: str) -> list:
allowed = []
if not ip_string:
return allowed
for ip_entry in ip_string.split(","):
ip_entry = ip_entry.strip() # 去除空格
if not ip_entry:
continue
# 检查通配符格式(包含*
if "*" in ip_entry:
# 处理通配符
@ -163,7 +164,7 @@ def _parse_allowed_ips(ip_string: str) -> list:
else:
logger.warning(f"无效的通配符IP格式已忽略: {ip_entry}")
continue
try:
# 尝试解析为CIDR格式包含/
if "/" in ip_entry:
@ -173,39 +174,39 @@ def _parse_allowed_ips(ip_string: str) -> list:
allowed.append(ipaddress.ip_address(ip_entry))
except (ValueError, AttributeError) as e:
logger.warning(f"无效的IP白名单条目已忽略: {ip_entry} ({e})")
return allowed
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
"""
将通配符IP模式转换为正则表达式
支持的格式
- 192.168.*.* 192.168.*
- 10.*.*.* 10.*
- *.*.*.* *
Args:
wildcard_pattern: 通配符模式字符串
Returns:
正则表达式字符串如果格式无效则返回None
"""
# 去除空格
pattern = wildcard_pattern.strip()
# 处理单个*(匹配所有)
if pattern == "*":
return r".*"
# 处理IPv4通配符格式
# 支持192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
parts = pattern.split(".")
if len(parts) > 4:
return None # IPv4最多4段
# 构建正则表达式
regex_parts = []
for part in parts:
@ -221,15 +222,16 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
return None # 无效的数字
else:
return None # 无效的格式
# 如果部分少于4段补充.*
while len(regex_parts) < 4:
regex_parts.append(r"\d+")
# 组合成正则表达式
regex = r"^" + r"\.".join(regex_parts) + r"$"
return regex
ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", ""))
# 信任的代理IP配置从环境变量读取逗号分隔
@ -250,7 +252,7 @@ def _get_mode_config(mode: str) -> dict:
配置字典包含所有相关参数
"""
mode = mode.lower()
if mode == "false":
return {
"enabled": False,
@ -320,7 +322,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
self.check_asset_scanner = config["check_asset_scanner"]
self.check_rate_limit = config["check_rate_limit"]
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
# 用于存储每个IP的请求时间戳使用deque提高性能
self.request_times: dict[str, deque] = {}
# 上次清理时间
@ -354,7 +356,6 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
return False
def _is_asset_scanner_header(self, request: Request) -> bool:
"""
检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配
@ -499,7 +500,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
empty_ips = []
# 找到最久未访问的IP最旧时间戳
oldest_ip = None
oldest_time = float('inf')
oldest_time = float("inf")
# 全量遍历找真正的oldest超限时性能可接受
for ip, times in self.request_times.items():
@ -532,7 +533,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
"""
if not TRUSTED_PROXIES or ip == "unknown":
return False
# 检查代理列表中的每个条目
for trusted_entry in TRUSTED_PROXIES:
# 通配符模式(字符串,正则表达式)
@ -558,7 +559,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
return True
except (ValueError, AttributeError):
continue
return False
def _get_client_ip(self, request: Request) -> str:
@ -635,7 +636,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
"""
if not ALLOWED_IPS or ip == "unknown":
return False
# 检查白名单中的每个条目
for allowed_entry in ALLOWED_IPS:
# 通配符模式(字符串,正则表达式)
@ -664,7 +665,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
except (ValueError, AttributeError):
# IP格式无效跳过
continue
return False
async def dispatch(self, request: Request, call_next):
@ -689,16 +690,31 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
# 允许访问静态资源CSS、JS、图片等
# 注意:.json 已移除,避免 API 路径绕过防护
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/
static_extensions = {".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".ttf", ".eot"}
static_extensions = {
".css",
".js",
".png",
".jpg",
".jpeg",
".gif",
".svg",
".ico",
".woff",
".woff2",
".ttf",
".eot",
}
static_prefixes = {"/static/", "/assets/", "/dist/"}
# 检查是否是静态资源路径(特定前缀下的静态文件)
path = request.url.path
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(path.endswith(ext) for ext in static_extensions)
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
path.endswith(ext) for ext in static_extensions
)
# 也允许根路径下的静态文件(如 /favicon.ico
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
if is_static_path or is_root_static:
return await call_next(request)
@ -729,9 +745,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
# 检测爬虫 User-Agent
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
logger.warning(
f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}"
)
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
# 根据配置决定是否阻止
if self.block_on_detect:
return PlainTextResponse(
@ -741,9 +755,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
# 检查请求频率限制
if self.check_rate_limit and self._check_rate_limit(client_ip):
logger.warning(
f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}"
)
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
return PlainTextResponse(
"Too Many Requests: Rate limit exceeded",
status_code=429,
@ -770,4 +782,3 @@ Disallow: /
media_type="text/plain",
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
)

View File

@ -19,7 +19,7 @@ COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
def _is_secure_environment() -> bool:
"""
检测是否应该启用安全 CookieHTTPS
Returns:
bool: 如果应该使用 secure cookie 则返回 True
"""
@ -28,12 +28,12 @@ def _is_secure_environment() -> bool:
return True
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
return False
# 检查是否是生产环境
env = os.environ.get("WEBUI_MODE", "").lower()
if env in ("production", "prod"):
return True
# 默认:开发环境不启用(因为通常是 HTTP
return False
@ -87,7 +87,7 @@ def set_auth_cookie(response: Response, token: str) -> None:
"""
# 根据环境决定安全设置
is_secure = _is_secure_environment()
response.set_cookie(
key=COOKIE_NAME,
value=token,
@ -109,7 +109,7 @@ def clear_auth_cookie(response: Response) -> None:
"""
# 保持与 set_auth_cookie 相同的安全设置
is_secure = _is_secure_environment()
response.delete_cookie(
key=COOKIE_NAME,
httponly=True,

View File

@ -31,6 +31,7 @@ def require_auth(
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
@ -399,21 +400,21 @@ async def websocket_chat(
token: 认证 token可选也可从 Cookie 获取
虚拟身份模式可通过 URL 参数直接配置或通过消息中的 set_virtual_identity 配置
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
@ -422,19 +423,19 @@ async def websocket_chat(
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())

View File

@ -346,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depen
@router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
async def update_model_config_section(
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@ -383,8 +385,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
orphaned_models = [
m.get("name") for m in models
if isinstance(m, dict) and m.get("api_provider") not in provider_names
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
]
if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"

View File

@ -83,16 +83,16 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
@ -101,19 +101,19 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")

View File

@ -99,16 +99,16 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
@ -117,19 +117,19 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")

View File

@ -3,7 +3,6 @@ from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, get_origin
from pathlib import Path
import json
import re
from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION
@ -38,54 +37,54 @@ def get_token_from_cookie_or_header(
def validate_safe_path(user_path: str, base_path: Path) -> Path:
"""
验证用户提供的路径是否安全防止路径遍历攻击
Args:
user_path: 用户输入的路径相对路径
base_path: 允许的基础目录
Returns:
安全的绝对路径
Raises:
HTTPException: 如果检测到路径遍历攻击
"""
# 规范化基础路径
base_resolved = base_path.resolve()
# 检查用户路径是否包含可疑字符
# 禁止: .., 绝对路径开头, 空字节等
if any(pattern in user_path for pattern in ["..", "\x00"]):
logger.warning(f"检测到可疑路径: {user_path}")
raise HTTPException(status_code=400, detail="路径包含非法字符")
# 检查是否为绝对路径Windows 和 Unix
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
logger.warning(f"检测到绝对路径: {user_path}")
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
# 构建目标路径并解析
target_path = (base_path / user_path).resolve()
# 验证解析后的路径仍在基础目录内
try:
target_path.relative_to(base_resolved)
except ValueError as e:
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
return target_path
def validate_plugin_id(plugin_id: str) -> str:
"""
验证插件 ID 格式是否安全
Args:
plugin_id: 插件 ID (支持 author.name 格式允许中文)
Returns:
验证通过的插件 ID
Raises:
HTTPException: 如果插件 ID 格式不安全
"""
@ -93,24 +92,24 @@ def validate_plugin_id(plugin_id: str) -> str:
if not plugin_id or not plugin_id.strip():
logger.warning("非法插件 ID: 空字符串")
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
# 禁止危险字符: 路径分隔符、空字节、控制字符等
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
for pattern in dangerous_patterns:
if pattern in plugin_id:
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
if plugin_id.startswith(".") or plugin_id.endswith("."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
# 禁止特殊名称
if plugin_id in (".", ".."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
return plugin_id
@ -556,10 +555,7 @@ async def fetch_raw_file(
if not token or not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
logger.info(
f"收到获取 Raw 文件请求: "
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
)
logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}")
# 发送开始加载进度
await update_progress(
@ -688,7 +684,7 @@ async def install_plugin(
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始安装
await update_progress(
stage="loading",
@ -899,7 +895,7 @@ async def uninstall_plugin(
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始卸载
await update_progress(
stage="loading",
@ -1041,7 +1037,7 @@ async def update_plugin(
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始更新
await update_progress(
stage="loading",
@ -1494,7 +1490,7 @@ async def get_plugin_config_schema(
ui_type = "text"
item_type = None
item_fields = None
if isinstance(field_value, bool):
ui_type = "switch"
elif isinstance(field_value, (int, float)):

View File

@ -15,16 +15,16 @@ logger = get_logger("webui.rate_limiter")
class RateLimiter:
"""
简单的内存请求频率限制器
使用滑动窗口算法实现
"""
def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {}
def _get_client_ip(self, request: Request) -> str:
"""获取客户端 IP 地址"""
# 检查代理头
@ -32,26 +32,23 @@ class RateLimiter:
if forwarded:
# 取第一个 IP最原始的客户端
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接连接的客户端
if request.client:
return request.client.host
return "unknown"
def _cleanup_old_requests(self, key: str, window_seconds: int):
"""清理过期的请求记录"""
now = time.time()
cutoff = now - window_seconds
self._requests[key] = [
(ts, count) for ts, count in self._requests[key]
if ts > cutoff
]
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
def _cleanup_expired_blocks(self):
"""清理过期的封禁"""
now = time.time()
@ -59,65 +56,61 @@ class RateLimiter:
for ip in expired:
del self._blocked[ip]
logger.info(f"🔓 IP {ip} 封禁已解除")
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
"""
检查 IP 是否被封禁
Returns:
(是否被封禁, 剩余封禁秒数)
"""
self._cleanup_expired_blocks()
ip = self._get_client_ip(request)
if ip in self._blocked:
remaining = int(self._blocked[ip] - time.time())
return True, max(0, remaining)
return False, None
def check_rate_limit(
self,
request: Request,
max_requests: int,
window_seconds: int,
key_suffix: str = ""
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
) -> Tuple[bool, int]:
"""
检查请求是否超过频率限制
Args:
request: FastAPI Request 对象
max_requests: 窗口期内允许的最大请求数
window_seconds: 窗口时间
key_suffix: 键后缀用于区分不同的限制规则
Returns:
(是否允许, 剩余请求数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:{key_suffix}" if key_suffix else ip
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前窗口内的请求数
current_count = sum(count for _, count in self._requests[key])
if current_count >= max_requests:
return False, 0
# 记录新请求
now = time.time()
self._requests[key].append((now, 1))
remaining = max_requests - current_count - 1
return True, remaining
def block_ip(self, request: Request, duration_seconds: int):
"""
封禁 IP
Args:
request: FastAPI Request 对象
duration_seconds: 封禁时长
@ -125,55 +118,51 @@ class RateLimiter:
ip = self._get_client_ip(request)
self._blocked[ip] = time.time() + duration_seconds
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt(
self,
request: Request,
max_failures: int = 5,
window_seconds: int = 300,
block_duration: int = 600
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
) -> Tuple[bool, int]:
"""
记录失败尝试如登录失败
如果在窗口期内失败次数过多自动封禁 IP
Args:
request: FastAPI Request 对象
max_failures: 允许的最大失败次数
window_seconds: 统计窗口
block_duration: 封禁时长
Returns:
(是否被封禁, 剩余尝试次数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前失败次数
current_failures = sum(count for _, count in self._requests[key])
# 记录本次失败
now = time.time()
self._requests[key].append((now, 1))
current_failures += 1
remaining = max_failures - current_failures
# 检查是否需要封禁
if current_failures >= max_failures:
self.block_ip(request, block_duration)
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
return True, 0
if current_failures >= max_failures - 2:
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}")
return False, max(0, remaining)
def reset_failures(self, request: Request):
"""
重置失败计数认证成功后调用
@ -199,66 +188,58 @@ def get_rate_limiter() -> RateLimiter:
async def check_auth_rate_limit(request: Request):
"""
认证接口的频率限制依赖
规则
- 每个 IP 每分钟最多 10 次认证请求
- 连续失败 5 次后封禁 10 分钟
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)}
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, remaining = limiter.check_rate_limit(
request,
request,
max_requests=10, # 每分钟 10 次
window_seconds=60,
key_suffix="auth"
key_suffix="auth",
)
if not allowed:
raise HTTPException(
status_code=429,
detail="认证请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
async def check_api_rate_limit(request: Request):
"""
普通 API 的频率限制依赖
规则每个 IP 每分钟最多 100 次请求
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)}
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, _ = limiter.check_rate_limit(
request,
request,
max_requests=100, # 每分钟 100 次
window_seconds=60,
key_suffix="api"
key_suffix="api",
)
if not allowed:
raise HTTPException(
status_code=429,
detail="请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})

View File

@ -112,7 +112,7 @@ async def health_check():
@router.post("/auth/verify", response_model=TokenVerifyResponse)
async def verify_token(
request_body: TokenVerifyRequest,
request_body: TokenVerifyRequest,
request: Request,
response: Response,
_rate_limit: None = Depends(check_auth_rate_limit),
@ -131,7 +131,7 @@ async def verify_token(
try:
token_manager = get_token_manager()
rate_limiter = get_rate_limiter()
is_valid = token_manager.verify_token(request_body.token)
if is_valid:
@ -146,21 +146,18 @@ async def verify_token(
# 记录失败尝试
blocked, remaining = rate_limiter.record_failed_attempt(
request,
max_failures=5, # 5 次失败
max_failures=5, # 5 次失败
window_seconds=300, # 5 分钟窗口
block_duration=600 # 封禁 10 分钟
block_duration=600, # 封禁 10 分钟
)
if blocked:
raise HTTPException(
status_code=429,
detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟"
)
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
message = "Token 无效或已过期"
if remaining <= 2:
message += f"(剩余 {remaining} 次尝试机会)"
return TokenVerifyResponse(valid=False, message=message)
except HTTPException:
raise

View File

@ -34,7 +34,7 @@ class WebUIServer:
# 重要:先注册 API 路由,再设置静态文件
self._register_api_routes()
self._setup_static_files()
# 注册robots.txt路由
self._setup_robots_txt()
@ -115,7 +115,7 @@ class WebUIServer:
media_type = mimetypes.guess_type(str(file_path))[0]
response = FileResponse(file_path, media_type=media_type)
# HTML 文件添加防索引头
if str(file_path).endswith('.html'):
if str(file_path).endswith(".html"):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
@ -130,23 +130,15 @@ class WebUIServer:
"""配置防爬虫中间件"""
try:
from src.webui.anti_crawler import AntiCrawlerMiddleware
# 从环境变量读取防爬虫模式false/strict/loose/basic
anti_crawler_mode = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
# 我们需要在CORS之前注册这样防爬虫检查会在CORS之前执行
self.app.add_middleware(
AntiCrawlerMiddleware,
mode=anti_crawler_mode
)
mode_descriptions = {
"false": "已禁用",
"strict": "严格模式",
"loose": "宽松模式",
"basic": "基础模式"
}
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
except Exception as e:
@ -156,12 +148,12 @@ class WebUIServer:
"""设置robots.txt路由"""
try:
from src.webui.anti_crawler import create_robots_txt_response
@self.app.get("/robots.txt", include_in_schema=False)
async def robots_txt():
"""返回robots.txt禁止所有爬虫"""
return create_robots_txt_response()
logger.debug("✅ robots.txt 路由已注册")
except Exception as e:
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)

View File

@ -4,7 +4,7 @@
临时 token 有效期 60 且只能使用一次用于解决 WebSocket 握手时 Cookie 不可用的问题
"""
from fastapi import APIRouter, Cookie, Header, HTTPException
from fastapi import APIRouter, Cookie, Header
from typing import Optional
import secrets
import time
@ -30,10 +30,10 @@ def _cleanup_expired_ws_tokens():
def generate_ws_token(session_token: str) -> str:
"""生成 WebSocket 临时 token
Args:
session_token: 原始的 session token
Returns:
临时 token 字符串
"""
@ -46,10 +46,10 @@ def generate_ws_token(session_token: str) -> str:
def verify_ws_token(temp_token: str) -> bool:
"""验证并消费 WebSocket 临时 token一次性使用
Args:
temp_token: 临时 token
Returns:
验证是否通过
"""
@ -81,11 +81,11 @@ async def get_ws_token(
):
"""
获取 WebSocket 连接用的临时 token
此端点验证当前会话的 Cookie Authorization header
然后返回一个临时 token 用于 WebSocket 握手认证
临时 token 有效期 60 且只能使用一次
注意在未认证时返回 200 状态码但 success=False避免前端因 401 刷新页面
"""
# 获取当前 session token
@ -94,21 +94,21 @@ async def get_ws_token(
session_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "")
if not session_token:
# 返回 200 但 success=False避免前端因 401 刷新页面
# 这在登录页面是正常情况,不应该触发错误处理
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
# 验证 session token
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
# 同样返回 200 但 success=False避免前端刷新
logger.debug("ws-token 请求:认证已过期")
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
# 生成临时 WebSocket token
ws_token = generate_ws_token(session_token)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}