mirror of https://github.com/Mai-with-u/MaiBot.git
添加认证依赖和请求频率限制模块,增强安全性和防止API滥用
parent
071bf96e85
commit
ea420f9f59
|
|
@ -3,6 +3,7 @@ WebUI 认证模块
|
||||||
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session"
|
||||||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||||
|
|
||||||
|
|
||||||
|
def _is_secure_environment() -> bool:
|
||||||
|
"""
|
||||||
|
检测是否应该启用安全 Cookie(HTTPS)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果应该使用 secure cookie 则返回 True
|
||||||
|
"""
|
||||||
|
# 检查环境变量
|
||||||
|
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("true", "1", "yes"):
|
||||||
|
return True
|
||||||
|
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否是生产环境
|
||||||
|
env = os.environ.get("WEBUI_MODE", "").lower()
|
||||||
|
if env in ("production", "prod"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 默认:开发环境不启用(因为通常是 HTTP)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_current_token(
|
def get_current_token(
|
||||||
request: Request,
|
request: Request,
|
||||||
maibot_session: Optional[str] = Cookie(None),
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
|
@ -62,16 +85,19 @@ def set_auth_cookie(response: Response, token: str) -> None:
|
||||||
response: FastAPI Response 对象
|
response: FastAPI Response 对象
|
||||||
token: 要设置的 token
|
token: 要设置的 token
|
||||||
"""
|
"""
|
||||||
|
# 根据环境决定安全设置
|
||||||
|
is_secure = _is_secure_environment()
|
||||||
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key=COOKIE_NAME,
|
key=COOKIE_NAME,
|
||||||
value=token,
|
value=token,
|
||||||
max_age=COOKIE_MAX_AGE,
|
max_age=COOKIE_MAX_AGE,
|
||||||
httponly=True, # 防止 JS 读取
|
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
|
||||||
samesite="lax", # 允许同站导航时发送 Cookie(兼容开发环境代理)
|
samesite="strict" if is_secure else "lax", # 生产环境使用 strict 防止 CSRF
|
||||||
secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True
|
secure=is_secure, # 生产环境强制 HTTPS
|
||||||
path="/", # 确保 Cookie 在所有路径下可用
|
path="/", # 确保 Cookie 在所有路径下可用
|
||||||
)
|
)
|
||||||
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
|
logger.debug(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure})")
|
||||||
|
|
||||||
|
|
||||||
def clear_auth_cookie(response: Response) -> None:
|
def clear_auth_cookie(response: Response) -> None:
|
||||||
|
|
@ -81,10 +107,14 @@ def clear_auth_cookie(response: Response) -> None:
|
||||||
Args:
|
Args:
|
||||||
response: FastAPI Response 对象
|
response: FastAPI Response 对象
|
||||||
"""
|
"""
|
||||||
|
# 保持与 set_auth_cookie 相同的安全设置
|
||||||
|
is_secure = _is_secure_environment()
|
||||||
|
|
||||||
response.delete_cookie(
|
response.delete_cookie(
|
||||||
key=COOKIE_NAME,
|
key=COOKIE_NAME,
|
||||||
httponly=True,
|
httponly=True,
|
||||||
samesite="lax",
|
samesite="strict" if is_secure else "lax",
|
||||||
|
secure=is_secure,
|
||||||
path="/",
|
path="/",
|
||||||
)
|
)
|
||||||
logger.debug("已清除认证 Cookie")
|
logger.debug("已清除认证 Cookie")
|
||||||
|
|
|
||||||
|
|
@ -8,18 +8,28 @@
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Any, Optional, List
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Messages, PersonInfo
|
from src.common.database.database_model import Messages, PersonInfo
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.bot import chat_bot
|
from src.chat.message_receive.bot import chat_bot
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
from src.webui.token_manager import get_token_manager
|
||||||
|
|
||||||
logger = get_logger("webui.chat")
|
logger = get_logger("webui.chat")
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
# WebUI 聊天的虚拟群组 ID
|
# WebUI 聊天的虚拟群组 ID
|
||||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||||
WEBUI_CHAT_PLATFORM = "webui"
|
WEBUI_CHAT_PLATFORM = "webui"
|
||||||
|
|
@ -256,6 +266,7 @@ async def get_chat_history(
|
||||||
limit: int = Query(default=50, ge=1, le=200),
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""获取聊天历史记录
|
"""获取聊天历史记录
|
||||||
|
|
||||||
|
|
@ -272,7 +283,7 @@ async def get_chat_history(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/platforms")
|
@router.get("/platforms")
|
||||||
async def get_available_platforms():
|
async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
||||||
"""获取可用平台列表
|
"""获取可用平台列表
|
||||||
|
|
||||||
从 PersonInfo 表中获取所有已知的平台
|
从 PersonInfo 表中获取所有已知的平台
|
||||||
|
|
@ -303,6 +314,7 @@ async def get_persons_by_platform(
|
||||||
platform: str = Query(..., description="平台名称"),
|
platform: str = Query(..., description="平台名称"),
|
||||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||||
limit: int = Query(default=50, ge=1, le=200),
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""获取指定平台的用户列表
|
"""获取指定平台的用户列表
|
||||||
|
|
||||||
|
|
@ -350,7 +362,7 @@ async def get_persons_by_platform(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/history")
|
@router.delete("/history")
|
||||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
|
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
|
||||||
"""清空聊天历史记录
|
"""清空聊天历史记录
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -372,6 +384,7 @@ async def websocket_chat(
|
||||||
person_id: Optional[str] = Query(default=None),
|
person_id: Optional[str] = Query(default=None),
|
||||||
group_name: Optional[str] = Query(default=None),
|
group_name: Optional[str] = Query(default=None),
|
||||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||||
|
token: Optional[str] = Query(default=None), # 认证 token
|
||||||
):
|
):
|
||||||
"""WebSocket 聊天端点
|
"""WebSocket 聊天端点
|
||||||
|
|
||||||
|
|
@ -382,9 +395,28 @@ async def websocket_chat(
|
||||||
person_id: 虚拟身份模式的用户 person_id(可选)
|
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||||
group_name: 虚拟身份模式的群名(可选)
|
group_name: 虚拟身份模式的群名(可选)
|
||||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||||
|
token: 认证 token(可选,也可从 Cookie 获取)
|
||||||
|
|
||||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||||
"""
|
"""
|
||||||
|
# 认证检查
|
||||||
|
auth_token = token
|
||||||
|
if not auth_token:
|
||||||
|
# 尝试从 Cookie 获取 token
|
||||||
|
auth_token = websocket.cookies.get("maibot_session")
|
||||||
|
|
||||||
|
if not auth_token:
|
||||||
|
logger.warning("WebSocket 聊天连接被拒绝:未提供认证 token")
|
||||||
|
await websocket.close(code=4001, reason="未提供认证信息")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 验证 token
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token_manager.verify_token(auth_token):
|
||||||
|
logger.warning("WebSocket 聊天连接被拒绝:token 无效")
|
||||||
|
await websocket.close(code=4003, reason="Token 无效或已过期")
|
||||||
|
return
|
||||||
|
|
||||||
# 生成会话 ID(每次连接都是新的)
|
# 生成会话 ID(每次连接都是新的)
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
@ -712,7 +744,7 @@ async def websocket_chat(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/info")
|
@router.get("/info")
|
||||||
async def get_chat_info():
|
async def get_chat_info(_auth: bool = Depends(require_auth)):
|
||||||
"""获取聊天室信息"""
|
"""获取聊天室信息"""
|
||||||
return {
|
return {
|
||||||
"bot_name": global_config.bot.nickname,
|
"bot_name": global_config.bot.nickname,
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,11 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import tomlkit
|
import tomlkit
|
||||||
from fastapi import APIRouter, HTTPException, Body
|
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
|
||||||
from typing import Any, Annotated
|
from typing import Any, Annotated, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||||
from src.config.official_configs import (
|
from src.config.official_configs import (
|
||||||
|
|
@ -49,11 +50,19 @@ PathBody = Annotated[dict[str, str], Body()]
|
||||||
router = APIRouter(prefix="/config", tags=["config"])
|
router = APIRouter(prefix="/config", tags=["config"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
# ===== 架构获取接口 =====
|
# ===== 架构获取接口 =====
|
||||||
|
|
||||||
|
|
||||||
@router.get("/schema/bot")
|
@router.get("/schema/bot")
|
||||||
async def get_bot_config_schema():
|
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||||
"""获取麦麦主程序配置架构"""
|
"""获取麦麦主程序配置架构"""
|
||||||
try:
|
try:
|
||||||
# Config 类包含所有子配置
|
# Config 类包含所有子配置
|
||||||
|
|
@ -65,7 +74,7 @@ async def get_bot_config_schema():
|
||||||
|
|
||||||
|
|
||||||
@router.get("/schema/model")
|
@router.get("/schema/model")
|
||||||
async def get_model_config_schema():
|
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||||
try:
|
try:
|
||||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||||
|
|
@ -79,7 +88,7 @@ async def get_model_config_schema():
|
||||||
|
|
||||||
|
|
||||||
@router.get("/schema/section/{section_name}")
|
@router.get("/schema/section/{section_name}")
|
||||||
async def get_config_section_schema(section_name: str):
|
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取指定配置节的架构
|
获取指定配置节的架构
|
||||||
|
|
||||||
|
|
@ -149,7 +158,7 @@ async def get_config_section_schema(section_name: str):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/bot")
|
@router.get("/bot")
|
||||||
async def get_bot_config():
|
async def get_bot_config(_auth: bool = Depends(require_auth)):
|
||||||
"""获取麦麦主程序配置"""
|
"""获取麦麦主程序配置"""
|
||||||
try:
|
try:
|
||||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
|
|
@ -168,7 +177,7 @@ async def get_bot_config():
|
||||||
|
|
||||||
|
|
||||||
@router.get("/model")
|
@router.get("/model")
|
||||||
async def get_model_config():
|
async def get_model_config(_auth: bool = Depends(require_auth)):
|
||||||
"""获取模型配置(包含提供商和模型任务配置)"""
|
"""获取模型配置(包含提供商和模型任务配置)"""
|
||||||
try:
|
try:
|
||||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||||
|
|
@ -190,7 +199,7 @@ async def get_model_config():
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bot")
|
@router.post("/bot")
|
||||||
async def update_bot_config(config_data: ConfigBody):
|
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||||
"""更新麦麦主程序配置"""
|
"""更新麦麦主程序配置"""
|
||||||
try:
|
try:
|
||||||
# 验证配置数据
|
# 验证配置数据
|
||||||
|
|
@ -213,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model")
|
@router.post("/model")
|
||||||
async def update_model_config(config_data: ConfigBody):
|
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||||
"""更新模型配置"""
|
"""更新模型配置"""
|
||||||
try:
|
try:
|
||||||
# 验证配置数据
|
# 验证配置数据
|
||||||
|
|
@ -239,7 +248,7 @@ async def update_model_config(config_data: ConfigBody):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bot/section/{section_name}")
|
@router.post("/bot/section/{section_name}")
|
||||||
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||||
try:
|
try:
|
||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
|
|
@ -288,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
||||||
|
|
||||||
|
|
||||||
@router.get("/bot/raw")
|
@router.get("/bot/raw")
|
||||||
async def get_bot_config_raw():
|
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
|
||||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||||
try:
|
try:
|
||||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
|
|
@ -307,7 +316,7 @@ async def get_bot_config_raw():
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bot/raw")
|
@router.post("/bot/raw")
|
||||||
async def update_bot_config_raw(raw_content: RawContentBody):
|
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
|
||||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||||
try:
|
try:
|
||||||
# 验证 TOML 格式
|
# 验证 TOML 格式
|
||||||
|
|
@ -337,7 +346,7 @@ async def update_bot_config_raw(raw_content: RawContentBody):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model/section/{section_name}")
|
@router.post("/model/section/{section_name}")
|
||||||
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
async def update_model_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||||
"""更新模型配置的指定节(保留注释和格式)"""
|
"""更新模型配置的指定节(保留注释和格式)"""
|
||||||
try:
|
try:
|
||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
|
|
@ -430,7 +439,7 @@ def _to_relative_path(path: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
@router.get("/adapter-config/path")
|
@router.get("/adapter-config/path")
|
||||||
async def get_adapter_config_path():
|
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
|
||||||
"""获取保存的适配器配置文件路径"""
|
"""获取保存的适配器配置文件路径"""
|
||||||
try:
|
try:
|
||||||
# 从 data/webui.json 读取路径偏好
|
# 从 data/webui.json 读取路径偏好
|
||||||
|
|
@ -469,7 +478,7 @@ async def get_adapter_config_path():
|
||||||
|
|
||||||
|
|
||||||
@router.post("/adapter-config/path")
|
@router.post("/adapter-config/path")
|
||||||
async def save_adapter_config_path(data: PathBody):
|
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||||
"""保存适配器配置文件路径偏好"""
|
"""保存适配器配置文件路径偏好"""
|
||||||
try:
|
try:
|
||||||
path = data.get("path")
|
path = data.get("path")
|
||||||
|
|
@ -512,7 +521,7 @@ async def save_adapter_config_path(data: PathBody):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/adapter-config")
|
@router.get("/adapter-config")
|
||||||
async def get_adapter_config(path: str):
|
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
|
||||||
"""从指定路径读取适配器配置文件"""
|
"""从指定路径读取适配器配置文件"""
|
||||||
try:
|
try:
|
||||||
if not path:
|
if not path:
|
||||||
|
|
@ -544,7 +553,7 @@ async def get_adapter_config(path: str):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/adapter-config")
|
@router.post("/adapter-config")
|
||||||
async def save_adapter_config(data: PathBody):
|
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||||
"""保存适配器配置到指定路径"""
|
"""保存适配器配置到指定路径"""
|
||||||
try:
|
try:
|
||||||
path = data.get("path")
|
path = data.get("path")
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,24 @@
|
||||||
"""知识库图谱可视化 API 路由"""
|
"""知识库图谱可视化 API 路由"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query, Depends, Cookie, Header
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeNode(BaseModel):
|
class KnowledgeNode(BaseModel):
|
||||||
"""知识节点"""
|
"""知识节点"""
|
||||||
|
|
||||||
|
|
@ -113,6 +122,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""获取知识图谱(限制节点数量)
|
"""获取知识图谱(限制节点数量)
|
||||||
|
|
||||||
|
|
@ -199,7 +209,7 @@ async def get_knowledge_graph(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stats", response_model=KnowledgeStats)
|
@router.get("/stats", response_model=KnowledgeStats)
|
||||||
async def get_knowledge_stats():
|
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
|
||||||
"""获取知识库统计信息
|
"""获取知识库统计信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -248,7 +258,7 @@ async def get_knowledge_stats():
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search", response_model=List[KnowledgeNode])
|
@router.get("/search", response_model=List[KnowledgeNode])
|
||||||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
|
||||||
"""搜索知识节点
|
"""搜索知识节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
"""WebSocket 日志推送模块"""
|
"""WebSocket 日志推送模块"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||||
from typing import Set
|
from typing import Set, Optional
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.token_manager import get_token_manager
|
||||||
|
|
||||||
logger = get_logger("webui.logs_ws")
|
logger = get_logger("webui.logs_ws")
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -73,14 +74,32 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws/logs")
|
@router.websocket("/ws/logs")
|
||||||
async def websocket_logs(websocket: WebSocket):
|
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||||
"""WebSocket 日志推送端点
|
"""WebSocket 日志推送端点
|
||||||
|
|
||||||
客户端连接后会持续接收服务器端的日志消息
|
客户端连接后会持续接收服务器端的日志消息
|
||||||
|
需要通过 query 参数传递 token 进行认证,例如:ws://host/ws/logs?token=xxx
|
||||||
"""
|
"""
|
||||||
|
# 认证检查
|
||||||
|
if not token:
|
||||||
|
# 尝试从 Cookie 获取 token
|
||||||
|
token = websocket.cookies.get("maibot_session")
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
logger.warning("WebSocket 连接被拒绝:未提供认证 token")
|
||||||
|
await websocket.close(code=4001, reason="未提供认证信息")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 验证 token
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token_manager.verify_token(token):
|
||||||
|
logger.warning("WebSocket 连接被拒绝:token 无效")
|
||||||
|
await websocket.close(code=4003, reason="Token 无效或已过期")
|
||||||
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
active_connections.add(websocket)
|
active_connections.add(websocket)
|
||||||
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||||
|
|
||||||
# 连接建立后,立即发送历史日志
|
# 连接建立后,立即发送历史日志
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,27 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import tomlkit
|
import tomlkit
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import CONFIG_DIR
|
from src.config.config import CONFIG_DIR
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
logger = get_logger("webui")
|
logger = get_logger("webui")
|
||||||
|
|
||||||
router = APIRouter(prefix="/models", tags=["models"])
|
router = APIRouter(prefix="/models", tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
# 模型获取器配置
|
# 模型获取器配置
|
||||||
MODEL_FETCHER_CONFIG = {
|
MODEL_FETCHER_CONFIG = {
|
||||||
# OpenAI 兼容格式的提供商
|
# OpenAI 兼容格式的提供商
|
||||||
|
|
@ -184,6 +193,7 @@ async def get_provider_models(
|
||||||
provider_name: str = Query(..., description="提供商名称"),
|
provider_name: str = Query(..., description="提供商名称"),
|
||||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定提供商的可用模型列表
|
获取指定提供商的可用模型列表
|
||||||
|
|
@ -228,6 +238,7 @@ async def get_models_by_url(
|
||||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||||
|
|
@ -251,6 +262,7 @@ async def get_models_by_url(
|
||||||
async def test_provider_connection(
|
async def test_provider_connection(
|
||||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
测试提供商连接状态
|
测试提供商连接状态
|
||||||
|
|
@ -337,6 +349,7 @@ async def test_provider_connection(
|
||||||
@router.post("/test-connection-by-name")
|
@router.post("/test-connection-by-name")
|
||||||
async def test_provider_connection_by_name(
|
async def test_provider_connection_by_name(
|
||||||
provider_name: str = Query(..., description="提供商名称"),
|
provider_name: str = Query(..., description="提供商名称"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
通过提供商名称测试连接(从配置文件读取信息)
|
通过提供商名称测试连接(从配置文件读取信息)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
"""WebSocket 插件加载进度推送模块"""
|
"""WebSocket 插件加载进度推送模块"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||||
from typing import Set, Dict, Any
|
from typing import Set, Dict, Any, Optional
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.token_manager import get_token_manager
|
||||||
|
|
||||||
logger = get_logger("webui.plugin_progress")
|
logger = get_logger("webui.plugin_progress")
|
||||||
|
|
||||||
|
|
@ -89,14 +90,33 @@ async def update_progress(
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws/plugin-progress")
|
@router.websocket("/ws/plugin-progress")
|
||||||
async def websocket_plugin_progress(websocket: WebSocket):
|
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||||
"""WebSocket 插件加载进度推送端点
|
"""WebSocket 插件加载进度推送端点
|
||||||
|
|
||||||
客户端连接后会立即收到当前进度状态
|
客户端连接后会立即收到当前进度状态
|
||||||
|
需要通过 query 参数或 Cookie 传递 token 进行认证
|
||||||
"""
|
"""
|
||||||
|
# 认证检查
|
||||||
|
auth_token = token
|
||||||
|
if not auth_token:
|
||||||
|
# 尝试从 Cookie 获取 token
|
||||||
|
auth_token = websocket.cookies.get("maibot_session")
|
||||||
|
|
||||||
|
if not auth_token:
|
||||||
|
logger.warning("插件进度 WebSocket 连接被拒绝:未提供认证 token")
|
||||||
|
await websocket.close(code=4001, reason="未提供认证信息")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 验证 token
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token_manager.verify_token(auth_token):
|
||||||
|
logger.warning("插件进度 WebSocket 连接被拒绝:token 无效")
|
||||||
|
await websocket.close(code=4003, reason="Token 无效或已过期")
|
||||||
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
active_connections.add(websocket)
|
active_connections.add(websocket)
|
||||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 发送当前进度状态
|
# 发送当前进度状态
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,264 @@
|
||||||
|
"""
|
||||||
|
WebUI 请求频率限制模块
|
||||||
|
防止暴力破解和 API 滥用
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, Tuple, Optional
|
||||||
|
from fastapi import Request, HTTPException
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("webui.rate_limiter")
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
简单的内存请求频率限制器
|
||||||
|
|
||||||
|
使用滑动窗口算法实现
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 存储格式: {key: [(timestamp, count), ...]}
|
||||||
|
self._requests: Dict[str, list] = defaultdict(list)
|
||||||
|
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||||
|
self._blocked: Dict[str, float] = {}
|
||||||
|
|
||||||
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
|
"""获取客户端 IP 地址"""
|
||||||
|
# 检查代理头
|
||||||
|
forwarded = request.headers.get("X-Forwarded-For")
|
||||||
|
if forwarded:
|
||||||
|
# 取第一个 IP(最原始的客户端)
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
|
||||||
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
|
if real_ip:
|
||||||
|
return real_ip
|
||||||
|
|
||||||
|
# 直接连接的客户端
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _cleanup_old_requests(self, key: str, window_seconds: int):
|
||||||
|
"""清理过期的请求记录"""
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - window_seconds
|
||||||
|
self._requests[key] = [
|
||||||
|
(ts, count) for ts, count in self._requests[key]
|
||||||
|
if ts > cutoff
|
||||||
|
]
|
||||||
|
|
||||||
|
def _cleanup_expired_blocks(self):
|
||||||
|
"""清理过期的封禁"""
|
||||||
|
now = time.time()
|
||||||
|
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
|
||||||
|
for ip in expired:
|
||||||
|
del self._blocked[ip]
|
||||||
|
logger.info(f"🔓 IP {ip} 封禁已解除")
|
||||||
|
|
||||||
|
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
|
||||||
|
"""
|
||||||
|
检查 IP 是否被封禁
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否被封禁, 剩余封禁秒数)
|
||||||
|
"""
|
||||||
|
self._cleanup_expired_blocks()
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
|
||||||
|
if ip in self._blocked:
|
||||||
|
remaining = int(self._blocked[ip] - time.time())
|
||||||
|
return True, max(0, remaining)
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def check_rate_limit(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
max_requests: int,
|
||||||
|
window_seconds: int,
|
||||||
|
key_suffix: str = ""
|
||||||
|
) -> Tuple[bool, int]:
|
||||||
|
"""
|
||||||
|
检查请求是否超过频率限制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI Request 对象
|
||||||
|
max_requests: 窗口期内允许的最大请求数
|
||||||
|
window_seconds: 窗口时间(秒)
|
||||||
|
key_suffix: 键后缀,用于区分不同的限制规则
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否允许, 剩余请求数)
|
||||||
|
"""
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
key = f"{ip}:{key_suffix}" if key_suffix else ip
|
||||||
|
|
||||||
|
# 清理过期记录
|
||||||
|
self._cleanup_old_requests(key, window_seconds)
|
||||||
|
|
||||||
|
# 计算当前窗口内的请求数
|
||||||
|
current_count = sum(count for _, count in self._requests[key])
|
||||||
|
|
||||||
|
if current_count >= max_requests:
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# 记录新请求
|
||||||
|
now = time.time()
|
||||||
|
self._requests[key].append((now, 1))
|
||||||
|
|
||||||
|
remaining = max_requests - current_count - 1
|
||||||
|
return True, remaining
|
||||||
|
|
||||||
|
def block_ip(self, request: Request, duration_seconds: int):
|
||||||
|
"""
|
||||||
|
封禁 IP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI Request 对象
|
||||||
|
duration_seconds: 封禁时长(秒)
|
||||||
|
"""
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
self._blocked[ip] = time.time() + duration_seconds
|
||||||
|
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒")
|
||||||
|
|
||||||
|
def record_failed_attempt(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
max_failures: int = 5,
|
||||||
|
window_seconds: int = 300,
|
||||||
|
block_duration: int = 600
|
||||||
|
) -> Tuple[bool, int]:
|
||||||
|
"""
|
||||||
|
记录失败尝试(如登录失败)
|
||||||
|
|
||||||
|
如果在窗口期内失败次数过多,自动封禁 IP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI Request 对象
|
||||||
|
max_failures: 允许的最大失败次数
|
||||||
|
window_seconds: 统计窗口(秒)
|
||||||
|
block_duration: 封禁时长(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否被封禁, 剩余尝试次数)
|
||||||
|
"""
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
key = f"{ip}:auth_failures"
|
||||||
|
|
||||||
|
# 清理过期记录
|
||||||
|
self._cleanup_old_requests(key, window_seconds)
|
||||||
|
|
||||||
|
# 计算当前失败次数
|
||||||
|
current_failures = sum(count for _, count in self._requests[key])
|
||||||
|
|
||||||
|
# 记录本次失败
|
||||||
|
now = time.time()
|
||||||
|
self._requests[key].append((now, 1))
|
||||||
|
current_failures += 1
|
||||||
|
|
||||||
|
remaining = max_failures - current_failures
|
||||||
|
|
||||||
|
# 检查是否需要封禁
|
||||||
|
if current_failures >= max_failures:
|
||||||
|
self.block_ip(request, block_duration)
|
||||||
|
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
|
||||||
|
return True, 0
|
||||||
|
|
||||||
|
if current_failures >= max_failures - 2:
|
||||||
|
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次")
|
||||||
|
|
||||||
|
return False, max(0, remaining)
|
||||||
|
|
||||||
|
def reset_failures(self, request: Request):
|
||||||
|
"""
|
||||||
|
重置失败计数(认证成功后调用)
|
||||||
|
"""
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
key = f"{ip}:auth_failures"
|
||||||
|
if key in self._requests:
|
||||||
|
del self._requests[key]
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_rate_limiter: Optional[RateLimiter] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_rate_limiter() -> RateLimiter:
|
||||||
|
"""获取 RateLimiter 单例"""
|
||||||
|
global _rate_limiter
|
||||||
|
if _rate_limiter is None:
|
||||||
|
_rate_limiter = RateLimiter()
|
||||||
|
return _rate_limiter
|
||||||
|
|
||||||
|
|
||||||
|
async def check_auth_rate_limit(request: Request):
|
||||||
|
"""
|
||||||
|
认证接口的频率限制依赖
|
||||||
|
|
||||||
|
规则:
|
||||||
|
- 每个 IP 每分钟最多 10 次认证请求
|
||||||
|
- 连续失败 5 次后封禁 10 分钟
|
||||||
|
"""
|
||||||
|
limiter = get_rate_limiter()
|
||||||
|
|
||||||
|
# 检查是否被封禁
|
||||||
|
blocked, remaining_block = limiter.is_blocked(request)
|
||||||
|
if blocked:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||||
|
headers={"Retry-After": str(remaining_block)}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查频率限制
|
||||||
|
allowed, remaining = limiter.check_rate_limit(
|
||||||
|
request,
|
||||||
|
max_requests=10, # 每分钟 10 次
|
||||||
|
window_seconds=60,
|
||||||
|
key_suffix="auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="认证请求过于频繁,请稍后重试",
|
||||||
|
headers={"Retry-After": "60"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_api_rate_limit(request: Request):
|
||||||
|
"""
|
||||||
|
普通 API 的频率限制依赖
|
||||||
|
|
||||||
|
规则:每个 IP 每分钟最多 100 次请求
|
||||||
|
"""
|
||||||
|
limiter = get_rate_limiter()
|
||||||
|
|
||||||
|
# 检查是否被封禁
|
||||||
|
blocked, remaining_block = limiter.is_blocked(request)
|
||||||
|
if blocked:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||||
|
headers={"Retry-After": str(remaining_block)}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查频率限制
|
||||||
|
allowed, _ = limiter.check_rate_limit(
|
||||||
|
request,
|
||||||
|
max_requests=100, # 每分钟 100 次
|
||||||
|
window_seconds=60,
|
||||||
|
key_suffix="api"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="请求过于频繁,请稍后重试",
|
||||||
|
headers={"Retry-After": "60"}
|
||||||
|
)
|
||||||
|
|
@ -7,10 +7,12 @@
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from fastapi import APIRouter, HTTPException
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from src.config.config import MMC_VERSION
|
from src.config.config import MMC_VERSION
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
router = APIRouter(prefix="/system", tags=["system"])
|
router = APIRouter(prefix="/system", tags=["system"])
|
||||||
logger = get_logger("webui_system")
|
logger = get_logger("webui_system")
|
||||||
|
|
@ -19,6 +21,14 @@ logger = get_logger("webui_system")
|
||||||
_start_time = time.time()
|
_start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
class RestartResponse(BaseModel):
|
class RestartResponse(BaseModel):
|
||||||
"""重启响应"""
|
"""重启响应"""
|
||||||
|
|
||||||
|
|
@ -36,7 +46,7 @@ class StatusResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/restart", response_model=RestartResponse)
|
@router.post("/restart", response_model=RestartResponse)
|
||||||
async def restart_maibot():
|
async def restart_maibot(_auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
重启麦麦主程序
|
重启麦麦主程序
|
||||||
|
|
||||||
|
|
@ -67,7 +77,7 @@ async def restart_maibot():
|
||||||
|
|
||||||
|
|
||||||
@router.get("/status", response_model=StatusResponse)
|
@router.get("/status", response_model=StatusResponse)
|
||||||
async def get_maibot_status():
|
async def get_maibot_status(_auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取麦麦运行状态
|
获取麦麦运行状态
|
||||||
|
|
||||||
|
|
@ -90,7 +100,7 @@ async def get_maibot_status():
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reload-config")
|
@router.post("/reload-config")
|
||||||
async def reload_config():
|
async def reload_config(_auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
热重载配置(不重启进程)
|
热重载配置(不重启进程)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
"""WebUI API 路由"""
|
"""WebUI API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
|
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .token_manager import get_token_manager
|
from .token_manager import get_token_manager
|
||||||
from .auth import set_auth_cookie, clear_auth_cookie
|
from .auth import set_auth_cookie, clear_auth_cookie
|
||||||
|
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
|
||||||
from .config_routes import router as config_router
|
from .config_routes import router as config_router
|
||||||
from .statistics_routes import router as statistics_router
|
from .statistics_routes import router as statistics_router
|
||||||
from .person_routes import router as person_router
|
from .person_routes import router as person_router
|
||||||
|
|
@ -107,12 +108,18 @@ async def health_check():
|
||||||
|
|
||||||
|
|
||||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||||
async def verify_token(request: TokenVerifyRequest, response: Response):
|
async def verify_token(
|
||||||
|
request_body: TokenVerifyRequest,
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
_rate_limit: None = Depends(check_auth_rate_limit),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: 包含 token 的验证请求
|
request_body: 包含 token 的验证请求
|
||||||
|
request: FastAPI Request 对象(用于获取客户端 IP)
|
||||||
response: FastAPI Response 对象
|
response: FastAPI Response 对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -120,16 +127,40 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
token_manager = get_token_manager()
|
token_manager = get_token_manager()
|
||||||
is_valid = token_manager.verify_token(request.token)
|
rate_limiter = get_rate_limiter()
|
||||||
|
|
||||||
|
is_valid = token_manager.verify_token(request_body.token)
|
||||||
|
|
||||||
if is_valid:
|
if is_valid:
|
||||||
|
# 认证成功,重置失败计数
|
||||||
|
rate_limiter.reset_failures(request)
|
||||||
# 设置 HttpOnly Cookie
|
# 设置 HttpOnly Cookie
|
||||||
set_auth_cookie(response, request.token)
|
set_auth_cookie(response, request_body.token)
|
||||||
# 同时返回首次配置状态,避免额外请求
|
# 同时返回首次配置状态,避免额外请求
|
||||||
is_first_setup = token_manager.is_first_setup()
|
is_first_setup = token_manager.is_first_setup()
|
||||||
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
||||||
else:
|
else:
|
||||||
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
# 记录失败尝试
|
||||||
|
blocked, remaining = rate_limiter.record_failed_attempt(
|
||||||
|
request,
|
||||||
|
max_failures=5, # 5 次失败
|
||||||
|
window_seconds=300, # 5 分钟窗口
|
||||||
|
block_duration=600 # 封禁 10 分钟
|
||||||
|
)
|
||||||
|
|
||||||
|
if blocked:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟"
|
||||||
|
)
|
||||||
|
|
||||||
|
message = "Token 无效或已过期"
|
||||||
|
if remaining <= 2:
|
||||||
|
message += f"(剩余 {remaining} 次尝试机会)"
|
||||||
|
|
||||||
|
return TokenVerifyResponse(valid=False, message=message)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Token 验证失败: {e}")
|
logger.error(f"Token 验证失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,28 @@
|
||||||
"""统计数据 API 路由"""
|
"""统计数据 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, Optional
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
logger = get_logger("webui.statistics")
|
logger = get_logger("webui.statistics")
|
||||||
|
|
||||||
router = APIRouter(prefix="/statistics", tags=["statistics"])
|
router = APIRouter(prefix="/statistics", tags=["statistics"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
class StatisticsSummary(BaseModel):
|
class StatisticsSummary(BaseModel):
|
||||||
"""统计数据摘要"""
|
"""统计数据摘要"""
|
||||||
|
|
||||||
|
|
@ -58,7 +67,7 @@ class DashboardData(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/dashboard", response_model=DashboardData)
|
@router.get("/dashboard", response_model=DashboardData)
|
||||||
async def get_dashboard_data(hours: int = 24):
|
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取仪表盘统计数据
|
获取仪表盘统计数据
|
||||||
|
|
||||||
|
|
@ -275,7 +284,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
|
|
||||||
|
|
||||||
@router.get("/summary")
|
@router.get("/summary")
|
||||||
async def get_summary(hours: int = 24):
|
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取统计摘要
|
获取统计摘要
|
||||||
|
|
||||||
|
|
@ -293,7 +302,7 @@ async def get_summary(hours: int = 24):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/models")
|
@router.get("/models")
|
||||||
async def get_model_stats(hours: int = 24):
|
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取模型统计
|
获取模型统计
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,15 @@ class WebUIServer:
|
||||||
"http://127.0.0.1:8001",
|
"http://127.0.0.1:8001",
|
||||||
],
|
],
|
||||||
allow_credentials=True, # 允许携带 Cookie
|
allow_credentials=True, # 允许携带 Cookie
|
||||||
allow_methods=["*"],
|
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
|
||||||
allow_headers=["*"],
|
allow_headers=[
|
||||||
|
"Content-Type",
|
||||||
|
"Authorization",
|
||||||
|
"Accept",
|
||||||
|
"Origin",
|
||||||
|
"X-Requested-With",
|
||||||
|
], # 明确指定允许的头
|
||||||
|
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
|
||||||
)
|
)
|
||||||
logger.debug("✅ CORS 中间件已配置")
|
logger.debug("✅ CORS 中间件已配置")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue