添加认证依赖和请求频率限制模块,增强安全性和防止API滥用

pull/1438/head
墨梓柒 2025-12-14 19:39:56 +08:00
parent 071bf96e85
commit ea420f9f59
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
12 changed files with 509 additions and 55 deletions

View File

@ -3,6 +3,7 @@ WebUI 认证模块
提供统一的认证依赖支持 Cookie Header 两种方式
"""
import os
from typing import Optional
from fastapi import HTTPException, Cookie, Header, Response, Request
from src.common.logger import get_logger
@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session"
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
def _is_secure_environment() -> bool:
"""
检测是否应该启用安全 CookieHTTPS
Returns:
bool: 如果应该使用 secure cookie 则返回 True
"""
# 检查环境变量
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("true", "1", "yes"):
return True
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
return False
# 检查是否是生产环境
env = os.environ.get("WEBUI_MODE", "").lower()
if env in ("production", "prod"):
return True
# 默认:开发环境不启用(因为通常是 HTTP
return False
def get_current_token(
request: Request,
maibot_session: Optional[str] = Cookie(None),
@ -62,16 +85,19 @@ def set_auth_cookie(response: Response, token: str) -> None:
response: FastAPI Response 对象
token: 要设置的 token
"""
# 根据环境决定安全设置
is_secure = _is_secure_environment()
response.set_cookie(
key=COOKIE_NAME,
value=token,
max_age=COOKIE_MAX_AGE,
httponly=True, # 防止 JS 读取
samesite="lax", # 允许同站导航时发送 Cookie兼容开发环境代理
secure=False, # 本地开发不强制 HTTPS生产环境建议设为 True
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
samesite="strict" if is_secure else "lax", # 生产环境使用 strict 防止 CSRF
secure=is_secure, # 生产环境强制 HTTPS
path="/", # 确保 Cookie 在所有路径下可用
)
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
logger.debug(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure})")
def clear_auth_cookie(response: Response) -> None:
@ -81,10 +107,14 @@ def clear_auth_cookie(response: Response) -> None:
Args:
response: FastAPI Response 对象
"""
# 保持与 set_auth_cookie 相同的安全设置
is_secure = _is_secure_environment()
response.delete_cookie(
key=COOKIE_NAME,
httponly=True,
samesite="lax",
samesite="strict" if is_secure else "lax",
secure=is_secure,
path="/",
)
logger.debug("已清除认证 Cookie")

View File

@ -8,18 +8,28 @@
import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel
from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.token_manager import get_token_manager
logger = get_logger("webui.chat")
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
@ -256,6 +266,7 @@ async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
_auth: bool = Depends(require_auth),
):
"""获取聊天历史记录
@ -272,7 +283,7 @@ async def get_chat_history(
@router.get("/platforms")
async def get_available_platforms():
async def get_available_platforms(_auth: bool = Depends(require_auth)):
"""获取可用平台列表
PersonInfo 表中获取所有已知的平台
@ -303,6 +314,7 @@ async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
_auth: bool = Depends(require_auth),
):
"""获取指定平台的用户列表
@ -350,7 +362,7 @@ async def get_persons_by_platform(
@router.delete("/history")
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
"""清空聊天历史记录
Args:
@ -372,6 +384,7 @@ async def websocket_chat(
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
token: Optional[str] = Query(default=None), # 认证 token
):
"""WebSocket 聊天端点
@ -382,9 +395,28 @@ async def websocket_chat(
person_id: 虚拟身份模式的用户 person_id可选
group_name: 虚拟身份模式的群名可选
group_id: 虚拟身份模式的群 ID可选由前端生成并持久化
token: 认证 token可选也可从 Cookie 获取
虚拟身份模式可通过 URL 参数直接配置或通过消息中的 set_virtual_identity 配置
"""
# 认证检查
auth_token = token
if not auth_token:
# 尝试从 Cookie 获取 token
auth_token = websocket.cookies.get("maibot_session")
if not auth_token:
logger.warning("WebSocket 聊天连接被拒绝:未提供认证 token")
await websocket.close(code=4001, reason="未提供认证信息")
return
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(auth_token):
logger.warning("WebSocket 聊天连接被拒绝token 无效")
await websocket.close(code=4003, reason="Token 无效或已过期")
return
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())
@ -712,7 +744,7 @@ async def websocket_chat(
@router.get("/info")
async def get_chat_info():
async def get_chat_info(_auth: bool = Depends(require_auth)):
"""获取聊天室信息"""
return {
"bot_name": global_config.bot.nickname,

View File

@ -4,10 +4,11 @@
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body
from typing import Any, Annotated
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
from typing import Any, Annotated, Optional
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import (
@ -49,11 +50,19 @@ PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema():
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
@ -65,7 +74,7 @@ async def get_bot_config_schema():
@router.get("/schema/model")
async def get_model_config_schema():
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
@ -79,7 +88,7 @@ async def get_model_config_schema():
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str):
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
"""
获取指定配置节的架构
@ -149,7 +158,7 @@ async def get_config_section_schema(section_name: str):
@router.get("/bot")
async def get_bot_config():
async def get_bot_config(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@ -168,7 +177,7 @@ async def get_bot_config():
@router.get("/model")
async def get_model_config():
async def get_model_config(_auth: bool = Depends(require_auth)):
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
@ -190,7 +199,7 @@ async def get_model_config():
@router.post("/bot")
async def update_bot_config(config_data: ConfigBody):
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
@ -213,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody):
@router.post("/model")
async def update_model_config(config_data: ConfigBody):
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新模型配置"""
try:
# 验证配置数据
@ -239,7 +248,7 @@ async def update_model_config(config_data: ConfigBody):
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: SectionBody):
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@ -288,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
@router.get("/bot/raw")
async def get_bot_config_raw():
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@ -307,7 +316,7 @@ async def get_bot_config_raw():
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: RawContentBody):
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
@ -337,7 +346,7 @@ async def update_bot_config_raw(raw_content: RawContentBody):
@router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: SectionBody):
async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@ -430,7 +439,7 @@ def _to_relative_path(path: str) -> str:
@router.get("/adapter-config/path")
async def get_adapter_config_path():
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
@ -469,7 +478,7 @@ async def get_adapter_config_path():
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: PathBody):
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
@ -512,7 +521,7 @@ async def save_adapter_config_path(data: PathBody):
@router.get("/adapter-config")
async def get_adapter_config(path: str):
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
"""从指定路径读取适配器配置文件"""
try:
if not path:
@ -544,7 +553,7 @@ async def get_adapter_config(path: str):
@router.post("/adapter-config")
async def save_adapter_config(data: PathBody):
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")

View File

@ -1,15 +1,24 @@
"""知识库图谱可视化 API 路由"""
from typing import List, Optional
from fastapi import APIRouter, Query
from fastapi import APIRouter, Query, Depends, Cookie, Header
from pydantic import BaseModel
import logging
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class KnowledgeNode(BaseModel):
"""知识节点"""
@ -113,6 +122,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
_auth: bool = Depends(require_auth),
):
"""获取知识图谱(限制节点数量)
@ -199,7 +209,7 @@ async def get_knowledge_graph(
@router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats():
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
"""获取知识库统计信息
Returns:
@ -248,7 +258,7 @@ async def get_knowledge_stats():
@router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1)):
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
"""搜索知识节点
Args:

View File

@ -1,10 +1,11 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Optional
import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
logger = get_logger("webui.logs_ws")
router = APIRouter()
@ -73,14 +74,32 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket):
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
需要通过 query 参数传递 token 进行认证例如ws://host/ws/logs?token=xxx
"""
# 认证检查
if not token:
# 尝试从 Cookie 获取 token
token = websocket.cookies.get("maibot_session")
if not token:
logger.warning("WebSocket 连接被拒绝:未提供认证 token")
await websocket.close(code=4001, reason="未提供认证信息")
return
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
logger.warning("WebSocket 连接被拒绝token 无效")
await websocket.close(code=4003, reason="Token 无效或已过期")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:

View File

@ -6,18 +6,27 @@
import os
import httpx
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
from typing import Optional
import tomlkit
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui")
router = APIRouter(prefix="/models", tags=["models"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# 模型获取器配置
MODEL_FETCHER_CONFIG = {
# OpenAI 兼容格式的提供商
@ -184,6 +193,7 @@ async def get_provider_models(
provider_name: str = Query(..., description="提供商名称"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
_auth: bool = Depends(require_auth),
):
"""
获取指定提供商的可用模型列表
@ -228,6 +238,7 @@ async def get_models_by_url(
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
_auth: bool = Depends(require_auth),
):
"""
通过 URL 直接获取模型列表用于自定义提供商
@ -251,6 +262,7 @@ async def get_models_by_url(
async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
_auth: bool = Depends(require_auth),
):
"""
测试提供商连接状态
@ -337,6 +349,7 @@ async def test_provider_connection(
@router.post("/test-connection-by-name")
async def test_provider_connection_by_name(
provider_name: str = Query(..., description="提供商名称"),
_auth: bool = Depends(require_auth),
):
"""
通过提供商名称测试连接从配置文件读取信息

View File

@ -1,10 +1,11 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set, Dict, Any
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Dict, Any, Optional
import json
import asyncio
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
logger = get_logger("webui.plugin_progress")
@ -89,14 +90,33 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket):
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
需要通过 query 参数或 Cookie 传递 token 进行认证
"""
# 认证检查
auth_token = token
if not auth_token:
# 尝试从 Cookie 获取 token
auth_token = websocket.cookies.get("maibot_session")
if not auth_token:
logger.warning("插件进度 WebSocket 连接被拒绝:未提供认证 token")
await websocket.close(code=4001, reason="未提供认证信息")
return
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(auth_token):
logger.warning("插件进度 WebSocket 连接被拒绝token 无效")
await websocket.close(code=4003, reason="Token 无效或已过期")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态

View File

@ -0,0 +1,264 @@
"""
WebUI 请求频率限制模块
防止暴力破解和 API 滥用
"""
import time
from collections import defaultdict
from typing import Dict, Tuple, Optional
from fastapi import Request, HTTPException
from src.common.logger import get_logger
logger = get_logger("webui.rate_limiter")
class RateLimiter:
"""
简单的内存请求频率限制器
使用滑动窗口算法实现
"""
def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {}
def _get_client_ip(self, request: Request) -> str:
"""获取客户端 IP 地址"""
# 检查代理头
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# 取第一个 IP最原始的客户端
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接连接的客户端
if request.client:
return request.client.host
return "unknown"
def _cleanup_old_requests(self, key: str, window_seconds: int):
"""清理过期的请求记录"""
now = time.time()
cutoff = now - window_seconds
self._requests[key] = [
(ts, count) for ts, count in self._requests[key]
if ts > cutoff
]
def _cleanup_expired_blocks(self):
"""清理过期的封禁"""
now = time.time()
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
for ip in expired:
del self._blocked[ip]
logger.info(f"🔓 IP {ip} 封禁已解除")
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
"""
检查 IP 是否被封禁
Returns:
(是否被封禁, 剩余封禁秒数)
"""
self._cleanup_expired_blocks()
ip = self._get_client_ip(request)
if ip in self._blocked:
remaining = int(self._blocked[ip] - time.time())
return True, max(0, remaining)
return False, None
def check_rate_limit(
self,
request: Request,
max_requests: int,
window_seconds: int,
key_suffix: str = ""
) -> Tuple[bool, int]:
"""
检查请求是否超过频率限制
Args:
request: FastAPI Request 对象
max_requests: 窗口期内允许的最大请求数
window_seconds: 窗口时间
key_suffix: 键后缀用于区分不同的限制规则
Returns:
(是否允许, 剩余请求数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:{key_suffix}" if key_suffix else ip
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前窗口内的请求数
current_count = sum(count for _, count in self._requests[key])
if current_count >= max_requests:
return False, 0
# 记录新请求
now = time.time()
self._requests[key].append((now, 1))
remaining = max_requests - current_count - 1
return True, remaining
def block_ip(self, request: Request, duration_seconds: int):
"""
封禁 IP
Args:
request: FastAPI Request 对象
duration_seconds: 封禁时长
"""
ip = self._get_client_ip(request)
self._blocked[ip] = time.time() + duration_seconds
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt(
self,
request: Request,
max_failures: int = 5,
window_seconds: int = 300,
block_duration: int = 600
) -> Tuple[bool, int]:
"""
记录失败尝试如登录失败
如果在窗口期内失败次数过多自动封禁 IP
Args:
request: FastAPI Request 对象
max_failures: 允许的最大失败次数
window_seconds: 统计窗口
block_duration: 封禁时长
Returns:
(是否被封禁, 剩余尝试次数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前失败次数
current_failures = sum(count for _, count in self._requests[key])
# 记录本次失败
now = time.time()
self._requests[key].append((now, 1))
current_failures += 1
remaining = max_failures - current_failures
# 检查是否需要封禁
if current_failures >= max_failures:
self.block_ip(request, block_duration)
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
return True, 0
if current_failures >= max_failures - 2:
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}")
return False, max(0, remaining)
def reset_failures(self, request: Request):
"""
重置失败计数认证成功后调用
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
if key in self._requests:
del self._requests[key]
# 全局单例
_rate_limiter: Optional[RateLimiter] = None
def get_rate_limiter() -> RateLimiter:
"""获取 RateLimiter 单例"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter
async def check_auth_rate_limit(request: Request):
"""
认证接口的频率限制依赖
规则
- 每个 IP 每分钟最多 10 次认证请求
- 连续失败 5 次后封禁 10 分钟
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)}
)
# 检查频率限制
allowed, remaining = limiter.check_rate_limit(
request,
max_requests=10, # 每分钟 10 次
window_seconds=60,
key_suffix="auth"
)
if not allowed:
raise HTTPException(
status_code=429,
detail="认证请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)
async def check_api_rate_limit(request: Request):
"""
普通 API 的频率限制依赖
规则每个 IP 每分钟最多 100 次请求
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)}
)
# 检查频率限制
allowed, _ = limiter.check_rate_limit(
request,
max_requests=100, # 每分钟 100 次
window_seconds=60,
key_suffix="api"
)
if not allowed:
raise HTTPException(
status_code=429,
detail="请求过于频繁,请稍后重试",
headers={"Retry-After": "60"}
)

View File

@ -7,10 +7,12 @@
import os
import time
from datetime import datetime
from fastapi import APIRouter, HTTPException
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
router = APIRouter(prefix="/system", tags=["system"])
logger = get_logger("webui_system")
@ -19,6 +21,14 @@ logger = get_logger("webui_system")
_start_time = time.time()
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class RestartResponse(BaseModel):
"""重启响应"""
@ -36,7 +46,7 @@ class StatusResponse(BaseModel):
@router.post("/restart", response_model=RestartResponse)
async def restart_maibot():
async def restart_maibot(_auth: bool = Depends(require_auth)):
"""
重启麦麦主程序
@ -67,7 +77,7 @@ async def restart_maibot():
@router.get("/status", response_model=StatusResponse)
async def get_maibot_status():
async def get_maibot_status(_auth: bool = Depends(require_auth)):
"""
获取麦麦运行状态
@ -90,7 +100,7 @@ async def get_maibot_status():
@router.post("/reload-config")
async def reload_config():
async def reload_config(_auth: bool = Depends(require_auth)):
"""
热重载配置不重启进程

View File

@ -1,11 +1,12 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
from pydantic import BaseModel, Field
from typing import Optional
from src.common.logger import get_logger
from .token_manager import get_token_manager
from .auth import set_auth_cookie, clear_auth_cookie
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
from .config_routes import router as config_router
from .statistics_routes import router as statistics_router
from .person_routes import router as person_router
@ -107,12 +108,18 @@ async def health_check():
@router.post("/auth/verify", response_model=TokenVerifyResponse)
async def verify_token(request: TokenVerifyRequest, response: Response):
async def verify_token(
request_body: TokenVerifyRequest,
request: Request,
response: Response,
_rate_limit: None = Depends(check_auth_rate_limit),
):
"""
验证访问令牌验证成功后设置 HttpOnly Cookie
Args:
request: 包含 token 的验证请求
request_body: 包含 token 的验证请求
request: FastAPI Request 对象用于获取客户端 IP
response: FastAPI Response 对象
Returns:
@ -120,16 +127,40 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
"""
try:
token_manager = get_token_manager()
is_valid = token_manager.verify_token(request.token)
rate_limiter = get_rate_limiter()
is_valid = token_manager.verify_token(request_body.token)
if is_valid:
# 认证成功,重置失败计数
rate_limiter.reset_failures(request)
# 设置 HttpOnly Cookie
set_auth_cookie(response, request.token)
set_auth_cookie(response, request_body.token)
# 同时返回首次配置状态,避免额外请求
is_first_setup = token_manager.is_first_setup()
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
else:
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
# 记录失败尝试
blocked, remaining = rate_limiter.record_failed_attempt(
request,
max_failures=5, # 5 次失败
window_seconds=300, # 5 分钟窗口
block_duration=600 # 封禁 10 分钟
)
if blocked:
raise HTTPException(
status_code=429,
detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟"
)
message = "Token 无效或已过期"
if remaining <= 2:
message += f"(剩余 {remaining} 次尝试机会)"
return TokenVerifyResponse(valid=False, message=message)
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 验证失败: {e}")
raise HTTPException(status_code=500, detail="Token 验证失败") from e

View File

@ -1,19 +1,28 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.statistics")
router = APIRouter(prefix="/statistics", tags=["statistics"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
@ -58,7 +67,7 @@ class DashboardData(BaseModel):
@router.get("/dashboard", response_model=DashboardData)
async def get_dashboard_data(hours: int = 24):
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取仪表盘统计数据
@ -275,7 +284,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
@router.get("/summary")
async def get_summary(hours: int = 24):
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取统计摘要
@ -293,7 +302,7 @@ async def get_summary(hours: int = 24):
@router.get("/models")
async def get_model_stats(hours: int = 24):
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取模型统计

View File

@ -44,8 +44,15 @@ class WebUIServer:
"http://127.0.0.1:8001",
],
allow_credentials=True, # 允许携带 Cookie
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"X-Requested-With",
], # 明确指定允许的头
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
)
logger.debug("✅ CORS 中间件已配置")