diff --git a/src/webui/anti_crawler.py b/src/webui/anti_crawler.py index c8c3c318..cc1cb202 100644 --- a/src/webui/anti_crawler.py +++ b/src/webui/anti_crawler.py @@ -7,13 +7,11 @@ import os import time import ipaddress import re -from collections import defaultdict -from typing import Optional, Union -from functools import lru_cache +from collections import deque +from typing import Optional from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from starlette.responses import Response, PlainTextResponse -from fastapi import HTTPException +from starlette.responses import PlainTextResponse from src.common.logger import get_logger @@ -110,18 +108,15 @@ ASSET_SCANNER_HEADERS = { # - "x-real-ip" (反向代理标准头,已在_get_client_ip中使用) } -# 可疑的HTTP头值模式(用于检测扫描工具) -SUSPICIOUS_HEADER_PATTERNS = { - "shodan", - "censys", - "zoomeye", - "fofa", - "quake", - "scanner", - "probe", - "scan", - "recon", - "reconnaissance", +# 仅检查特定HTTP头中的可疑模式(收紧匹配范围) +# 只检查这些特定头,不检查所有头 +SCANNER_SPECIFIC_HEADERS = { + "x-scan", + "x-scanner", + "x-probe", + "x-originating-ip", + "x-remote-ip", + "x-remote-addr", } # 防爬虫模式配置 @@ -237,6 +232,12 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]: ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", "")) +# 信任的代理IP配置(从环境变量读取,逗号分隔) +# 只有在信任的代理IP下才使用X-Forwarded-For头 +# 默认关闭(空),不信任任何代理 +TRUSTED_PROXIES = _parse_allowed_ips(os.getenv("WEBUI_TRUSTED_PROXIES", "")) +TRUST_XFF = os.getenv("WEBUI_TRUST_XFF", "false").lower() == "true" + def _get_mode_config(mode: str) -> dict: """ @@ -320,14 +321,13 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): self.check_rate_limit = config["check_rate_limit"] self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问 - # 用于存储每个IP的请求时间戳 - self.request_times: dict[str, list[float]] = defaultdict(list) + # 用于存储每个IP的请求时间戳(使用deque提高性能) + self.request_times: dict[str, deque] = {} # 上次清理时间 self.last_cleanup = time.time() # 将关键词列表转换为集合以提高查找性能 self.crawler_keywords_set = set(CRAWLER_USER_AGENTS) self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS) - self.suspicious_patterns_set = set(SUSPICIOUS_HEADER_PATTERNS) def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool: """ @@ -354,31 +354,10 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): return False - def _is_asset_scanner_user_agent(self, user_agent: Optional[str]) -> bool: - """ - 检测是否为资产测绘工具 User-Agent - - Args: - user_agent: User-Agent 字符串 - - Returns: - 如果是资产测绘工具则返回 True - """ - if not user_agent: - return False - - user_agent_lower = user_agent.lower() - - # 检查是否包含资产测绘工具关键词 - for scanner_keyword in ASSET_SCANNER_USER_AGENTS: - if scanner_keyword in user_agent_lower: - return True - - return False def _is_asset_scanner_header(self, request: Request) -> bool: """ - 检测是否为资产测绘工具的HTTP头 + 检测是否为资产测绘工具的HTTP头(只检查特定头,收紧匹配) Args: request: 请求对象 @@ -386,7 +365,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): Returns: 如果检测到资产测绘工具头则返回 True """ - # 检查所有HTTP头 + # 只检查特定的扫描工具头,不检查所有头 for header_name, header_value in request.headers.items(): header_name_lower = header_name.lower() header_value_lower = header_value.lower() if header_value else "" @@ -404,10 +383,12 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): if header_value_lower: return True - # 使用集合查找提高性能(检查头值中是否包含可疑模式) - for pattern in self.suspicious_patterns_set: - if pattern in header_name_lower or pattern in header_value_lower: - return True + # 只检查特定头中的可疑模式(收紧匹配) + if header_name_lower in SCANNER_SPECIFIC_HEADERS: + # 检查头值中是否包含已知扫描工具名称 + for tool in self.scanner_keywords_set: + if tool in header_value_lower: + return True return False @@ -441,16 +422,18 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): detected_tool = tool break - # 检查HTTP头中的工具标识 + # 检查HTTP头中的工具标识(只检查特定头) if not detected_tool: for header_name, header_value in request.headers.items(): - header_value_lower = (header_value or "").lower() - for tool in self.scanner_keywords_set: - if tool in header_value_lower: - detected_tool = tool + header_name_lower = header_name.lower() + if header_name_lower in SCANNER_SPECIFIC_HEADERS: + header_value_lower = (header_value or "").lower() + for tool in self.scanner_keywords_set: + if tool in header_value_lower: + detected_tool = tool + break + if detected_tool: break - if detected_tool: - break return True, detected_tool or "unknown_scanner" @@ -470,11 +453,6 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): if self._is_ip_allowed(client_ip): return False - # 限制跟踪的IP数量,防止内存泄漏 - if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips: - # 清理最旧的记录 - self._cleanup_old_requests(time.time()) - current_time = time.time() # 定期清理过期的请求记录(每5分钟清理一次) @@ -482,15 +460,20 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): self._cleanup_old_requests(current_time) self.last_cleanup = current_time - # 获取该IP的请求时间列表 + # 限制跟踪的IP数量,防止内存泄漏 + if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips: + # 清理最旧的记录(删除最久未访问的IP) + self._cleanup_oldest_ips() + + # 获取或创建该IP的请求时间deque + if client_ip not in self.request_times: + self.request_times[client_ip] = deque(maxlen=self.rate_limit_max_requests * 2) + request_times = self.request_times[client_ip] - # 移除时间窗口外的请求记录 - request_times[:] = [ - req_time - for req_time in request_times - if current_time - req_time < self.rate_limit_window - ] + # 移除时间窗口外的请求记录(从左侧弹出过期记录) + while request_times and current_time - request_times[0] >= self.rate_limit_window: + request_times.popleft() # 检查是否超过限制 if len(request_times) >= self.rate_limit_max_requests: @@ -501,20 +484,84 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): return False def _cleanup_old_requests(self, current_time: float): - """清理过期的请求记录""" - for ip in list(self.request_times.keys()): - self.request_times[ip] = [ - req_time - for req_time in self.request_times[ip] - if current_time - req_time < self.rate_limit_window - ] - # 如果列表为空,删除该IP的记录 - if not self.request_times[ip]: + """清理过期的请求记录(只清理当前需要检查的IP,不全量遍历)""" + # 这个方法现在主要用于定期清理,实际清理在_check_rate_limit中按需进行 + # 清理最久未访问的IP记录 + if len(self.request_times) > self.max_tracked_ips * 0.8: + self._cleanup_oldest_ips() + + def _cleanup_oldest_ips(self): + """清理最久未访问的IP记录(避免全量遍历)""" + if not self.request_times: + return + + # 找到最久未访问的IP(deque为空或最旧时间戳) + oldest_ip = None + oldest_time = float('inf') + + # 只检查部分IP,不全量遍历 + check_count = min(100, len(self.request_times)) + checked = 0 + for ip, times in self.request_times.items(): + if checked >= check_count: + break + checked += 1 + if not times: + # 空deque,优先删除 del self.request_times[ip] + return + if times[0] < oldest_time: + oldest_time = times[0] + oldest_ip = ip + + # 删除最久未访问的IP + if oldest_ip: + del self.request_times[oldest_ip] + + def _is_trusted_proxy(self, ip: str) -> bool: + """ + 检查IP是否在信任的代理列表中 + + Args: + ip: IP地址字符串 + + Returns: + 如果是信任的代理则返回 True + """ + if not TRUSTED_PROXIES or ip == "unknown": + return False + + # 检查代理列表中的每个条目 + for trusted_entry in TRUSTED_PROXIES: + # 通配符模式(字符串,正则表达式) + if isinstance(trusted_entry, str): + try: + if re.match(trusted_entry, ip): + return True + except re.error: + continue + # CIDR格式(网络对象) + elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)): + try: + client_ip_obj = ipaddress.ip_address(ip) + if client_ip_obj in trusted_entry: + return True + except (ValueError, AttributeError): + continue + # 精确IP(地址对象) + elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)): + try: + client_ip_obj = ipaddress.ip_address(ip) + if client_ip_obj == trusted_entry: + return True + except (ValueError, AttributeError): + continue + + return False def _get_client_ip(self, request: Request) -> str: """ - 获取客户端真实IP地址(带基本验证) + 获取客户端真实IP地址(带基本验证和代理信任检查) Args: request: 请求对象 @@ -522,27 +569,38 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): Returns: 客户端IP地址 """ - # 优先从 X-Forwarded-For 获取(适用于反向代理) - forwarded_for = request.headers.get("X-Forwarded-For") - if forwarded_for: - # X-Forwarded-For 可能包含多个IP,取第一个 - ip = forwarded_for.split(",")[0].strip() - # 基本验证IP格式 - if self._validate_ip(ip): - return ip - - # 从 X-Real-IP 获取 - real_ip = request.headers.get("X-Real-IP") - if real_ip: - ip = real_ip.strip() - if self._validate_ip(ip): - return ip - - # 使用客户端IP + # 获取直接连接的客户端IP(用于验证代理) + direct_client_ip = None if request.client: - ip = request.client.host - if self._validate_ip(ip): - return ip + direct_client_ip = request.client.host + + # 检查是否信任X-Forwarded-For头 + use_xff = TRUST_XFF + if not use_xff and TRUSTED_PROXIES and direct_client_ip: + # 如果配置了信任的代理列表,检查直接连接的IP是否在信任列表中 + use_xff = self._is_trusted_proxy(direct_client_ip) + + # 如果信任代理,优先从 X-Forwarded-For 获取 + if use_xff: + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + # X-Forwarded-For 可能包含多个IP,取第一个 + ip = forwarded_for.split(",")[0].strip() + # 基本验证IP格式 + if self._validate_ip(ip): + return ip + + # 从 X-Real-IP 获取(如果信任代理) + if use_xff: + real_ip = request.headers.get("X-Real-IP") + if real_ip: + ip = real_ip.strip() + if self._validate_ip(ip): + return ip + + # 使用直接连接的客户端IP + if direct_client_ip and self._validate_ip(direct_client_ip): + return direct_client_ip return "unknown" diff --git a/template/template.env b/template/template.env index a08635fb..b08fecf0 100644 --- a/template/template.env +++ b/template/template.env @@ -12,4 +12,6 @@ WEBUI_PORT=8001 # WebUI 服务器端口 WEBUI_ANTI_CRAWLER_MODE=basic # 防爬虫模式: false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止) WEBUI_ALLOWED_IPS=127.0.0.1 # IP白名单(逗号分隔,支持精确IP、CIDR格式和通配符) # 示例: 127.0.0.1,192.168.1.0/24,172.17.0.0/16 - # 注意: 不要使用 *.*.*.* 或 *,这会导致防爬虫功能完全失效 \ No newline at end of file +WEBUI_TRUSTED_PROXIES= # 信任的代理IP列表(逗号分隔),只有来自这些IP的X-Forwarded-For才被信任 + # 示例: 127.0.0.1,192.168.1.1,172.17.0.1 +WEBUI_TRUST_XFF=false # 是否信任X-Forwarded-For头(默认false,需要配合TRUSTED_PROXIES使用) \ No newline at end of file