mirror of https://github.com/Mai-with-u/MaiBot.git
Ruff Format
parent
ee4cb3dc67
commit
74a2f4346a
|
|
@ -126,6 +126,7 @@ SCANNER_SPECIFIC_HEADERS = {
|
|||
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||
ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
|
||||
|
||||
|
||||
# IP白名单配置(从环境变量读取,逗号分隔)
|
||||
# 支持格式:
|
||||
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||
|
|
@ -135,10 +136,10 @@ ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
|
|||
def _parse_allowed_ips(ip_string: str) -> list:
|
||||
"""
|
||||
解析IP白名单字符串,支持精确IP、CIDR格式和通配符
|
||||
|
||||
|
||||
Args:
|
||||
ip_string: 逗号分隔的IP字符串
|
||||
|
||||
|
||||
Returns:
|
||||
IP白名单列表,每个元素可能是:
|
||||
- ipaddress.IPv4Network/IPv6Network对象(CIDR格式)
|
||||
|
|
@ -148,12 +149,12 @@ def _parse_allowed_ips(ip_string: str) -> list:
|
|||
allowed = []
|
||||
if not ip_string:
|
||||
return allowed
|
||||
|
||||
|
||||
for ip_entry in ip_string.split(","):
|
||||
ip_entry = ip_entry.strip() # 去除空格
|
||||
if not ip_entry:
|
||||
continue
|
||||
|
||||
|
||||
# 检查通配符格式(包含*)
|
||||
if "*" in ip_entry:
|
||||
# 处理通配符
|
||||
|
|
@ -163,7 +164,7 @@ def _parse_allowed_ips(ip_string: str) -> list:
|
|||
else:
|
||||
logger.warning(f"无效的通配符IP格式,已忽略: {ip_entry}")
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# 尝试解析为CIDR格式(包含/)
|
||||
if "/" in ip_entry:
|
||||
|
|
@ -173,39 +174,39 @@ def _parse_allowed_ips(ip_string: str) -> list:
|
|||
allowed.append(ipaddress.ip_address(ip_entry))
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"无效的IP白名单条目,已忽略: {ip_entry} ({e})")
|
||||
|
||||
|
||||
return allowed
|
||||
|
||||
|
||||
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
||||
"""
|
||||
将通配符IP模式转换为正则表达式
|
||||
|
||||
|
||||
支持的格式:
|
||||
- 192.168.*.* 或 192.168.*
|
||||
- 10.*.*.* 或 10.*
|
||||
- *.*.*.* 或 *
|
||||
|
||||
|
||||
Args:
|
||||
wildcard_pattern: 通配符模式字符串
|
||||
|
||||
|
||||
Returns:
|
||||
正则表达式字符串,如果格式无效则返回None
|
||||
"""
|
||||
# 去除空格
|
||||
pattern = wildcard_pattern.strip()
|
||||
|
||||
|
||||
# 处理单个*(匹配所有)
|
||||
if pattern == "*":
|
||||
return r".*"
|
||||
|
||||
|
||||
# 处理IPv4通配符格式
|
||||
# 支持:192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
|
||||
parts = pattern.split(".")
|
||||
|
||||
|
||||
if len(parts) > 4:
|
||||
return None # IPv4最多4段
|
||||
|
||||
|
||||
# 构建正则表达式
|
||||
regex_parts = []
|
||||
for part in parts:
|
||||
|
|
@ -221,15 +222,16 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
|||
return None # 无效的数字
|
||||
else:
|
||||
return None # 无效的格式
|
||||
|
||||
|
||||
# 如果部分少于4段,补充.*
|
||||
while len(regex_parts) < 4:
|
||||
regex_parts.append(r"\d+")
|
||||
|
||||
|
||||
# 组合成正则表达式
|
||||
regex = r"^" + r"\.".join(regex_parts) + r"$"
|
||||
return regex
|
||||
|
||||
|
||||
ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", ""))
|
||||
|
||||
# 信任的代理IP配置(从环境变量读取,逗号分隔)
|
||||
|
|
@ -250,7 +252,7 @@ def _get_mode_config(mode: str) -> dict:
|
|||
配置字典,包含所有相关参数
|
||||
"""
|
||||
mode = mode.lower()
|
||||
|
||||
|
||||
if mode == "false":
|
||||
return {
|
||||
"enabled": False,
|
||||
|
|
@ -320,7 +322,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
self.check_asset_scanner = config["check_asset_scanner"]
|
||||
self.check_rate_limit = config["check_rate_limit"]
|
||||
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
||||
|
||||
|
||||
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
||||
self.request_times: dict[str, deque] = {}
|
||||
# 上次清理时间
|
||||
|
|
@ -354,7 +356,6 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
return False
|
||||
|
||||
|
||||
def _is_asset_scanner_header(self, request: Request) -> bool:
|
||||
"""
|
||||
检测是否为资产测绘工具的HTTP头(只检查特定头,收紧匹配)
|
||||
|
|
@ -499,7 +500,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
empty_ips = []
|
||||
# 找到最久未访问的IP(最旧时间戳)
|
||||
oldest_ip = None
|
||||
oldest_time = float('inf')
|
||||
oldest_time = float("inf")
|
||||
|
||||
# 全量遍历找真正的oldest(超限时性能可接受)
|
||||
for ip, times in self.request_times.items():
|
||||
|
|
@ -532,7 +533,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
"""
|
||||
if not TRUSTED_PROXIES or ip == "unknown":
|
||||
return False
|
||||
|
||||
|
||||
# 检查代理列表中的每个条目
|
||||
for trusted_entry in TRUSTED_PROXIES:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
|
|
@ -558,7 +559,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
|
||||
|
||||
return False
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
|
|
@ -635,7 +636,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
"""
|
||||
if not ALLOWED_IPS or ip == "unknown":
|
||||
return False
|
||||
|
||||
|
||||
# 检查白名单中的每个条目
|
||||
for allowed_entry in ALLOWED_IPS:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
|
|
@ -664,7 +665,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
except (ValueError, AttributeError):
|
||||
# IP格式无效,跳过
|
||||
continue
|
||||
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
|
|
@ -689,16 +690,31 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
# 允许访问静态资源(CSS、JS、图片等)
|
||||
# 注意:.json 已移除,避免 API 路径绕过防护
|
||||
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/)
|
||||
static_extensions = {".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".ttf", ".eot"}
|
||||
static_extensions = {
|
||||
".css",
|
||||
".js",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".svg",
|
||||
".ico",
|
||||
".woff",
|
||||
".woff2",
|
||||
".ttf",
|
||||
".eot",
|
||||
}
|
||||
static_prefixes = {"/static/", "/assets/", "/dist/"}
|
||||
|
||||
|
||||
# 检查是否是静态资源路径(特定前缀下的静态文件)
|
||||
path = request.url.path
|
||||
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(path.endswith(ext) for ext in static_extensions)
|
||||
|
||||
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
|
||||
path.endswith(ext) for ext in static_extensions
|
||||
)
|
||||
|
||||
# 也允许根路径下的静态文件(如 /favicon.ico)
|
||||
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
|
||||
|
||||
|
||||
if is_static_path or is_root_static:
|
||||
return await call_next(request)
|
||||
|
||||
|
|
@ -729,9 +745,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
# 检测爬虫 User-Agent
|
||||
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
|
||||
logger.warning(
|
||||
f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}"
|
||||
)
|
||||
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
# 根据配置决定是否阻止
|
||||
if self.block_on_detect:
|
||||
return PlainTextResponse(
|
||||
|
|
@ -741,9 +755,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
# 检查请求频率限制
|
||||
if self.check_rate_limit and self._check_rate_limit(client_ip):
|
||||
logger.warning(
|
||||
f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}"
|
||||
)
|
||||
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
return PlainTextResponse(
|
||||
"Too Many Requests: Rate limit exceeded",
|
||||
status_code=429,
|
||||
|
|
@ -770,4 +782,3 @@ Disallow: /
|
|||
media_type="text/plain",
|
||||
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
|||
def _is_secure_environment() -> bool:
|
||||
"""
|
||||
检测是否应该启用安全 Cookie(HTTPS)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果应该使用 secure cookie 则返回 True
|
||||
"""
|
||||
|
|
@ -28,12 +28,12 @@ def _is_secure_environment() -> bool:
|
|||
return True
|
||||
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
|
||||
return False
|
||||
|
||||
|
||||
# 检查是否是生产环境
|
||||
env = os.environ.get("WEBUI_MODE", "").lower()
|
||||
if env in ("production", "prod"):
|
||||
return True
|
||||
|
||||
|
||||
# 默认:开发环境不启用(因为通常是 HTTP)
|
||||
return False
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ def set_auth_cookie(response: Response, token: str) -> None:
|
|||
"""
|
||||
# 根据环境决定安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
|
|
@ -109,7 +109,7 @@ def clear_auth_cookie(response: Response) -> None:
|
|||
"""
|
||||
# 保持与 set_auth_cookie 相同的安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
|
||||
response.delete_cookie(
|
||||
key=COOKIE_NAME,
|
||||
httponly=True,
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ def require_auth(
|
|||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# WebUI 聊天的虚拟群组 ID
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
|
|
@ -399,21 +400,21 @@ async def websocket_chat(
|
|||
token: 认证 token(可选,也可从 Cookie 获取)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
|
||||
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
|
||||
示例:ws://host/api/chat/ws?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
|
|
@ -422,19 +423,19 @@ async def websocket_chat(
|
|||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
|
|
|
|||
|
|
@ -346,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depen
|
|||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||
async def update_model_config_section(
|
||||
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
|
||||
):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
|
|
@ -383,8 +385,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
|||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||
models = config_data.get("models", [])
|
||||
orphaned_models = [
|
||||
m.get("name") for m in models
|
||||
if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
||||
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
||||
]
|
||||
if orphaned_models:
|
||||
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
||||
|
|
|
|||
|
|
@ -83,16 +83,16 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
|
|||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
|
||||
示例:ws://host/ws/logs?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用临时 token 认证成功")
|
||||
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
|
|
@ -101,19 +101,19 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
|
|||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 session token 认证成功")
|
||||
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
|
|
|||
|
|
@ -99,16 +99,16 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
|||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
|
||||
示例:ws://host/ws/plugin-progress?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
|
|
@ -117,19 +117,19 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
|||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
|
||||
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from pydantic import BaseModel, Field
|
|||
from typing import Optional, List, Dict, Any, get_origin
|
||||
from pathlib import Path
|
||||
import json
|
||||
import re
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import MMC_VERSION
|
||||
|
|
@ -38,54 +37,54 @@ def get_token_from_cookie_or_header(
|
|||
def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
||||
"""
|
||||
验证用户提供的路径是否安全,防止路径遍历攻击
|
||||
|
||||
|
||||
Args:
|
||||
user_path: 用户输入的路径(相对路径)
|
||||
base_path: 允许的基础目录
|
||||
|
||||
|
||||
Returns:
|
||||
安全的绝对路径
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果检测到路径遍历攻击
|
||||
"""
|
||||
# 规范化基础路径
|
||||
base_resolved = base_path.resolve()
|
||||
|
||||
|
||||
# 检查用户路径是否包含可疑字符
|
||||
# 禁止: .., 绝对路径开头, 空字节等
|
||||
if any(pattern in user_path for pattern in ["..", "\x00"]):
|
||||
logger.warning(f"检测到可疑路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="路径包含非法字符")
|
||||
|
||||
|
||||
# 检查是否为绝对路径(Windows 和 Unix)
|
||||
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
|
||||
logger.warning(f"检测到绝对路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
|
||||
|
||||
|
||||
# 构建目标路径并解析
|
||||
target_path = (base_path / user_path).resolve()
|
||||
|
||||
|
||||
# 验证解析后的路径仍在基础目录内
|
||||
try:
|
||||
target_path.relative_to(base_resolved)
|
||||
except ValueError as e:
|
||||
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
|
||||
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
|
||||
|
||||
|
||||
return target_path
|
||||
|
||||
|
||||
def validate_plugin_id(plugin_id: str) -> str:
|
||||
"""
|
||||
验证插件 ID 格式是否安全
|
||||
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID (支持 author.name 格式,允许中文)
|
||||
|
||||
|
||||
Returns:
|
||||
验证通过的插件 ID
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: 如果插件 ID 格式不安全
|
||||
"""
|
||||
|
|
@ -93,24 +92,24 @@ def validate_plugin_id(plugin_id: str) -> str:
|
|||
if not plugin_id or not plugin_id.strip():
|
||||
logger.warning("非法插件 ID: 空字符串")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
|
||||
|
||||
|
||||
# 禁止危险字符: 路径分隔符、空字节、控制字符等
|
||||
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in plugin_id:
|
||||
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
|
||||
|
||||
|
||||
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
|
||||
if plugin_id.startswith(".") or plugin_id.endswith("."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
|
||||
|
||||
|
||||
# 禁止特殊名称
|
||||
if plugin_id in (".", ".."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
|
||||
|
||||
|
||||
return plugin_id
|
||||
|
||||
|
||||
|
|
@ -556,10 +555,7 @@ async def fetch_raw_file(
|
|||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(
|
||||
f"收到获取 Raw 文件请求: "
|
||||
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
|
||||
)
|
||||
logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}")
|
||||
|
||||
# 发送开始加载进度
|
||||
await update_progress(
|
||||
|
|
@ -688,7 +684,7 @@ async def install_plugin(
|
|||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
|
||||
# 推送进度:开始安装
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
|
|
@ -899,7 +895,7 @@ async def uninstall_plugin(
|
|||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
|
||||
# 推送进度:开始卸载
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
|
|
@ -1041,7 +1037,7 @@ async def update_plugin(
|
|||
try:
|
||||
# 验证插件 ID 格式安全性
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
|
||||
|
||||
# 推送进度:开始更新
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
|
|
@ -1494,7 +1490,7 @@ async def get_plugin_config_schema(
|
|||
ui_type = "text"
|
||||
item_type = None
|
||||
item_fields = None
|
||||
|
||||
|
||||
if isinstance(field_value, bool):
|
||||
ui_type = "switch"
|
||||
elif isinstance(field_value, (int, float)):
|
||||
|
|
|
|||
|
|
@ -15,16 +15,16 @@ logger = get_logger("webui.rate_limiter")
|
|||
class RateLimiter:
|
||||
"""
|
||||
简单的内存请求频率限制器
|
||||
|
||||
|
||||
使用滑动窗口算法实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
# 存储格式: {key: [(timestamp, count), ...]}
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||
self._blocked: Dict[str, float] = {}
|
||||
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""获取客户端 IP 地址"""
|
||||
# 检查代理头
|
||||
|
|
@ -32,26 +32,23 @@ class RateLimiter:
|
|||
if forwarded:
|
||||
# 取第一个 IP(最原始的客户端)
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
|
||||
# 直接连接的客户端
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _cleanup_old_requests(self, key: str, window_seconds: int):
|
||||
"""清理过期的请求记录"""
|
||||
now = time.time()
|
||||
cutoff = now - window_seconds
|
||||
self._requests[key] = [
|
||||
(ts, count) for ts, count in self._requests[key]
|
||||
if ts > cutoff
|
||||
]
|
||||
|
||||
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
|
||||
|
||||
def _cleanup_expired_blocks(self):
|
||||
"""清理过期的封禁"""
|
||||
now = time.time()
|
||||
|
|
@ -59,65 +56,61 @@ class RateLimiter:
|
|||
for ip in expired:
|
||||
del self._blocked[ip]
|
||||
logger.info(f"🔓 IP {ip} 封禁已解除")
|
||||
|
||||
|
||||
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
|
||||
"""
|
||||
检查 IP 是否被封禁
|
||||
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余封禁秒数)
|
||||
"""
|
||||
self._cleanup_expired_blocks()
|
||||
ip = self._get_client_ip(request)
|
||||
|
||||
|
||||
if ip in self._blocked:
|
||||
remaining = int(self._blocked[ip] - time.time())
|
||||
return True, max(0, remaining)
|
||||
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def check_rate_limit(
|
||||
self,
|
||||
request: Request,
|
||||
max_requests: int,
|
||||
window_seconds: int,
|
||||
key_suffix: str = ""
|
||||
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
检查请求是否超过频率限制
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_requests: 窗口期内允许的最大请求数
|
||||
window_seconds: 窗口时间(秒)
|
||||
key_suffix: 键后缀,用于区分不同的限制规则
|
||||
|
||||
|
||||
Returns:
|
||||
(是否允许, 剩余请求数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:{key_suffix}" if key_suffix else ip
|
||||
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
|
||||
# 计算当前窗口内的请求数
|
||||
current_count = sum(count for _, count in self._requests[key])
|
||||
|
||||
|
||||
if current_count >= max_requests:
|
||||
return False, 0
|
||||
|
||||
|
||||
# 记录新请求
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
|
||||
|
||||
remaining = max_requests - current_count - 1
|
||||
return True, remaining
|
||||
|
||||
|
||||
def block_ip(self, request: Request, duration_seconds: int):
|
||||
"""
|
||||
封禁 IP
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
duration_seconds: 封禁时长(秒)
|
||||
|
|
@ -125,55 +118,51 @@ class RateLimiter:
|
|||
ip = self._get_client_ip(request)
|
||||
self._blocked[ip] = time.time() + duration_seconds
|
||||
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒")
|
||||
|
||||
|
||||
def record_failed_attempt(
|
||||
self,
|
||||
request: Request,
|
||||
max_failures: int = 5,
|
||||
window_seconds: int = 300,
|
||||
block_duration: int = 600
|
||||
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
记录失败尝试(如登录失败)
|
||||
|
||||
|
||||
如果在窗口期内失败次数过多,自动封禁 IP
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_failures: 允许的最大失败次数
|
||||
window_seconds: 统计窗口(秒)
|
||||
block_duration: 封禁时长(秒)
|
||||
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余尝试次数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
|
||||
# 计算当前失败次数
|
||||
current_failures = sum(count for _, count in self._requests[key])
|
||||
|
||||
|
||||
# 记录本次失败
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
current_failures += 1
|
||||
|
||||
|
||||
remaining = max_failures - current_failures
|
||||
|
||||
|
||||
# 检查是否需要封禁
|
||||
if current_failures >= max_failures:
|
||||
self.block_ip(request, block_duration)
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
|
||||
return True, 0
|
||||
|
||||
|
||||
if current_failures >= max_failures - 2:
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次")
|
||||
|
||||
|
||||
return False, max(0, remaining)
|
||||
|
||||
|
||||
def reset_failures(self, request: Request):
|
||||
"""
|
||||
重置失败计数(认证成功后调用)
|
||||
|
|
@ -199,66 +188,58 @@ def get_rate_limiter() -> RateLimiter:
|
|||
async def check_auth_rate_limit(request: Request):
|
||||
"""
|
||||
认证接口的频率限制依赖
|
||||
|
||||
|
||||
规则:
|
||||
- 每个 IP 每分钟最多 10 次认证请求
|
||||
- 连续失败 5 次后封禁 10 分钟
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)}
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
|
||||
# 检查频率限制
|
||||
allowed, remaining = limiter.check_rate_limit(
|
||||
request,
|
||||
request,
|
||||
max_requests=10, # 每分钟 10 次
|
||||
window_seconds=60,
|
||||
key_suffix="auth"
|
||||
key_suffix="auth",
|
||||
)
|
||||
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="认证请求过于频繁,请稍后重试",
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
|
||||
|
||||
async def check_api_rate_limit(request: Request):
|
||||
"""
|
||||
普通 API 的频率限制依赖
|
||||
|
||||
|
||||
规则:每个 IP 每分钟最多 100 次请求
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)}
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
|
||||
# 检查频率限制
|
||||
allowed, _ = limiter.check_rate_limit(
|
||||
request,
|
||||
request,
|
||||
max_requests=100, # 每分钟 100 次
|
||||
window_seconds=60,
|
||||
key_suffix="api"
|
||||
key_suffix="api",
|
||||
)
|
||||
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="请求过于频繁,请稍后重试",
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ async def health_check():
|
|||
|
||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||
async def verify_token(
|
||||
request_body: TokenVerifyRequest,
|
||||
request_body: TokenVerifyRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
_rate_limit: None = Depends(check_auth_rate_limit),
|
||||
|
|
@ -131,7 +131,7 @@ async def verify_token(
|
|||
try:
|
||||
token_manager = get_token_manager()
|
||||
rate_limiter = get_rate_limiter()
|
||||
|
||||
|
||||
is_valid = token_manager.verify_token(request_body.token)
|
||||
|
||||
if is_valid:
|
||||
|
|
@ -146,21 +146,18 @@ async def verify_token(
|
|||
# 记录失败尝试
|
||||
blocked, remaining = rate_limiter.record_failed_attempt(
|
||||
request,
|
||||
max_failures=5, # 5 次失败
|
||||
max_failures=5, # 5 次失败
|
||||
window_seconds=300, # 5 分钟窗口
|
||||
block_duration=600 # 封禁 10 分钟
|
||||
block_duration=600, # 封禁 10 分钟
|
||||
)
|
||||
|
||||
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
|
||||
|
||||
message = "Token 无效或已过期"
|
||||
if remaining <= 2:
|
||||
message += f"(剩余 {remaining} 次尝试机会)"
|
||||
|
||||
|
||||
return TokenVerifyResponse(valid=False, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class WebUIServer:
|
|||
# 重要:先注册 API 路由,再设置静态文件
|
||||
self._register_api_routes()
|
||||
self._setup_static_files()
|
||||
|
||||
|
||||
# 注册robots.txt路由
|
||||
self._setup_robots_txt()
|
||||
|
||||
|
|
@ -115,7 +115,7 @@ class WebUIServer:
|
|||
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||
response = FileResponse(file_path, media_type=media_type)
|
||||
# HTML 文件添加防索引头
|
||||
if str(file_path).endswith('.html'):
|
||||
if str(file_path).endswith(".html"):
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
|
|
@ -130,23 +130,15 @@ class WebUIServer:
|
|||
"""配置防爬虫中间件"""
|
||||
try:
|
||||
from src.webui.anti_crawler import AntiCrawlerMiddleware
|
||||
|
||||
|
||||
# 从环境变量读取防爬虫模式(false/strict/loose/basic)
|
||||
anti_crawler_mode = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
|
||||
|
||||
|
||||
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
|
||||
# 我们需要在CORS之前注册,这样防爬虫检查会在CORS之前执行
|
||||
self.app.add_middleware(
|
||||
AntiCrawlerMiddleware,
|
||||
mode=anti_crawler_mode
|
||||
)
|
||||
|
||||
mode_descriptions = {
|
||||
"false": "已禁用",
|
||||
"strict": "严格模式",
|
||||
"loose": "宽松模式",
|
||||
"basic": "基础模式"
|
||||
}
|
||||
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
|
||||
|
||||
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
|
||||
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
|
||||
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
|
||||
except Exception as e:
|
||||
|
|
@ -156,12 +148,12 @@ class WebUIServer:
|
|||
"""设置robots.txt路由"""
|
||||
try:
|
||||
from src.webui.anti_crawler import create_robots_txt_response
|
||||
|
||||
|
||||
@self.app.get("/robots.txt", include_in_schema=False)
|
||||
async def robots_txt():
|
||||
"""返回robots.txt,禁止所有爬虫"""
|
||||
return create_robots_txt_response()
|
||||
|
||||
|
||||
logger.debug("✅ robots.txt 路由已注册")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Cookie, Header, HTTPException
|
||||
from fastapi import APIRouter, Cookie, Header
|
||||
from typing import Optional
|
||||
import secrets
|
||||
import time
|
||||
|
|
@ -30,10 +30,10 @@ def _cleanup_expired_ws_tokens():
|
|||
|
||||
def generate_ws_token(session_token: str) -> str:
|
||||
"""生成 WebSocket 临时 token
|
||||
|
||||
|
||||
Args:
|
||||
session_token: 原始的 session token
|
||||
|
||||
|
||||
Returns:
|
||||
临时 token 字符串
|
||||
"""
|
||||
|
|
@ -46,10 +46,10 @@ def generate_ws_token(session_token: str) -> str:
|
|||
|
||||
def verify_ws_token(temp_token: str) -> bool:
|
||||
"""验证并消费 WebSocket 临时 token(一次性使用)
|
||||
|
||||
|
||||
Args:
|
||||
temp_token: 临时 token
|
||||
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
|
|
@ -81,11 +81,11 @@ async def get_ws_token(
|
|||
):
|
||||
"""
|
||||
获取 WebSocket 连接用的临时 token
|
||||
|
||||
|
||||
此端点验证当前会话的 Cookie 或 Authorization header,
|
||||
然后返回一个临时 token 用于 WebSocket 握手认证。
|
||||
临时 token 有效期 60 秒,且只能使用一次。
|
||||
|
||||
|
||||
注意:在未认证时返回 200 状态码但 success=False,避免前端因 401 刷新页面。
|
||||
"""
|
||||
# 获取当前 session token
|
||||
|
|
@ -94,21 +94,21 @@ async def get_ws_token(
|
|||
session_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
session_token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not session_token:
|
||||
# 返回 200 但 success=False,避免前端因 401 刷新页面
|
||||
# 这在登录页面是正常情况,不应该触发错误处理
|
||||
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
|
||||
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
|
||||
|
||||
|
||||
# 验证 session token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
# 同样返回 200 但 success=False,避免前端刷新
|
||||
logger.debug("ws-token 请求:认证已过期")
|
||||
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
|
||||
|
||||
|
||||
# 生成临时 WebSocket token
|
||||
ws_token = generate_ws_token(session_token)
|
||||
|
||||
|
||||
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}
|
||||
|
|
|
|||
Loading…
Reference in New Issue