Ruff Format

pull/1443/head
墨梓柒 2025-12-14 21:39:09 +08:00
parent ee4cb3dc67
commit 74a2f4346a
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
11 changed files with 176 additions and 197 deletions

View File

@ -126,6 +126,7 @@ SCANNER_SPECIFIC_HEADERS = {
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP # basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower() ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
# IP白名单配置从环境变量读取逗号分隔 # IP白名单配置从环境变量读取逗号分隔
# 支持格式: # 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100 # - 精确IP127.0.0.1, 192.168.1.100
@ -230,6 +231,7 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
regex = r"^" + r"\.".join(regex_parts) + r"$" regex = r"^" + r"\.".join(regex_parts) + r"$"
return regex return regex
ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", "")) ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", ""))
# 信任的代理IP配置从环境变量读取逗号分隔 # 信任的代理IP配置从环境变量读取逗号分隔
@ -354,7 +356,6 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
return False return False
def _is_asset_scanner_header(self, request: Request) -> bool: def _is_asset_scanner_header(self, request: Request) -> bool:
""" """
检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配 检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配
@ -499,7 +500,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
empty_ips = [] empty_ips = []
# 找到最久未访问的IP最旧时间戳 # 找到最久未访问的IP最旧时间戳
oldest_ip = None oldest_ip = None
oldest_time = float('inf') oldest_time = float("inf")
# 全量遍历找真正的oldest超限时性能可接受 # 全量遍历找真正的oldest超限时性能可接受
for ip, times in self.request_times.items(): for ip, times in self.request_times.items():
@ -689,12 +690,27 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
# 允许访问静态资源CSS、JS、图片等 # 允许访问静态资源CSS、JS、图片等
# 注意:.json 已移除,避免 API 路径绕过防护 # 注意:.json 已移除,避免 API 路径绕过防护
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/ # 静态资源只在特定前缀下放行(/static/、/assets/、/dist/
static_extensions = {".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".ttf", ".eot"} static_extensions = {
".css",
".js",
".png",
".jpg",
".jpeg",
".gif",
".svg",
".ico",
".woff",
".woff2",
".ttf",
".eot",
}
static_prefixes = {"/static/", "/assets/", "/dist/"} static_prefixes = {"/static/", "/assets/", "/dist/"}
# 检查是否是静态资源路径(特定前缀下的静态文件) # 检查是否是静态资源路径(特定前缀下的静态文件)
path = request.url.path path = request.url.path
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(path.endswith(ext) for ext in static_extensions) is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
path.endswith(ext) for ext in static_extensions
)
# 也允许根路径下的静态文件(如 /favicon.ico # 也允许根路径下的静态文件(如 /favicon.ico
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions) is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
@ -729,9 +745,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
# 检测爬虫 User-Agent # 检测爬虫 User-Agent
if self.check_user_agent and self._is_crawler_user_agent(user_agent): if self.check_user_agent and self._is_crawler_user_agent(user_agent):
logger.warning( logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}"
)
# 根据配置决定是否阻止 # 根据配置决定是否阻止
if self.block_on_detect: if self.block_on_detect:
return PlainTextResponse( return PlainTextResponse(
@ -741,9 +755,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
# 检查请求频率限制 # 检查请求频率限制
if self.check_rate_limit and self._check_rate_limit(client_ip): if self.check_rate_limit and self._check_rate_limit(client_ip):
logger.warning( logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}"
)
return PlainTextResponse( return PlainTextResponse(
"Too Many Requests: Rate limit exceeded", "Too Many Requests: Rate limit exceeded",
status_code=429, status_code=429,
@ -770,4 +782,3 @@ Disallow: /
media_type="text/plain", media_type="text/plain",
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时 headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
) )

View File

@ -31,6 +31,7 @@ def require_auth(
"""认证依赖:验证用户是否已登录""" """认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization) return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID # WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat" WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui" WEBUI_CHAT_PLATFORM = "webui"

View File

@ -346,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depen
@router.post("/model/section/{section_name}") @router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)): async def update_model_config_section(
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
):
"""更新模型配置的指定节(保留注释和格式)""" """更新模型配置的指定节(保留注释和格式)"""
try: try:
# 读取现有配置 # 读取现有配置
@ -383,8 +385,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)} provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", []) models = config_data.get("models", [])
orphaned_models = [ orphaned_models = [
m.get("name") for m in models m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
if isinstance(m, dict) and m.get("api_provider") not in provider_names
] ]
if orphaned_models: if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。" error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"

View File

@ -3,7 +3,6 @@ from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, get_origin from typing import Optional, List, Dict, Any, get_origin
from pathlib import Path from pathlib import Path
import json import json
import re
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION from src.config.config import MMC_VERSION
@ -556,10 +555,7 @@ async def fetch_raw_file(
if not token or not token_manager.verify_token(token): if not token or not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
logger.info( logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}")
f"收到获取 Raw 文件请求: "
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
)
# 发送开始加载进度 # 发送开始加载进度
await update_progress( await update_progress(

View File

@ -47,10 +47,7 @@ class RateLimiter:
"""清理过期的请求记录""" """清理过期的请求记录"""
now = time.time() now = time.time()
cutoff = now - window_seconds cutoff = now - window_seconds
self._requests[key] = [ self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
(ts, count) for ts, count in self._requests[key]
if ts > cutoff
]
def _cleanup_expired_blocks(self): def _cleanup_expired_blocks(self):
"""清理过期的封禁""" """清理过期的封禁"""
@ -77,11 +74,7 @@ class RateLimiter:
return False, None return False, None
def check_rate_limit( def check_rate_limit(
self, self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
request: Request,
max_requests: int,
window_seconds: int,
key_suffix: str = ""
) -> Tuple[bool, int]: ) -> Tuple[bool, int]:
""" """
检查请求是否超过频率限制 检查请求是否超过频率限制
@ -127,11 +120,7 @@ class RateLimiter:
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}") logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt( def record_failed_attempt(
self, self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
request: Request,
max_failures: int = 5,
window_seconds: int = 300,
block_duration: int = 600
) -> Tuple[bool, int]: ) -> Tuple[bool, int]:
""" """
记录失败尝试如登录失败 记录失败尝试如登录失败
@ -212,7 +201,7 @@ async def check_auth_rate_limit(request: Request):
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)} headers={"Retry-After": str(remaining_block)},
) )
# 检查频率限制 # 检查频率限制
@ -220,15 +209,11 @@ async def check_auth_rate_limit(request: Request):
request, request,
max_requests=10, # 每分钟 10 次 max_requests=10, # 每分钟 10 次
window_seconds=60, window_seconds=60,
key_suffix="auth" key_suffix="auth",
) )
if not allowed: if not allowed:
raise HTTPException( raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
status_code=429,
detail="认证请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)
async def check_api_rate_limit(request: Request): async def check_api_rate_limit(request: Request):
@ -245,7 +230,7 @@ async def check_api_rate_limit(request: Request):
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试", detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)} headers={"Retry-After": str(remaining_block)},
) )
# 检查频率限制 # 检查频率限制
@ -253,12 +238,8 @@ async def check_api_rate_limit(request: Request):
request, request,
max_requests=100, # 每分钟 100 次 max_requests=100, # 每分钟 100 次
window_seconds=60, window_seconds=60,
key_suffix="api" key_suffix="api",
) )
if not allowed: if not allowed:
raise HTTPException( raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
status_code=429,
detail="请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)

View File

@ -146,16 +146,13 @@ async def verify_token(
# 记录失败尝试 # 记录失败尝试
blocked, remaining = rate_limiter.record_failed_attempt( blocked, remaining = rate_limiter.record_failed_attempt(
request, request,
max_failures=5, # 5 次失败 max_failures=5, # 5 次失败
window_seconds=300, # 5 分钟窗口 window_seconds=300, # 5 分钟窗口
block_duration=600 # 封禁 10 分钟 block_duration=600, # 封禁 10 分钟
) )
if blocked: if blocked:
raise HTTPException( raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
status_code=429,
detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟"
)
message = "Token 无效或已过期" message = "Token 无效或已过期"
if remaining <= 2: if remaining <= 2:

View File

@ -115,7 +115,7 @@ class WebUIServer:
media_type = mimetypes.guess_type(str(file_path))[0] media_type = mimetypes.guess_type(str(file_path))[0]
response = FileResponse(file_path, media_type=media_type) response = FileResponse(file_path, media_type=media_type)
# HTML 文件添加防索引头 # HTML 文件添加防索引头
if str(file_path).endswith('.html'): if str(file_path).endswith(".html"):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive" response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response return response
@ -136,17 +136,9 @@ class WebUIServer:
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行 # 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
# 我们需要在CORS之前注册这样防爬虫检查会在CORS之前执行 # 我们需要在CORS之前注册这样防爬虫检查会在CORS之前执行
self.app.add_middleware( self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
AntiCrawlerMiddleware,
mode=anti_crawler_mode
)
mode_descriptions = { mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
"false": "已禁用",
"strict": "严格模式",
"loose": "宽松模式",
"basic": "基础模式"
}
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式") mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}") logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
except Exception as e: except Exception as e:

View File

@ -4,7 +4,7 @@
临时 token 有效期 60 且只能使用一次用于解决 WebSocket 握手时 Cookie 不可用的问题 临时 token 有效期 60 且只能使用一次用于解决 WebSocket 握手时 Cookie 不可用的问题
""" """
from fastapi import APIRouter, Cookie, Header, HTTPException from fastapi import APIRouter, Cookie, Header
from typing import Optional from typing import Optional
import secrets import secrets
import time import time