mirror of https://github.com/Mai-with-u/MaiBot.git
feat:增强反爬虫中间件,对受信任代理提供支持
重构反爬虫逻辑,使用 deque 存储请求时间戳,以提升性能和内存管理效率。新增通过 WEBUI_TRUSTED_PROXIES 和 WEBUI_TRUST_XFF 配置受信任代理的支持,从而实现对 X-Forwarded-For 请求头的选择性信任。将可疑请求头的检测限制在特定请求头范围内,减少误判。更新 template.env,新增与代理相关的环境变量。pull/1439/head
parent
16271718a7
commit
97c872f4f2
|
|
@ -7,13 +7,11 @@ import os
|
|||
import time
|
||||
import ipaddress
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Union
|
||||
from functools import lru_cache
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, PlainTextResponse
|
||||
from fastapi import HTTPException
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
|
@ -110,18 +108,15 @@ ASSET_SCANNER_HEADERS = {
|
|||
# - "x-real-ip" (反向代理标准头,已在_get_client_ip中使用)
|
||||
}
|
||||
|
||||
# 可疑的HTTP头值模式(用于检测扫描工具)
|
||||
SUSPICIOUS_HEADER_PATTERNS = {
|
||||
"shodan",
|
||||
"censys",
|
||||
"zoomeye",
|
||||
"fofa",
|
||||
"quake",
|
||||
"scanner",
|
||||
"probe",
|
||||
"scan",
|
||||
"recon",
|
||||
"reconnaissance",
|
||||
# 仅检查特定HTTP头中的可疑模式(收紧匹配范围)
|
||||
# 只检查这些特定头,不检查所有头
|
||||
SCANNER_SPECIFIC_HEADERS = {
|
||||
"x-scan",
|
||||
"x-scanner",
|
||||
"x-probe",
|
||||
"x-originating-ip",
|
||||
"x-remote-ip",
|
||||
"x-remote-addr",
|
||||
}
|
||||
|
||||
# 防爬虫模式配置
|
||||
|
|
@ -237,6 +232,12 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
|||
|
||||
ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", ""))
|
||||
|
||||
# 信任的代理IP配置(从环境变量读取,逗号分隔)
|
||||
# 只有在信任的代理IP下才使用X-Forwarded-For头
|
||||
# 默认关闭(空),不信任任何代理
|
||||
TRUSTED_PROXIES = _parse_allowed_ips(os.getenv("WEBUI_TRUSTED_PROXIES", ""))
|
||||
TRUST_XFF = os.getenv("WEBUI_TRUST_XFF", "false").lower() == "true"
|
||||
|
||||
|
||||
def _get_mode_config(mode: str) -> dict:
|
||||
"""
|
||||
|
|
@ -320,14 +321,13 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
self.check_rate_limit = config["check_rate_limit"]
|
||||
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
||||
|
||||
# 用于存储每个IP的请求时间戳
|
||||
self.request_times: dict[str, list[float]] = defaultdict(list)
|
||||
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
||||
self.request_times: dict[str, deque] = {}
|
||||
# 上次清理时间
|
||||
self.last_cleanup = time.time()
|
||||
# 将关键词列表转换为集合以提高查找性能
|
||||
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
|
||||
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
|
||||
self.suspicious_patterns_set = set(SUSPICIOUS_HEADER_PATTERNS)
|
||||
|
||||
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
|
||||
"""
|
||||
|
|
@ -354,31 +354,10 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
return False
|
||||
|
||||
def _is_asset_scanner_user_agent(self, user_agent: Optional[str]) -> bool:
|
||||
"""
|
||||
检测是否为资产测绘工具 User-Agent
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent 字符串
|
||||
|
||||
Returns:
|
||||
如果是资产测绘工具则返回 True
|
||||
"""
|
||||
if not user_agent:
|
||||
return False
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# 检查是否包含资产测绘工具关键词
|
||||
for scanner_keyword in ASSET_SCANNER_USER_AGENTS:
|
||||
if scanner_keyword in user_agent_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_asset_scanner_header(self, request: Request) -> bool:
|
||||
"""
|
||||
检测是否为资产测绘工具的HTTP头
|
||||
检测是否为资产测绘工具的HTTP头(只检查特定头,收紧匹配)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
|
@ -386,7 +365,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
Returns:
|
||||
如果检测到资产测绘工具头则返回 True
|
||||
"""
|
||||
# 检查所有HTTP头
|
||||
# 只检查特定的扫描工具头,不检查所有头
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
header_value_lower = header_value.lower() if header_value else ""
|
||||
|
|
@ -404,10 +383,12 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
if header_value_lower:
|
||||
return True
|
||||
|
||||
# 使用集合查找提高性能(检查头值中是否包含可疑模式)
|
||||
for pattern in self.suspicious_patterns_set:
|
||||
if pattern in header_name_lower or pattern in header_value_lower:
|
||||
return True
|
||||
# 只检查特定头中的可疑模式(收紧匹配)
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
# 检查头值中是否包含已知扫描工具名称
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
|
@ -441,16 +422,18 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
detected_tool = tool
|
||||
break
|
||||
|
||||
# 检查HTTP头中的工具标识
|
||||
# 检查HTTP头中的工具标识(只检查特定头)
|
||||
if not detected_tool:
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_value_lower = (header_value or "").lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
detected_tool = tool
|
||||
header_name_lower = header_name.lower()
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
header_value_lower = (header_value or "").lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
detected_tool = tool
|
||||
break
|
||||
if detected_tool:
|
||||
break
|
||||
if detected_tool:
|
||||
break
|
||||
|
||||
return True, detected_tool or "unknown_scanner"
|
||||
|
||||
|
|
@ -470,11 +453,6 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
if self._is_ip_allowed(client_ip):
|
||||
return False
|
||||
|
||||
# 限制跟踪的IP数量,防止内存泄漏
|
||||
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
|
||||
# 清理最旧的记录
|
||||
self._cleanup_old_requests(time.time())
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 定期清理过期的请求记录(每5分钟清理一次)
|
||||
|
|
@ -482,15 +460,20 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
self._cleanup_old_requests(current_time)
|
||||
self.last_cleanup = current_time
|
||||
|
||||
# 获取该IP的请求时间列表
|
||||
# 限制跟踪的IP数量,防止内存泄漏
|
||||
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
|
||||
# 清理最旧的记录(删除最久未访问的IP)
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
# 获取或创建该IP的请求时间deque
|
||||
if client_ip not in self.request_times:
|
||||
self.request_times[client_ip] = deque(maxlen=self.rate_limit_max_requests * 2)
|
||||
|
||||
request_times = self.request_times[client_ip]
|
||||
|
||||
# 移除时间窗口外的请求记录
|
||||
request_times[:] = [
|
||||
req_time
|
||||
for req_time in request_times
|
||||
if current_time - req_time < self.rate_limit_window
|
||||
]
|
||||
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
|
||||
while request_times and current_time - request_times[0] >= self.rate_limit_window:
|
||||
request_times.popleft()
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(request_times) >= self.rate_limit_max_requests:
|
||||
|
|
@ -501,20 +484,84 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
return False
|
||||
|
||||
def _cleanup_old_requests(self, current_time: float):
|
||||
"""清理过期的请求记录"""
|
||||
for ip in list(self.request_times.keys()):
|
||||
self.request_times[ip] = [
|
||||
req_time
|
||||
for req_time in self.request_times[ip]
|
||||
if current_time - req_time < self.rate_limit_window
|
||||
]
|
||||
# 如果列表为空,删除该IP的记录
|
||||
if not self.request_times[ip]:
|
||||
"""清理过期的请求记录(只清理当前需要检查的IP,不全量遍历)"""
|
||||
# 这个方法现在主要用于定期清理,实际清理在_check_rate_limit中按需进行
|
||||
# 清理最久未访问的IP记录
|
||||
if len(self.request_times) > self.max_tracked_ips * 0.8:
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
def _cleanup_oldest_ips(self):
|
||||
"""清理最久未访问的IP记录(避免全量遍历)"""
|
||||
if not self.request_times:
|
||||
return
|
||||
|
||||
# 找到最久未访问的IP(deque为空或最旧时间戳)
|
||||
oldest_ip = None
|
||||
oldest_time = float('inf')
|
||||
|
||||
# 只检查部分IP,不全量遍历
|
||||
check_count = min(100, len(self.request_times))
|
||||
checked = 0
|
||||
for ip, times in self.request_times.items():
|
||||
if checked >= check_count:
|
||||
break
|
||||
checked += 1
|
||||
if not times:
|
||||
# 空deque,优先删除
|
||||
del self.request_times[ip]
|
||||
return
|
||||
if times[0] < oldest_time:
|
||||
oldest_time = times[0]
|
||||
oldest_ip = ip
|
||||
|
||||
# 删除最久未访问的IP
|
||||
if oldest_ip:
|
||||
del self.request_times[oldest_ip]
|
||||
|
||||
def _is_trusted_proxy(self, ip: str) -> bool:
|
||||
"""
|
||||
检查IP是否在信任的代理列表中
|
||||
|
||||
Args:
|
||||
ip: IP地址字符串
|
||||
|
||||
Returns:
|
||||
如果是信任的代理则返回 True
|
||||
"""
|
||||
if not TRUSTED_PROXIES or ip == "unknown":
|
||||
return False
|
||||
|
||||
# 检查代理列表中的每个条目
|
||||
for trusted_entry in TRUSTED_PROXIES:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
if isinstance(trusted_entry, str):
|
||||
try:
|
||||
if re.match(trusted_entry, ip):
|
||||
return True
|
||||
except re.error:
|
||||
continue
|
||||
# CIDR格式(网络对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj in trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
# 精确IP(地址对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj == trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实IP地址(带基本验证)
|
||||
获取客户端真实IP地址(带基本验证和代理信任检查)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
|
@ -522,27 +569,38 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||
Returns:
|
||||
客户端IP地址
|
||||
"""
|
||||
# 优先从 X-Forwarded-For 获取(适用于反向代理)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 可能包含多个IP,取第一个
|
||||
ip = forwarded_for.split(",")[0].strip()
|
||||
# 基本验证IP格式
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 从 X-Real-IP 获取
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
ip = real_ip.strip()
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 使用客户端IP
|
||||
# 获取直接连接的客户端IP(用于验证代理)
|
||||
direct_client_ip = None
|
||||
if request.client:
|
||||
ip = request.client.host
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
direct_client_ip = request.client.host
|
||||
|
||||
# 检查是否信任X-Forwarded-For头
|
||||
use_xff = TRUST_XFF
|
||||
if not use_xff and TRUSTED_PROXIES and direct_client_ip:
|
||||
# 如果配置了信任的代理列表,检查直接连接的IP是否在信任列表中
|
||||
use_xff = self._is_trusted_proxy(direct_client_ip)
|
||||
|
||||
# 如果信任代理,优先从 X-Forwarded-For 获取
|
||||
if use_xff:
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 可能包含多个IP,取第一个
|
||||
ip = forwarded_for.split(",")[0].strip()
|
||||
# 基本验证IP格式
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 从 X-Real-IP 获取(如果信任代理)
|
||||
if use_xff:
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
ip = real_ip.strip()
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 使用直接连接的客户端IP
|
||||
if direct_client_ip and self._validate_ip(direct_client_ip):
|
||||
return direct_client_ip
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
|
|
|||
|
|
@ -12,4 +12,6 @@ WEBUI_PORT=8001 # WebUI 服务器端口
|
|||
WEBUI_ANTI_CRAWLER_MODE=basic # 防爬虫模式: false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止)
|
||||
WEBUI_ALLOWED_IPS=127.0.0.1 # IP白名单(逗号分隔,支持精确IP、CIDR格式和通配符)
|
||||
# 示例: 127.0.0.1,192.168.1.0/24,172.17.0.0/16
|
||||
# 注意: 不要使用 *.*.*.* 或 *,这会导致防爬虫功能完全失效
|
||||
WEBUI_TRUSTED_PROXIES= # 信任的代理IP列表(逗号分隔),只有来自这些IP的X-Forwarded-For才被信任
|
||||
# 示例: 127.0.0.1,192.168.1.1,172.17.0.1
|
||||
WEBUI_TRUST_XFF=false # 是否信任X-Forwarded-For头(默认false,需要配合TRUSTED_PROXIES使用)
|
||||
Loading…
Reference in New Issue