Ruff Format

pull/1443/head
墨梓柒 2025-12-14 21:39:09 +08:00
parent ee4cb3dc67
commit 74a2f4346a
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
11 changed files with 176 additions and 197 deletions

View File

@ -126,6 +126,7 @@ SCANNER_SPECIFIC_HEADERS = {
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP # basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower() ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
# IP白名单配置从环境变量读取逗号分隔 # IP白名单配置从环境变量读取逗号分隔
# 支持格式: # 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100 # - 精确IP127.0.0.1, 192.168.1.100
@ -135,10 +136,10 @@ ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
def _parse_allowed_ips(ip_string: str) -> list: def _parse_allowed_ips(ip_string: str) -> list:
""" """
解析IP白名单字符串支持精确IPCIDR格式和通配符 解析IP白名单字符串支持精确IPCIDR格式和通配符
Args: Args:
ip_string: 逗号分隔的IP字符串 ip_string: 逗号分隔的IP字符串
Returns: Returns:
IP白名单列表每个元素可能是 IP白名单列表每个元素可能是
- ipaddress.IPv4Network/IPv6Network对象CIDR格式 - ipaddress.IPv4Network/IPv6Network对象CIDR格式
@ -148,12 +149,12 @@ def _parse_allowed_ips(ip_string: str) -> list:
allowed = [] allowed = []
if not ip_string: if not ip_string:
return allowed return allowed
for ip_entry in ip_string.split(","): for ip_entry in ip_string.split(","):
ip_entry = ip_entry.strip() # 去除空格 ip_entry = ip_entry.strip() # 去除空格
if not ip_entry: if not ip_entry:
continue continue
# 检查通配符格式(包含* # 检查通配符格式(包含*
if "*" in ip_entry: if "*" in ip_entry:
# 处理通配符 # 处理通配符
@ -163,7 +164,7 @@ def _parse_allowed_ips(ip_string: str) -> list:
else: else:
logger.warning(f"无效的通配符IP格式已忽略: {ip_entry}") logger.warning(f"无效的通配符IP格式已忽略: {ip_entry}")
continue continue
try: try:
# 尝试解析为CIDR格式包含/ # 尝试解析为CIDR格式包含/
if "/" in ip_entry: if "/" in ip_entry:
@ -173,39 +174,39 @@ def _parse_allowed_ips(ip_string: str) -> list:
allowed.append(ipaddress.ip_address(ip_entry)) allowed.append(ipaddress.ip_address(ip_entry))
except (ValueError, AttributeError) as e: except (ValueError, AttributeError) as e:
logger.warning(f"无效的IP白名单条目已忽略: {ip_entry} ({e})") logger.warning(f"无效的IP白名单条目已忽略: {ip_entry} ({e})")
return allowed return allowed
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]: def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
""" """
将通配符IP模式转换为正则表达式 将通配符IP模式转换为正则表达式
支持的格式 支持的格式
- 192.168.*.* 192.168.* - 192.168.*.* 192.168.*
- 10.*.*.* 10.* - 10.*.*.* 10.*
- *.*.*.* * - *.*.*.* *
Args: Args:
wildcard_pattern: 通配符模式字符串 wildcard_pattern: 通配符模式字符串
Returns: Returns:
正则表达式字符串如果格式无效则返回None 正则表达式字符串如果格式无效则返回None
""" """
# 去除空格 # 去除空格
pattern = wildcard_pattern.strip() pattern = wildcard_pattern.strip()
# 处理单个*(匹配所有) # 处理单个*(匹配所有)
if pattern == "*": if pattern == "*":
return r".*" return r".*"
# 处理IPv4通配符格式 # 处理IPv4通配符格式
# 支持192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等 # 支持192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
parts = pattern.split(".") parts = pattern.split(".")
if len(parts) > 4: if len(parts) > 4:
return None # IPv4最多4段 return None # IPv4最多4段
# 构建正则表达式 # 构建正则表达式
regex_parts = [] regex_parts = []
for part in parts: for part in parts:
@ -221,15 +222,16 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
return None # 无效的数字 return None # 无效的数字
else: else:
return None # 无效的格式 return None # 无效的格式
# 如果部分少于4段补充.* # 如果部分少于4段补充.*
while len(regex_parts) < 4: while len(regex_parts) < 4:
regex_parts.append(r"\d+") regex_parts.append(r"\d+")
# 组合成正则表达式 # 组合成正则表达式
regex = r"^" + r"\.".join(regex_parts) + r"$" regex = r"^" + r"\.".join(regex_parts) + r"$"
return regex return regex
ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", "")) ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", ""))
# 信任的代理IP配置从环境变量读取逗号分隔 # 信任的代理IP配置从环境变量读取逗号分隔
@ -250,7 +252,7 @@ def _get_mode_config(mode: str) -> dict:
配置字典包含所有相关参数 配置字典包含所有相关参数
""" """
mode = mode.lower() mode = mode.lower()
if mode == "false": if mode == "false":
return { return {
"enabled": False, "enabled": False,
@ -320,7 +322,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
self.check_asset_scanner = config["check_asset_scanner"] self.check_asset_scanner = config["check_asset_scanner"]
self.check_rate_limit = config["check_rate_limit"] self.check_rate_limit = config["check_rate_limit"]
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问 self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
# 用于存储每个IP的请求时间戳使用deque提高性能 # 用于存储每个IP的请求时间戳使用deque提高性能
self.request_times: dict[str, deque] = {} self.request_times: dict[str, deque] = {}
# 上次清理时间 # 上次清理时间
@ -354,7 +356,6 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
return False return False
def _is_asset_scanner_header(self, request: Request) -> bool: def _is_asset_scanner_header(self, request: Request) -> bool:
""" """
检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配 检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配
@ -499,7 +500,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
empty_ips = [] empty_ips = []
# 找到最久未访问的IP最旧时间戳 # 找到最久未访问的IP最旧时间戳
oldest_ip = None oldest_ip = None
oldest_time = float('inf') oldest_time = float("inf")
# 全量遍历找真正的oldest超限时性能可接受 # 全量遍历找真正的oldest超限时性能可接受
for ip, times in self.request_times.items(): for ip, times in self.request_times.items():
@ -532,7 +533,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
""" """
if not TRUSTED_PROXIES or ip == "unknown": if not TRUSTED_PROXIES or ip == "unknown":
return False return False
# 检查代理列表中的每个条目 # 检查代理列表中的每个条目
for trusted_entry in TRUSTED_PROXIES: for trusted_entry in TRUSTED_PROXIES:
# 通配符模式(字符串,正则表达式) # 通配符模式(字符串,正则表达式)
@ -558,7 +559,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
return True return True
except (ValueError, AttributeError): except (ValueError, AttributeError):
continue continue
return False return False
def _get_client_ip(self, request: Request) -> str: def _get_client_ip(self, request: Request) -> str:
@ -635,7 +636,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
""" """
if not ALLOWED_IPS or ip == "unknown": if not ALLOWED_IPS or ip == "unknown":
return False return False
# 检查白名单中的每个条目 # 检查白名单中的每个条目
for allowed_entry in ALLOWED_IPS: for allowed_entry in ALLOWED_IPS:
# 通配符模式(字符串,正则表达式) # 通配符模式(字符串,正则表达式)
@ -664,7 +665,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
except (ValueError, AttributeError): except (ValueError, AttributeError):
# IP格式无效跳过 # IP格式无效跳过
continue continue
return False return False
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
@ -689,16 +690,31 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
# 允许访问静态资源CSS、JS、图片等 # 允许访问静态资源CSS、JS、图片等
# 注意:.json 已移除,避免 API 路径绕过防护 # 注意:.json 已移除,避免 API 路径绕过防护
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/ # 静态资源只在特定前缀下放行(/static/、/assets/、/dist/
static_extensions = {".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".ttf", ".eot"} static_extensions = {
".css",
".js",
".png",
".jpg",
".jpeg",
".gif",
".svg",
".ico",
".woff",
".woff2",
".ttf",
".eot",
}
static_prefixes = {"/static/", "/assets/", "/dist/"} static_prefixes = {"/static/", "/assets/", "/dist/"}
# 检查是否是静态资源路径(特定前缀下的静态文件) # 检查是否是静态资源路径(特定前缀下的静态文件)
path = request.url.path path = request.url.path
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(path.endswith(ext) for ext in static_extensions) is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
path.endswith(ext) for ext in static_extensions
)
# 也允许根路径下的静态文件(如 /favicon.ico # 也允许根路径下的静态文件(如 /favicon.ico
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions) is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
if is_static_path or is_root_static: if is_static_path or is_root_static:
return await call_next(request) return await call_next(request)
@ -729,9 +745,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
# 检测爬虫 User-Agent # 检测爬虫 User-Agent
if self.check_user_agent and self._is_crawler_user_agent(user_agent): if self.check_user_agent and self._is_crawler_user_agent(user_agent):
logger.warning( logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}"
)
# 根据配置决定是否阻止 # 根据配置决定是否阻止
if self.block_on_detect: if self.block_on_detect:
return PlainTextResponse( return PlainTextResponse(
@ -741,9 +755,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
# 检查请求频率限制 # 检查请求频率限制
if self.check_rate_limit and self._check_rate_limit(client_ip): if self.check_rate_limit and self._check_rate_limit(client_ip):
logger.warning( logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}"
)
return PlainTextResponse( return PlainTextResponse(
"Too Many Requests: Rate limit exceeded", "Too Many Requests: Rate limit exceeded",
status_code=429, status_code=429,
@ -770,4 +782,3 @@ Disallow: /
media_type="text/plain", media_type="text/plain",
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时 headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
) )

View File

@ -19,7 +19,7 @@ COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
def _is_secure_environment() -> bool: def _is_secure_environment() -> bool:
""" """
检测是否应该启用安全 CookieHTTPS 检测是否应该启用安全 CookieHTTPS
Returns: Returns:
bool: 如果应该使用 secure cookie 则返回 True bool: 如果应该使用 secure cookie 则返回 True
""" """
@ -28,12 +28,12 @@ def _is_secure_environment() -> bool:
return True return True
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"): if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
return False return False
# 检查是否是生产环境 # 检查是否是生产环境
env = os.environ.get("WEBUI_MODE", "").lower() env = os.environ.get("WEBUI_MODE", "").lower()
if env in ("production", "prod"): if env in ("production", "prod"):
return True return True
# 默认:开发环境不启用(因为通常是 HTTP # 默认:开发环境不启用(因为通常是 HTTP
return False return False
@ -87,7 +87,7 @@ def set_auth_cookie(response: Response, token: str) -> None:
""" """
# 根据环境决定安全设置 # 根据环境决定安全设置
is_secure = _is_secure_environment() is_secure = _is_secure_environment()
response.set_cookie( response.set_cookie(
key=COOKIE_NAME, key=COOKIE_NAME,
value=token, value=token,
@ -109,7 +109,7 @@ def clear_auth_cookie(response: Response) -> None:
""" """
# 保持与 set_auth_cookie 相同的安全设置 # 保持与 set_auth_cookie 相同的安全设置
is_secure = _is_secure_environment() is_secure = _is_secure_environment()
response.delete_cookie( response.delete_cookie(
key=COOKIE_NAME, key=COOKIE_NAME,
httponly=True, httponly=True,

View File

@ -31,6 +31,7 @@ def require_auth(
"""认证依赖:验证用户是否已登录""" """认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization) return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID # WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat" WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui" WEBUI_CHAT_PLATFORM = "webui"
@ -399,21 +400,21 @@ async def websocket_chat(
token: 认证 token可选也可从 Cookie 获取 token: 认证 token可选也可从 Cookie 获取
虚拟身份模式可通过 URL 参数直接配置或通过消息中的 set_virtual_identity 配置 虚拟身份模式可通过 URL 参数直接配置或通过消息中的 set_virtual_identity 配置
支持三种认证方式按优先级 支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token 1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session 2. Cookie 中的 maibot_session
3. 直接使用 session token兼容 3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx 示例ws://host/api/chat/ws?token=xxx
""" """
is_authenticated = False is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式 # 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token): if token and verify_ws_token(token):
is_authenticated = True is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功") logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token # 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated: if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session") cookie_token = websocket.cookies.get("maibot_session")
@ -422,19 +423,19 @@ async def websocket_chat(
if token_manager.verify_token(cookie_token): if token_manager.verify_token(cookie_token):
is_authenticated = True is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功") logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式 # 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token: if not is_authenticated and token:
token_manager = get_token_manager() token_manager = get_token_manager()
if token_manager.verify_token(token): if token_manager.verify_token(token):
is_authenticated = True is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功") logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated: if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败") logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录") await websocket.close(code=4001, reason="认证失败,请重新登录")
return return
# 生成会话 ID每次连接都是新的 # 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())

View File

@ -346,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depen
@router.post("/model/section/{section_name}") @router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)): async def update_model_config_section(
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
):
"""更新模型配置的指定节(保留注释和格式)""" """更新模型配置的指定节(保留注释和格式)"""
try: try:
# 读取现有配置 # 读取现有配置
@ -383,8 +385,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)} provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", []) models = config_data.get("models", [])
orphaned_models = [ orphaned_models = [
m.get("name") for m in models m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
if isinstance(m, dict) and m.get("api_provider") not in provider_names
] ]
if orphaned_models: if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。" error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"

View File

@ -83,16 +83,16 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token 1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session 2. Cookie 中的 maibot_session
3. 直接使用 session token兼容 3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx 示例ws://host/ws/logs?token=xxx
""" """
is_authenticated = False is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式 # 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token): if token and verify_ws_token(token):
is_authenticated = True is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功") logger.debug("WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token # 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated: if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session") cookie_token = websocket.cookies.get("maibot_session")
@ -101,19 +101,19 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
if token_manager.verify_token(cookie_token): if token_manager.verify_token(cookie_token):
is_authenticated = True is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功") logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式 # 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token: if not is_authenticated and token:
token_manager = get_token_manager() token_manager = get_token_manager()
if token_manager.verify_token(token): if token_manager.verify_token(token):
is_authenticated = True is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功") logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated: if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败") logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录") await websocket.close(code=4001, reason="认证失败,请重新登录")
return return
await websocket.accept() await websocket.accept()
active_connections.add(websocket) active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}") logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")

View File

@ -99,16 +99,16 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token 1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session 2. Cookie 中的 maibot_session
3. 直接使用 session token兼容 3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx 示例ws://host/ws/plugin-progress?token=xxx
""" """
is_authenticated = False is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式 # 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token): if token and verify_ws_token(token):
is_authenticated = True is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功") logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token # 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated: if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session") cookie_token = websocket.cookies.get("maibot_session")
@ -117,19 +117,19 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
if token_manager.verify_token(cookie_token): if token_manager.verify_token(cookie_token):
is_authenticated = True is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功") logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式 # 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token: if not is_authenticated and token:
token_manager = get_token_manager() token_manager = get_token_manager()
if token_manager.verify_token(token): if token_manager.verify_token(token):
is_authenticated = True is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功") logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated: if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败") logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录") await websocket.close(code=4001, reason="认证失败,请重新登录")
return return
await websocket.accept() await websocket.accept()
active_connections.add(websocket) active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}") logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")

View File

@ -3,7 +3,6 @@ from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, get_origin from typing import Optional, List, Dict, Any, get_origin
from pathlib import Path from pathlib import Path
import json import json
import re
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION from src.config.config import MMC_VERSION
@ -38,54 +37,54 @@ def get_token_from_cookie_or_header(
def validate_safe_path(user_path: str, base_path: Path) -> Path: def validate_safe_path(user_path: str, base_path: Path) -> Path:
""" """
验证用户提供的路径是否安全防止路径遍历攻击 验证用户提供的路径是否安全防止路径遍历攻击
Args: Args:
user_path: 用户输入的路径相对路径 user_path: 用户输入的路径相对路径
base_path: 允许的基础目录 base_path: 允许的基础目录
Returns: Returns:
安全的绝对路径 安全的绝对路径
Raises: Raises:
HTTPException: 如果检测到路径遍历攻击 HTTPException: 如果检测到路径遍历攻击
""" """
# 规范化基础路径 # 规范化基础路径
base_resolved = base_path.resolve() base_resolved = base_path.resolve()
# 检查用户路径是否包含可疑字符 # 检查用户路径是否包含可疑字符
# 禁止: .., 绝对路径开头, 空字节等 # 禁止: .., 绝对路径开头, 空字节等
if any(pattern in user_path for pattern in ["..", "\x00"]): if any(pattern in user_path for pattern in ["..", "\x00"]):
logger.warning(f"检测到可疑路径: {user_path}") logger.warning(f"检测到可疑路径: {user_path}")
raise HTTPException(status_code=400, detail="路径包含非法字符") raise HTTPException(status_code=400, detail="路径包含非法字符")
# 检查是否为绝对路径Windows 和 Unix # 检查是否为绝对路径Windows 和 Unix
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"): if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
logger.warning(f"检测到绝对路径: {user_path}") logger.warning(f"检测到绝对路径: {user_path}")
raise HTTPException(status_code=400, detail="不允许使用绝对路径") raise HTTPException(status_code=400, detail="不允许使用绝对路径")
# 构建目标路径并解析 # 构建目标路径并解析
target_path = (base_path / user_path).resolve() target_path = (base_path / user_path).resolve()
# 验证解析后的路径仍在基础目录内 # 验证解析后的路径仍在基础目录内
try: try:
target_path.relative_to(base_resolved) target_path.relative_to(base_resolved)
except ValueError as e: except ValueError as e:
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}") logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
raise HTTPException(status_code=400, detail="路径超出允许范围") from e raise HTTPException(status_code=400, detail="路径超出允许范围") from e
return target_path return target_path
def validate_plugin_id(plugin_id: str) -> str: def validate_plugin_id(plugin_id: str) -> str:
""" """
验证插件 ID 格式是否安全 验证插件 ID 格式是否安全
Args: Args:
plugin_id: 插件 ID (支持 author.name 格式允许中文) plugin_id: 插件 ID (支持 author.name 格式允许中文)
Returns: Returns:
验证通过的插件 ID 验证通过的插件 ID
Raises: Raises:
HTTPException: 如果插件 ID 格式不安全 HTTPException: 如果插件 ID 格式不安全
""" """
@ -93,24 +92,24 @@ def validate_plugin_id(plugin_id: str) -> str:
if not plugin_id or not plugin_id.strip(): if not plugin_id or not plugin_id.strip():
logger.warning("非法插件 ID: 空字符串") logger.warning("非法插件 ID: 空字符串")
raise HTTPException(status_code=400, detail="插件 ID 不能为空") raise HTTPException(status_code=400, detail="插件 ID 不能为空")
# 禁止危险字符: 路径分隔符、空字节、控制字符等 # 禁止危险字符: 路径分隔符、空字节、控制字符等
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"] dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
for pattern in dangerous_patterns: for pattern in dangerous_patterns:
if pattern in plugin_id: if pattern in plugin_id:
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)") logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符") raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
# 禁止以点开头或结尾(防止隐藏文件和路径问题) # 禁止以点开头或结尾(防止隐藏文件和路径问题)
if plugin_id.startswith(".") or plugin_id.endswith("."): if plugin_id.startswith(".") or plugin_id.endswith("."):
logger.warning(f"非法插件 ID: {plugin_id}") logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾") raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
# 禁止特殊名称 # 禁止特殊名称
if plugin_id in (".", ".."): if plugin_id in (".", ".."):
logger.warning(f"非法插件 ID: {plugin_id}") logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名") raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
return plugin_id return plugin_id
@ -556,10 +555,7 @@ async def fetch_raw_file(
if not token or not token_manager.verify_token(token): if not token or not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
logger.info( logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}")
f"收到获取 Raw 文件请求: "
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
)
# 发送开始加载进度 # 发送开始加载进度
await update_progress( await update_progress(
@ -688,7 +684,7 @@ async def install_plugin(
try: try:
# 验证插件 ID 格式安全性 # 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id) plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始安装 # 推送进度:开始安装
await update_progress( await update_progress(
stage="loading", stage="loading",
@ -899,7 +895,7 @@ async def uninstall_plugin(
try: try:
# 验证插件 ID 格式安全性 # 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id) plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始卸载 # 推送进度:开始卸载
await update_progress( await update_progress(
stage="loading", stage="loading",
@ -1041,7 +1037,7 @@ async def update_plugin(
try: try:
# 验证插件 ID 格式安全性 # 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id) plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始更新 # 推送进度:开始更新
await update_progress( await update_progress(
stage="loading", stage="loading",
@ -1494,7 +1490,7 @@ async def get_plugin_config_schema(
ui_type = "text" ui_type = "text"
item_type = None item_type = None
item_fields = None item_fields = None
if isinstance(field_value, bool): if isinstance(field_value, bool):
ui_type = "switch" ui_type = "switch"
elif isinstance(field_value, (int, float)): elif isinstance(field_value, (int, float)):

View File

@ -15,16 +15,16 @@ logger = get_logger("webui.rate_limiter")
class RateLimiter: class RateLimiter:
""" """
简单的内存请求频率限制器 简单的内存请求频率限制器
使用滑动窗口算法实现 使用滑动窗口算法实现
""" """
def __init__(self): def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]} # 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list) self._requests: Dict[str, list] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp} # 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {} self._blocked: Dict[str, float] = {}
def _get_client_ip(self, request: Request) -> str: def _get_client_ip(self, request: Request) -> str:
"""获取客户端 IP 地址""" """获取客户端 IP 地址"""
# 检查代理头 # 检查代理头
@ -32,26 +32,23 @@ class RateLimiter:
if forwarded: if forwarded:
# 取第一个 IP最原始的客户端 # 取第一个 IP最原始的客户端
return forwarded.split(",")[0].strip() return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP") real_ip = request.headers.get("X-Real-IP")
if real_ip: if real_ip:
return real_ip return real_ip
# 直接连接的客户端 # 直接连接的客户端
if request.client: if request.client:
return request.client.host return request.client.host
return "unknown" return "unknown"
def _cleanup_old_requests(self, key: str, window_seconds: int): def _cleanup_old_requests(self, key: str, window_seconds: int):
"""清理过期的请求记录""" """清理过期的请求记录"""
now = time.time() now = time.time()
cutoff = now - window_seconds cutoff = now - window_seconds
self._requests[key] = [ self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
(ts, count) for ts, count in self._requests[key]
if ts > cutoff
]
def _cleanup_expired_blocks(self): def _cleanup_expired_blocks(self):
"""清理过期的封禁""" """清理过期的封禁"""
now = time.time() now = time.time()
@ -59,65 +56,61 @@ class RateLimiter:
for ip in expired: for ip in expired:
del self._blocked[ip] del self._blocked[ip]
logger.info(f"🔓 IP {ip} 封禁已解除") logger.info(f"🔓 IP {ip} 封禁已解除")
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]: def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
""" """
检查 IP 是否被封禁 检查 IP 是否被封禁
Returns: Returns:
(是否被封禁, 剩余封禁秒数) (是否被封禁, 剩余封禁秒数)
""" """
self._cleanup_expired_blocks() self._cleanup_expired_blocks()
ip = self._get_client_ip(request) ip = self._get_client_ip(request)
if ip in self._blocked: if ip in self._blocked:
remaining = int(self._blocked[ip] - time.time()) remaining = int(self._blocked[ip] - time.time())
return True, max(0, remaining) return True, max(0, remaining)
return False, None return False, None
def check_rate_limit( def check_rate_limit(
self, self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
request: Request,
max_requests: int,
window_seconds: int,
key_suffix: str = ""
) -> Tuple[bool, int]: ) -> Tuple[bool, int]:
""" """
检查请求是否超过频率限制 检查请求是否超过频率限制
Args: Args:
request: FastAPI Request 对象 request: FastAPI Request 对象
max_requests: 窗口期内允许的最大请求数 max_requests: 窗口期内允许的最大请求数
window_seconds: 窗口时间 window_seconds: 窗口时间
key_suffix: 键后缀用于区分不同的限制规则 key_suffix: 键后缀用于区分不同的限制规则
Returns: Returns:
(是否允许, 剩余请求数) (是否允许, 剩余请求数)
""" """
ip = self._get_client_ip(request) ip = self._get_client_ip(request)
key = f"{ip}:{key_suffix}" if key_suffix else ip key = f"{ip}:{key_suffix}" if key_suffix else ip
# 清理过期记录 # 清理过期记录
self._cleanup_old_requests(key, window_seconds) self._cleanup_old_requests(key, window_seconds)
# 计算当前窗口内的请求数 # 计算当前窗口内的请求数
current_count = sum(count for _, count in self._requests[key]) current_count = sum(count for _, count in self._requests[key])
if current_count >= max_requests: if current_count >= max_requests:
return False, 0 return False, 0
# 记录新请求 # 记录新请求
now = time.time() now = time.time()
self._requests[key].append((now, 1)) self._requests[key].append((now, 1))
remaining = max_requests - current_count - 1 remaining = max_requests - current_count - 1
return True, remaining return True, remaining
def block_ip(self, request: Request, duration_seconds: int): def block_ip(self, request: Request, duration_seconds: int):
""" """
封禁 IP 封禁 IP
Args: Args:
request: FastAPI Request 对象 request: FastAPI Request 对象
duration_seconds: 封禁时长 duration_seconds: 封禁时长
@ -125,55 +118,51 @@ class RateLimiter:
ip = self._get_client_ip(request) ip = self._get_client_ip(request)
self._blocked[ip] = time.time() + duration_seconds self._blocked[ip] = time.time() + duration_seconds
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}") logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt( def record_failed_attempt(
self, self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
request: Request,
max_failures: int = 5,
window_seconds: int = 300,
block_duration: int = 600
) -> Tuple[bool, int]: ) -> Tuple[bool, int]:
""" """
记录失败尝试如登录失败 记录失败尝试如登录失败
如果在窗口期内失败次数过多自动封禁 IP 如果在窗口期内失败次数过多自动封禁 IP
Args: Args:
request: FastAPI Request 对象 request: FastAPI Request 对象
max_failures: 允许的最大失败次数 max_failures: 允许的最大失败次数
window_seconds: 统计窗口 window_seconds: 统计窗口
block_duration: 封禁时长 block_duration: 封禁时长
Returns: Returns:
(是否被封禁, 剩余尝试次数) (是否被封禁, 剩余尝试次数)
""" """
ip = self._get_client_ip(request) ip = self._get_client_ip(request)
key = f"{ip}:auth_failures" key = f"{ip}:auth_failures"
# 清理过期记录 # 清理过期记录
self._cleanup_old_requests(key, window_seconds) self._cleanup_old_requests(key, window_seconds)
# 计算当前失败次数 # 计算当前失败次数
current_failures = sum(count for _, count in self._requests[key]) current_failures = sum(count for _, count in self._requests[key])
# 记录本次失败 # 记录本次失败
now = time.time() now = time.time()
self._requests[key].append((now, 1)) self._requests[key].append((now, 1))
current_failures += 1 current_failures += 1
remaining = max_failures - current_failures remaining = max_failures - current_failures
# 检查是否需要封禁 # 检查是否需要封禁
if current_failures >= max_failures: if current_failures >= max_failures:
self.block_ip(request, block_duration) self.block_ip(request, block_duration)
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁") logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
return True, 0 return True, 0
if current_failures >= max_failures - 2: if current_failures >= max_failures - 2:
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}") logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}")
return False, max(0, remaining) return False, max(0, remaining)
def reset_failures(self, request: Request): def reset_failures(self, request: Request):
""" """
重置失败计数认证成功后调用 重置失败计数认证成功后调用
@ -199,66 +188,58 @@ def get_rate_limiter() -> RateLimiter:
async def check_auth_rate_limit(request: Request): async def check_auth_rate_limit(request: Request):
""" """
认证接口的频率限制依赖 认证接口的频率限制依赖
规则 规则
- 每个 IP 每分钟最多 10 次认证请求 - 每个 IP 每分钟最多 10 次认证请求
- 连续失败 5 次后封禁 10 分钟 - 连续失败 5 次后封禁 10 分钟
""" """
limiter = get_rate_limiter() limiter = get_rate_limiter()
# 检查是否被封禁 # 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request) blocked, remaining_block = limiter.is_blocked(request)
if blocked: if blocked:
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)} headers={"Retry-After": str(remaining_block)},
) )
# 检查频率限制 # 检查频率限制
allowed, remaining = limiter.check_rate_limit( allowed, remaining = limiter.check_rate_limit(
request, request,
max_requests=10, # 每分钟 10 次 max_requests=10, # 每分钟 10 次
window_seconds=60, window_seconds=60,
key_suffix="auth" key_suffix="auth",
) )
if not allowed: if not allowed:
raise HTTPException( raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
status_code=429,
detail="认证请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)
async def check_api_rate_limit(request: Request): async def check_api_rate_limit(request: Request):
""" """
普通 API 的频率限制依赖 普通 API 的频率限制依赖
规则每个 IP 每分钟最多 100 次请求 规则每个 IP 每分钟最多 100 次请求
""" """
limiter = get_rate_limiter() limiter = get_rate_limiter()
# 检查是否被封禁 # 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request) blocked, remaining_block = limiter.is_blocked(request)
if blocked: if blocked:
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)} headers={"Retry-After": str(remaining_block)},
) )
# 检查频率限制 # 检查频率限制
allowed, _ = limiter.check_rate_limit( allowed, _ = limiter.check_rate_limit(
request, request,
max_requests=100, # 每分钟 100 次 max_requests=100, # 每分钟 100 次
window_seconds=60, window_seconds=60,
key_suffix="api" key_suffix="api",
) )
if not allowed: if not allowed:
raise HTTPException( raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
status_code=429,
detail="请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)

View File

@ -112,7 +112,7 @@ async def health_check():
@router.post("/auth/verify", response_model=TokenVerifyResponse) @router.post("/auth/verify", response_model=TokenVerifyResponse)
async def verify_token( async def verify_token(
request_body: TokenVerifyRequest, request_body: TokenVerifyRequest,
request: Request, request: Request,
response: Response, response: Response,
_rate_limit: None = Depends(check_auth_rate_limit), _rate_limit: None = Depends(check_auth_rate_limit),
@ -131,7 +131,7 @@ async def verify_token(
try: try:
token_manager = get_token_manager() token_manager = get_token_manager()
rate_limiter = get_rate_limiter() rate_limiter = get_rate_limiter()
is_valid = token_manager.verify_token(request_body.token) is_valid = token_manager.verify_token(request_body.token)
if is_valid: if is_valid:
@ -146,21 +146,18 @@ async def verify_token(
# 记录失败尝试 # 记录失败尝试
blocked, remaining = rate_limiter.record_failed_attempt( blocked, remaining = rate_limiter.record_failed_attempt(
request, request,
max_failures=5, # 5 次失败 max_failures=5, # 5 次失败
window_seconds=300, # 5 分钟窗口 window_seconds=300, # 5 分钟窗口
block_duration=600 # 封禁 10 分钟 block_duration=600, # 封禁 10 分钟
) )
if blocked: if blocked:
raise HTTPException( raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
status_code=429,
detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟"
)
message = "Token 无效或已过期" message = "Token 无效或已过期"
if remaining <= 2: if remaining <= 2:
message += f"(剩余 {remaining} 次尝试机会)" message += f"(剩余 {remaining} 次尝试机会)"
return TokenVerifyResponse(valid=False, message=message) return TokenVerifyResponse(valid=False, message=message)
except HTTPException: except HTTPException:
raise raise

View File

@ -34,7 +34,7 @@ class WebUIServer:
# 重要:先注册 API 路由,再设置静态文件 # 重要:先注册 API 路由,再设置静态文件
self._register_api_routes() self._register_api_routes()
self._setup_static_files() self._setup_static_files()
# 注册robots.txt路由 # 注册robots.txt路由
self._setup_robots_txt() self._setup_robots_txt()
@ -115,7 +115,7 @@ class WebUIServer:
media_type = mimetypes.guess_type(str(file_path))[0] media_type = mimetypes.guess_type(str(file_path))[0]
response = FileResponse(file_path, media_type=media_type) response = FileResponse(file_path, media_type=media_type)
# HTML 文件添加防索引头 # HTML 文件添加防索引头
if str(file_path).endswith('.html'): if str(file_path).endswith(".html"):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive" response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response return response
@ -130,23 +130,15 @@ class WebUIServer:
"""配置防爬虫中间件""" """配置防爬虫中间件"""
try: try:
from src.webui.anti_crawler import AntiCrawlerMiddleware from src.webui.anti_crawler import AntiCrawlerMiddleware
# 从环境变量读取防爬虫模式false/strict/loose/basic # 从环境变量读取防爬虫模式false/strict/loose/basic
anti_crawler_mode = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower() anti_crawler_mode = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行 # 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
# 我们需要在CORS之前注册这样防爬虫检查会在CORS之前执行 # 我们需要在CORS之前注册这样防爬虫检查会在CORS之前执行
self.app.add_middleware( self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
AntiCrawlerMiddleware,
mode=anti_crawler_mode mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
)
mode_descriptions = {
"false": "已禁用",
"strict": "严格模式",
"loose": "宽松模式",
"basic": "基础模式"
}
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式") mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}") logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
except Exception as e: except Exception as e:
@ -156,12 +148,12 @@ class WebUIServer:
"""设置robots.txt路由""" """设置robots.txt路由"""
try: try:
from src.webui.anti_crawler import create_robots_txt_response from src.webui.anti_crawler import create_robots_txt_response
@self.app.get("/robots.txt", include_in_schema=False) @self.app.get("/robots.txt", include_in_schema=False)
async def robots_txt(): async def robots_txt():
"""返回robots.txt禁止所有爬虫""" """返回robots.txt禁止所有爬虫"""
return create_robots_txt_response() return create_robots_txt_response()
logger.debug("✅ robots.txt 路由已注册") logger.debug("✅ robots.txt 路由已注册")
except Exception as e: except Exception as e:
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True) logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)

View File

@ -4,7 +4,7 @@
临时 token 有效期 60 且只能使用一次用于解决 WebSocket 握手时 Cookie 不可用的问题 临时 token 有效期 60 且只能使用一次用于解决 WebSocket 握手时 Cookie 不可用的问题
""" """
from fastapi import APIRouter, Cookie, Header, HTTPException from fastapi import APIRouter, Cookie, Header
from typing import Optional from typing import Optional
import secrets import secrets
import time import time
@ -30,10 +30,10 @@ def _cleanup_expired_ws_tokens():
def generate_ws_token(session_token: str) -> str: def generate_ws_token(session_token: str) -> str:
"""生成 WebSocket 临时 token """生成 WebSocket 临时 token
Args: Args:
session_token: 原始的 session token session_token: 原始的 session token
Returns: Returns:
临时 token 字符串 临时 token 字符串
""" """
@ -46,10 +46,10 @@ def generate_ws_token(session_token: str) -> str:
def verify_ws_token(temp_token: str) -> bool: def verify_ws_token(temp_token: str) -> bool:
"""验证并消费 WebSocket 临时 token一次性使用 """验证并消费 WebSocket 临时 token一次性使用
Args: Args:
temp_token: 临时 token temp_token: 临时 token
Returns: Returns:
验证是否通过 验证是否通过
""" """
@ -81,11 +81,11 @@ async def get_ws_token(
): ):
""" """
获取 WebSocket 连接用的临时 token 获取 WebSocket 连接用的临时 token
此端点验证当前会话的 Cookie Authorization header 此端点验证当前会话的 Cookie Authorization header
然后返回一个临时 token 用于 WebSocket 握手认证 然后返回一个临时 token 用于 WebSocket 握手认证
临时 token 有效期 60 且只能使用一次 临时 token 有效期 60 且只能使用一次
注意在未认证时返回 200 状态码但 success=False避免前端因 401 刷新页面 注意在未认证时返回 200 状态码但 success=False避免前端因 401 刷新页面
""" """
# 获取当前 session token # 获取当前 session token
@ -94,21 +94,21 @@ async def get_ws_token(
session_token = maibot_session session_token = maibot_session
elif authorization and authorization.startswith("Bearer "): elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "") session_token = authorization.replace("Bearer ", "")
if not session_token: if not session_token:
# 返回 200 但 success=False避免前端因 401 刷新页面 # 返回 200 但 success=False避免前端因 401 刷新页面
# 这在登录页面是正常情况,不应该触发错误处理 # 这在登录页面是正常情况,不应该触发错误处理
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)") logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0} return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
# 验证 session token # 验证 session token
token_manager = get_token_manager() token_manager = get_token_manager()
if not token_manager.verify_token(session_token): if not token_manager.verify_token(session_token):
# 同样返回 200 但 success=False避免前端刷新 # 同样返回 200 但 success=False避免前端刷新
logger.debug("ws-token 请求:认证已过期") logger.debug("ws-token 请求:认证已过期")
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0} return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
# 生成临时 WebSocket token # 生成临时 WebSocket token
ws_token = generate_ws_token(session_token) ws_token = generate_ws_token(session_token)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS} return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}