feat:增强反爬虫中间件,对受信任代理提供支持

重构反爬虫逻辑,使用 deque 存储请求时间戳,以提升性能和内存管理效率。新增通过 WEBUI_TRUSTED_PROXIES 和 WEBUI_TRUST_XFF 配置受信任代理的支持,从而实现对 X-Forwarded-For 请求头的选择性信任。将可疑请求头的检测限制在特定请求头范围内,减少误判。更新 template.env,新增与代理相关的环境变量。
pull/1439/head
陈曦 2025-12-14 19:47:07 +08:00
parent 16271718a7
commit 97c872f4f2
2 changed files with 157 additions and 97 deletions

View File

@ -7,13 +7,11 @@ import os
import time
import ipaddress
import re
from collections import defaultdict
from typing import Optional, Union
from functools import lru_cache
from collections import deque
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, PlainTextResponse
from fastapi import HTTPException
from starlette.responses import PlainTextResponse
from src.common.logger import get_logger
@ -110,18 +108,15 @@ ASSET_SCANNER_HEADERS = {
# - "x-real-ip" (反向代理标准头已在_get_client_ip中使用)
}
# 可疑的HTTP头值模式用于检测扫描工具
SUSPICIOUS_HEADER_PATTERNS = {
"shodan",
"censys",
"zoomeye",
"fofa",
"quake",
"scanner",
"probe",
"scan",
"recon",
"reconnaissance",
# 仅检查特定HTTP头中的可疑模式收紧匹配范围
# 只检查这些特定头,不检查所有头
SCANNER_SPECIFIC_HEADERS = {
"x-scan",
"x-scanner",
"x-probe",
"x-originating-ip",
"x-remote-ip",
"x-remote-addr",
}
# 防爬虫模式配置
@ -237,6 +232,12 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", ""))
# 信任的代理IP配置从环境变量读取逗号分隔
# 只有在信任的代理IP下才使用X-Forwarded-For头
# 默认关闭(空),不信任任何代理
TRUSTED_PROXIES = _parse_allowed_ips(os.getenv("WEBUI_TRUSTED_PROXIES", ""))
TRUST_XFF = os.getenv("WEBUI_TRUST_XFF", "false").lower() == "true"
def _get_mode_config(mode: str) -> dict:
"""
@ -320,14 +321,13 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
self.check_rate_limit = config["check_rate_limit"]
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
# 用于存储每个IP的请求时间戳
self.request_times: dict[str, list[float]] = defaultdict(list)
# 用于存储每个IP的请求时间戳使用deque提高性能
self.request_times: dict[str, deque] = {}
# 上次清理时间
self.last_cleanup = time.time()
# 将关键词列表转换为集合以提高查找性能
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
self.suspicious_patterns_set = set(SUSPICIOUS_HEADER_PATTERNS)
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
"""
@ -354,31 +354,10 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
return False
def _is_asset_scanner_user_agent(self, user_agent: Optional[str]) -> bool:
"""
检测是否为资产测绘工具 User-Agent
Args:
user_agent: User-Agent 字符串
Returns:
如果是资产测绘工具则返回 True
"""
if not user_agent:
return False
user_agent_lower = user_agent.lower()
# 检查是否包含资产测绘工具关键词
for scanner_keyword in ASSET_SCANNER_USER_AGENTS:
if scanner_keyword in user_agent_lower:
return True
return False
def _is_asset_scanner_header(self, request: Request) -> bool:
"""
检测是否为资产测绘工具的HTTP头
检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配
Args:
request: 请求对象
@ -386,7 +365,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
Returns:
如果检测到资产测绘工具头则返回 True
"""
# 检查所有HTTP
# 只检查特定的扫描工具头,不检查所有头
for header_name, header_value in request.headers.items():
header_name_lower = header_name.lower()
header_value_lower = header_value.lower() if header_value else ""
@ -404,10 +383,12 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
if header_value_lower:
return True
# 使用集合查找提高性能(检查头值中是否包含可疑模式)
for pattern in self.suspicious_patterns_set:
if pattern in header_name_lower or pattern in header_value_lower:
return True
# 只检查特定头中的可疑模式(收紧匹配)
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
# 检查头值中是否包含已知扫描工具名称
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
return True
return False
@ -441,16 +422,18 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
detected_tool = tool
break
# 检查HTTP头中的工具标识
# 检查HTTP头中的工具标识(只检查特定头)
if not detected_tool:
for header_name, header_value in request.headers.items():
header_value_lower = (header_value or "").lower()
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
detected_tool = tool
header_name_lower = header_name.lower()
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
header_value_lower = (header_value or "").lower()
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
detected_tool = tool
break
if detected_tool:
break
if detected_tool:
break
return True, detected_tool or "unknown_scanner"
@ -470,11 +453,6 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
if self._is_ip_allowed(client_ip):
return False
# 限制跟踪的IP数量防止内存泄漏
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
# 清理最旧的记录
self._cleanup_old_requests(time.time())
current_time = time.time()
# 定期清理过期的请求记录每5分钟清理一次
@ -482,15 +460,20 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
self._cleanup_old_requests(current_time)
self.last_cleanup = current_time
# 获取该IP的请求时间列表
# 限制跟踪的IP数量防止内存泄漏
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
# 清理最旧的记录删除最久未访问的IP
self._cleanup_oldest_ips()
# 获取或创建该IP的请求时间deque
if client_ip not in self.request_times:
self.request_times[client_ip] = deque(maxlen=self.rate_limit_max_requests * 2)
request_times = self.request_times[client_ip]
# 移除时间窗口外的请求记录
request_times[:] = [
req_time
for req_time in request_times
if current_time - req_time < self.rate_limit_window
]
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
while request_times and current_time - request_times[0] >= self.rate_limit_window:
request_times.popleft()
# 检查是否超过限制
if len(request_times) >= self.rate_limit_max_requests:
@ -501,20 +484,84 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
return False
def _cleanup_old_requests(self, current_time: float):
"""清理过期的请求记录"""
for ip in list(self.request_times.keys()):
self.request_times[ip] = [
req_time
for req_time in self.request_times[ip]
if current_time - req_time < self.rate_limit_window
]
# 如果列表为空删除该IP的记录
if not self.request_times[ip]:
"""清理过期的请求记录只清理当前需要检查的IP不全量遍历"""
# 这个方法现在主要用于定期清理实际清理在_check_rate_limit中按需进行
# 清理最久未访问的IP记录
if len(self.request_times) > self.max_tracked_ips * 0.8:
self._cleanup_oldest_ips()
def _cleanup_oldest_ips(self):
"""清理最久未访问的IP记录避免全量遍历"""
if not self.request_times:
return
# 找到最久未访问的IPdeque为空或最旧时间戳
oldest_ip = None
oldest_time = float('inf')
# 只检查部分IP不全量遍历
check_count = min(100, len(self.request_times))
checked = 0
for ip, times in self.request_times.items():
if checked >= check_count:
break
checked += 1
if not times:
# 空deque优先删除
del self.request_times[ip]
return
if times[0] < oldest_time:
oldest_time = times[0]
oldest_ip = ip
# 删除最久未访问的IP
if oldest_ip:
del self.request_times[oldest_ip]
def _is_trusted_proxy(self, ip: str) -> bool:
"""
检查IP是否在信任的代理列表中
Args:
ip: IP地址字符串
Returns:
如果是信任的代理则返回 True
"""
if not TRUSTED_PROXIES or ip == "unknown":
return False
# 检查代理列表中的每个条目
for trusted_entry in TRUSTED_PROXIES:
# 通配符模式(字符串,正则表达式)
if isinstance(trusted_entry, str):
try:
if re.match(trusted_entry, ip):
return True
except re.error:
continue
# CIDR格式网络对象
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj in trusted_entry:
return True
except (ValueError, AttributeError):
continue
# 精确IP地址对象
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj == trusted_entry:
return True
except (ValueError, AttributeError):
continue
return False
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端真实IP地址带基本验证
获取客户端真实IP地址带基本验证和代理信任检查
Args:
request: 请求对象
@ -522,27 +569,38 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
Returns:
客户端IP地址
"""
# 优先从 X-Forwarded-For 获取(适用于反向代理)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 可能包含多个IP取第一个
ip = forwarded_for.split(",")[0].strip()
# 基本验证IP格式
if self._validate_ip(ip):
return ip
# 从 X-Real-IP 获取
real_ip = request.headers.get("X-Real-IP")
if real_ip:
ip = real_ip.strip()
if self._validate_ip(ip):
return ip
# 使用客户端IP
# 获取直接连接的客户端IP用于验证代理
direct_client_ip = None
if request.client:
ip = request.client.host
if self._validate_ip(ip):
return ip
direct_client_ip = request.client.host
# 检查是否信任X-Forwarded-For头
use_xff = TRUST_XFF
if not use_xff and TRUSTED_PROXIES and direct_client_ip:
# 如果配置了信任的代理列表检查直接连接的IP是否在信任列表中
use_xff = self._is_trusted_proxy(direct_client_ip)
# 如果信任代理,优先从 X-Forwarded-For 获取
if use_xff:
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 可能包含多个IP取第一个
ip = forwarded_for.split(",")[0].strip()
# 基本验证IP格式
if self._validate_ip(ip):
return ip
# 从 X-Real-IP 获取(如果信任代理)
if use_xff:
real_ip = request.headers.get("X-Real-IP")
if real_ip:
ip = real_ip.strip()
if self._validate_ip(ip):
return ip
# 使用直接连接的客户端IP
if direct_client_ip and self._validate_ip(direct_client_ip):
return direct_client_ip
return "unknown"

View File

@ -12,4 +12,6 @@ WEBUI_PORT=8001 # WebUI 服务器端口
WEBUI_ANTI_CRAWLER_MODE=basic # 防爬虫模式: false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止)
WEBUI_ALLOWED_IPS=127.0.0.1 # IP白名单逗号分隔支持精确IP、CIDR格式和通配符
# 示例: 127.0.0.1,192.168.1.0/24,172.17.0.0/16
# 注意: 不要使用 *.*.*.* 或 *,这会导致防爬虫功能完全失效
WEBUI_TRUSTED_PROXIES= # 信任的代理IP列表逗号分隔只有来自这些IP的X-Forwarded-For才被信任
# 示例: 127.0.0.1,192.168.1.1,172.17.0.1
WEBUI_TRUST_XFF=false # 是否信任X-Forwarded-For头默认false需要配合TRUSTED_PROXIES使用