diff --git a/src/webui/auth.py b/src/webui/auth.py new file mode 100644 index 00000000..8d52a5e3 --- /dev/null +++ b/src/webui/auth.py @@ -0,0 +1,127 @@ +""" +WebUI 认证模块 +提供统一的认证依赖,支持 Cookie 和 Header 两种方式 +""" + +from typing import Optional +from fastapi import HTTPException, Cookie, Header, Response, Request +from src.common.logger import get_logger +from .token_manager import get_token_manager + +logger = get_logger("webui.auth") + +# Cookie 配置 +COOKIE_NAME = "maibot_session" +COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天 + + +def get_current_token( + request: Request, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> str: + """ + 获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取 + + Args: + request: FastAPI Request 对象 + maibot_session: Cookie 中的 token + authorization: Authorization Header (Bearer token) + + Returns: + 验证通过的 token + + Raises: + HTTPException: 认证失败时抛出 401 错误 + """ + token = None + + # 优先从 Cookie 获取 + if maibot_session: + token = maibot_session + # 其次从 Header 获取(兼容旧版本) + elif authorization and authorization.startswith("Bearer "): + token = authorization.replace("Bearer ", "") + + if not token: + raise HTTPException(status_code=401, detail="未提供有效的认证信息") + + # 验证 token + token_manager = get_token_manager() + if not token_manager.verify_token(token): + raise HTTPException(status_code=401, detail="Token 无效或已过期") + + return token + + +def set_auth_cookie(response: Response, token: str) -> None: + """ + 设置认证 Cookie + + Args: + response: FastAPI Response 对象 + token: 要设置的 token + """ + response.set_cookie( + key=COOKIE_NAME, + value=token, + max_age=COOKIE_MAX_AGE, + httponly=True, # 防止 JS 读取 + samesite="lax", # 允许同站导航时发送 Cookie(兼容开发环境代理) + secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True + path="/", # 确保 Cookie 在所有路径下可用 + ) + logger.debug(f"已设置认证 Cookie: {token[:8]}...") + + +def clear_auth_cookie(response: Response) -> None: + """ + 清除认证 Cookie + + Args: + response: FastAPI Response 对象 + """ + response.delete_cookie( + key=COOKIE_NAME, + httponly=True, + samesite="lax", + path="/", + ) + logger.debug("已清除认证 Cookie") + + +def verify_auth_token_from_cookie_or_header( + maibot_session: Optional[str] = None, + authorization: Optional[str] = None, +) -> bool: + """ + 验证认证 Token,支持从 Cookie 或 Header 获取 + + Args: + maibot_session: Cookie 中的 token + authorization: Authorization header (Bearer token) + + Returns: + 验证成功返回 True + + Raises: + HTTPException: 认证失败时抛出 401 错误 + """ + token = None + + # 优先从 Cookie 获取 + if maibot_session: + token = maibot_session + # 其次从 Header 获取(兼容旧版本) + elif authorization and authorization.startswith("Bearer "): + token = authorization.replace("Bearer ", "") + + if not token: + raise HTTPException(status_code=401, detail="未提供有效的认证信息") + + # 验证 token + token_manager = get_token_manager() + if not token_manager.verify_token(token): + raise HTTPException(status_code=401, detail="Token 无效或已过期") + + return True diff --git a/src/webui/emoji_routes.py b/src/webui/emoji_routes.py index 94f77b95..c4d90ea2 100644 --- a/src/webui/emoji_routes.py +++ b/src/webui/emoji_routes.py @@ -1,12 +1,13 @@ """表情包管理 API 路由""" -from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form +from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie from fastapi.responses import FileResponse from pydantic import BaseModel from typing import Optional, List, Annotated from src.common.logger import get_logger from src.common.database.database_model import Emoji from .token_manager import get_token_manager +from .auth import verify_auth_token_from_cookie_or_header import time import os import hashlib @@ -101,18 +102,12 @@ class BatchDeleteResponse(BaseModel): failed_ids: List[int] = [] -def verify_auth_token(authorization: Optional[str]) -> bool: - """验证认证 Token""" - if not authorization or not authorization.startswith("Bearer "): - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - - token = authorization.replace("Bearer ", "") - token_manager = get_token_manager() - - if not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="Token 无效或已过期") - - return True +def verify_auth_token( + maibot_session: Optional[str] = None, + authorization: Optional[str] = None, +) -> bool: + """验证认证 Token,支持 Cookie 和 Header""" + return verify_auth_token_from_cookie_or_header(maibot_session, authorization) def emoji_to_response(emoji: Emoji) -> EmojiResponse: @@ -144,6 +139,7 @@ async def get_emoji_list( format: Optional[str] = Query(None, description="格式筛选"), sort_by: Optional[str] = Query("usage_count", description="排序字段"), sort_order: Optional[str] = Query("desc", description="排序方向"), + maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None), ): """ @@ -164,7 +160,7 @@ async def get_emoji_list( 表情包列表 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) # 构建查询 query = Emoji.select() @@ -222,7 +218,7 @@ async def get_emoji_list( @router.get("/{emoji_id}", response_model=EmojiDetailResponse) -async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)): +async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 获取表情包详细信息 @@ -234,7 +230,7 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header( 表情包详细信息 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) emoji = Emoji.get_or_none(Emoji.id == emoji_id) @@ -251,7 +247,7 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header( @router.patch("/{emoji_id}", response_model=EmojiUpdateResponse) -async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)): +async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 增量更新表情包(只更新提供的字段) @@ -264,7 +260,7 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization 更新结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) emoji = Emoji.get_or_none(Emoji.id == emoji_id) @@ -303,7 +299,7 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization @router.delete("/{emoji_id}", response_model=EmojiDeleteResponse) -async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)): +async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 删除表情包 @@ -315,7 +311,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None 删除结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) emoji = Emoji.get_or_none(Emoji.id == emoji_id) @@ -340,7 +336,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None @router.get("/stats/summary") -async def get_emoji_stats(authorization: Optional[str] = Header(None)): +async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 获取表情包统计数据 @@ -351,7 +347,7 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)): 统计数据 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) total = Emoji.select().count() registered = Emoji.select().where(Emoji.is_registered).count() @@ -395,7 +391,7 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)): @router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse) -async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)): +async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 注册表情包(快捷操作) @@ -407,7 +403,7 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No 更新结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) emoji = Emoji.get_or_none(Emoji.id == emoji_id) @@ -435,7 +431,7 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No @router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse) -async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)): +async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 禁用表情包(快捷操作) @@ -447,7 +443,7 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)): 更新结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) emoji = Emoji.get_or_none(Emoji.id == emoji_id) @@ -474,6 +470,7 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)): async def get_emoji_thumbnail( emoji_id: int, token: Optional[str] = Query(None, description="访问令牌"), + maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None), ): """ @@ -481,21 +478,31 @@ async def get_emoji_thumbnail( Args: emoji_id: 表情包ID - token: 访问令牌(通过 query parameter) + token: 访问令牌(通过 query parameter,用于向后兼容) + maibot_session: Cookie 中的 token authorization: Authorization header Returns: 表情包图片文件 """ try: - # 优先使用 query parameter 中的 token(用于 img 标签) - if token: - token_manager = get_token_manager() - if not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="Token 无效或已过期") - else: - # 如果没有 query token,则验证 Authorization header - verify_auth_token(authorization) + token_manager = get_token_manager() + is_valid = False + + # 1. 优先使用 Cookie + if maibot_session and token_manager.verify_token(maibot_session): + is_valid = True + # 2. 其次使用 query parameter(用于向后兼容 img 标签) + elif token and token_manager.verify_token(token): + is_valid = True + # 3. 最后使用 Authorization header + elif authorization and authorization.startswith("Bearer "): + auth_token = authorization.replace("Bearer ", "") + if token_manager.verify_token(auth_token): + is_valid = True + + if not is_valid: + raise HTTPException(status_code=401, detail="Token 无效或已过期") emoji = Emoji.get_or_none(Emoji.id == emoji_id) @@ -528,7 +535,7 @@ async def get_emoji_thumbnail( @router.post("/batch/delete", response_model=BatchDeleteResponse) -async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)): +async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 批量删除表情包 @@ -540,7 +547,7 @@ async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Option 批量删除结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) if not request.emoji_ids: raise HTTPException(status_code=400, detail="未提供要删除的表情包ID") @@ -601,6 +608,7 @@ async def upload_emoji( description: DescriptionForm = "", emotion: EmotionForm = "", is_registered: IsRegisteredForm = True, + maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None), ): """ @@ -617,7 +625,7 @@ async def upload_emoji( 上传结果和表情包信息 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) # 验证文件类型 if not file.content_type: @@ -721,6 +729,7 @@ async def batch_upload_emoji( files: EmojiFiles, emotion: EmotionForm = "", is_registered: IsRegisteredForm = True, + maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None), ): """ @@ -736,7 +745,7 @@ async def batch_upload_emoji( 批量上传结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) results = { "success": True, diff --git a/src/webui/expression_routes.py b/src/webui/expression_routes.py index 40f87489..f92219ab 100644 --- a/src/webui/expression_routes.py +++ b/src/webui/expression_routes.py @@ -1,11 +1,12 @@ """表达方式管理 API 路由""" -from fastapi import APIRouter, HTTPException, Header, Query +from fastapi import APIRouter, HTTPException, Header, Query, Cookie from pydantic import BaseModel from typing import Optional, List, Dict from src.common.logger import get_logger from src.common.database.database_model import Expression, ChatStreams from .token_manager import get_token_manager +from .auth import verify_auth_token_from_cookie_or_header import time logger = get_logger("webui.expression") @@ -87,18 +88,12 @@ class ExpressionCreateResponse(BaseModel): data: ExpressionResponse -def verify_auth_token(authorization: Optional[str]) -> bool: - """验证认证 Token""" - if not authorization or not authorization.startswith("Bearer "): - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - - token = authorization.replace("Bearer ", "") - token_manager = get_token_manager() - - if not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="Token 无效或已过期") - - return True +def verify_auth_token( + maibot_session: Optional[str] = None, + authorization: Optional[str] = None, +) -> bool: + """验证认证 Token,支持 Cookie 和 Header""" + return verify_auth_token_from_cookie_or_header(maibot_session, authorization) def expression_to_response(expression: Expression) -> ExpressionResponse: @@ -162,7 +157,7 @@ class ChatListResponse(BaseModel): @router.get("/chats", response_model=ChatListResponse) -async def get_chat_list(authorization: Optional[str] = Header(None)): +async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 获取所有聊天列表(用于下拉选择) @@ -173,7 +168,7 @@ async def get_chat_list(authorization: Optional[str] = Header(None)): 聊天列表 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) chat_list = [] for cs in ChatStreams.select(): @@ -205,6 +200,7 @@ async def get_expression_list( page_size: int = Query(20, ge=1, le=100, description="每页数量"), search: Optional[str] = Query(None, description="搜索关键词"), chat_id: Optional[str] = Query(None, description="聊天ID筛选"), + maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None), ): """ @@ -221,7 +217,7 @@ async def get_expression_list( 表达方式列表 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) # 构建查询 query = Expression.select() @@ -265,7 +261,7 @@ async def get_expression_list( @router.get("/{expression_id}", response_model=ExpressionDetailResponse) -async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)): +async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 获取表达方式详细信息 @@ -277,7 +273,7 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str] 表达方式详细信息 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) expression = Expression.get_or_none(Expression.id == expression_id) @@ -294,7 +290,7 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str] @router.post("/", response_model=ExpressionCreateResponse) -async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)): +async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 创建新的表达方式 @@ -306,7 +302,7 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt 创建结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) current_time = time.time() @@ -336,7 +332,7 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt @router.patch("/{expression_id}", response_model=ExpressionUpdateResponse) async def update_expression( - expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None) + expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) ): """ 增量更新表达方式(只更新提供的字段) @@ -350,7 +346,7 @@ async def update_expression( 更新结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) expression = Expression.get_or_none(Expression.id == expression_id) @@ -386,7 +382,7 @@ async def update_expression( @router.delete("/{expression_id}", response_model=ExpressionDeleteResponse) -async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)): +async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 删除表达方式 @@ -398,7 +394,7 @@ async def delete_expression(expression_id: int, authorization: Optional[str] = H 删除结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) expression = Expression.get_or_none(Expression.id == expression_id) @@ -429,7 +425,7 @@ class BatchDeleteRequest(BaseModel): @router.post("/batch/delete", response_model=ExpressionDeleteResponse) -async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)): +async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 批量删除表达方式 @@ -441,7 +437,7 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O 删除结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) if not request.ids: raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID") @@ -470,7 +466,7 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O @router.get("/stats/summary") -async def get_expression_stats(authorization: Optional[str] = Header(None)): +async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 获取表达方式统计数据 @@ -481,7 +477,7 @@ async def get_expression_stats(authorization: Optional[str] = Header(None)): 统计数据 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) total = Expression.select().count() diff --git a/src/webui/person_routes.py b/src/webui/person_routes.py index 5935a2fa..0b70a3a2 100644 --- a/src/webui/person_routes.py +++ b/src/webui/person_routes.py @@ -1,11 +1,12 @@ """人物信息管理 API 路由""" -from fastapi import APIRouter, HTTPException, Header, Query +from fastapi import APIRouter, HTTPException, Header, Query, Cookie from pydantic import BaseModel from typing import Optional, List, Dict from src.common.logger import get_logger from src.common.database.database_model import PersonInfo from .token_manager import get_token_manager +from .auth import verify_auth_token_from_cookie_or_header import json import time @@ -91,18 +92,12 @@ class BatchDeleteResponse(BaseModel): failed_ids: List[str] = [] -def verify_auth_token(authorization: Optional[str]) -> bool: - """验证认证 Token""" - if not authorization or not authorization.startswith("Bearer "): - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - - token = authorization.replace("Bearer ", "") - token_manager = get_token_manager() - - if not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="Token 无效或已过期") - - return True +def verify_auth_token( + maibot_session: Optional[str] = None, + authorization: Optional[str] = None, +) -> bool: + """验证认证 Token,支持 Cookie 和 Header""" + return verify_auth_token_from_cookie_or_header(maibot_session, authorization) def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]: @@ -141,6 +136,7 @@ async def get_person_list( search: Optional[str] = Query(None, description="搜索关键词"), is_known: Optional[bool] = Query(None, description="是否已认识筛选"), platform: Optional[str] = Query(None, description="平台筛选"), + maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None), ): """ @@ -158,7 +154,7 @@ async def get_person_list( 人物信息列表 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) # 构建查询 query = PersonInfo.select() @@ -205,7 +201,7 @@ async def get_person_list( @router.get("/{person_id}", response_model=PersonDetailResponse) -async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)): +async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 获取人物详细信息 @@ -217,7 +213,7 @@ async def get_person_detail(person_id: str, authorization: Optional[str] = Heade 人物详细信息 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) @@ -234,7 +230,7 @@ async def get_person_detail(person_id: str, authorization: Optional[str] = Heade @router.patch("/{person_id}", response_model=PersonUpdateResponse) -async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)): +async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 增量更新人物信息(只更新提供的字段) @@ -247,7 +243,7 @@ async def update_person(person_id: str, request: PersonUpdateRequest, authorizat 更新结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) @@ -283,7 +279,7 @@ async def update_person(person_id: str, request: PersonUpdateRequest, authorizat @router.delete("/{person_id}", response_model=PersonDeleteResponse) -async def delete_person(person_id: str, authorization: Optional[str] = Header(None)): +async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 删除人物信息 @@ -295,7 +291,7 @@ async def delete_person(person_id: str, authorization: Optional[str] = Header(No 删除结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) @@ -320,7 +316,7 @@ async def delete_person(person_id: str, authorization: Optional[str] = Header(No @router.get("/stats/summary") -async def get_person_stats(authorization: Optional[str] = Header(None)): +async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 获取人物信息统计数据 @@ -331,7 +327,7 @@ async def get_person_stats(authorization: Optional[str] = Header(None)): 统计数据 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) total = PersonInfo.select().count() known = PersonInfo.select().where(PersonInfo.is_known).count() @@ -353,7 +349,7 @@ async def get_person_stats(authorization: Optional[str] = Header(None)): @router.post("/batch/delete", response_model=BatchDeleteResponse) -async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)): +async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): """ 批量删除人物信息 @@ -365,7 +361,7 @@ async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optio 批量删除结果 """ try: - verify_auth_token(authorization) + verify_auth_token(maibot_session, authorization) if not request.person_ids: raise HTTPException(status_code=400, detail="未提供要删除的人物ID") diff --git a/src/webui/plugin_routes.py b/src/webui/plugin_routes.py index bf4784df..1f2d85da 100644 --- a/src/webui/plugin_routes.py +++ b/src/webui/plugin_routes.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Header +from fastapi import APIRouter, HTTPException, Header, Cookie from pydantic import BaseModel, Field from typing import Optional, List, Dict, Any from pathlib import Path @@ -19,6 +19,20 @@ router = APIRouter(prefix="/plugins", tags=["插件管理"]) set_update_progress_callback(update_progress) +def get_token_from_cookie_or_header( + maibot_session: Optional[str] = None, + authorization: Optional[str] = None, +) -> Optional[str]: + """从 Cookie 或 Header 获取 token""" + # 优先从 Cookie 获取 + if maibot_session: + return maibot_session + # 其次从 Header 获取 + if authorization and authorization.startswith("Bearer "): + return authorization.replace("Bearer ", "") + return None + + def parse_version(version_str: str) -> tuple[int, int, int]: """ 解析版本号字符串 @@ -210,12 +224,12 @@ async def check_git_status() -> GitStatusResponse: @router.get("/mirrors", response_model=AvailableMirrorsResponse) -async def get_available_mirrors(authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse: +async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse: """ 获取所有可用的镜像源配置 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -240,12 +254,12 @@ async def get_available_mirrors(authorization: Optional[str] = Header(None)) -> @router.post("/mirrors", response_model=MirrorConfigResponse) -async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = Header(None)) -> MirrorConfigResponse: +async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse: """ 添加新的镜像源 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -280,13 +294,13 @@ async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = H @router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse) async def update_mirror( - mirror_id: str, request: UpdateMirrorRequest, authorization: Optional[str] = Header(None) + mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) ) -> MirrorConfigResponse: """ 更新镜像源配置 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -323,12 +337,12 @@ async def update_mirror( @router.delete("/mirrors/{mirror_id}") -async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 删除镜像源 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -346,7 +360,7 @@ async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(No @router.post("/fetch-raw", response_model=FetchRawFileResponse) async def fetch_raw_file( - request: FetchRawFileRequest, authorization: Optional[str] = Header(None) + request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) ) -> FetchRawFileResponse: """ 获取 GitHub 仓库的 Raw 文件内容 @@ -356,7 +370,7 @@ async def fetch_raw_file( 注意:此接口可公开访问,用于获取插件仓库等公开资源 """ # Token 验证(可选,用于日志记录) - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() is_authenticated = token and token_manager.verify_token(token) @@ -431,7 +445,7 @@ async def fetch_raw_file( @router.post("/clone", response_model=CloneRepositoryResponse) async def clone_repository( - request: CloneRepositoryRequest, authorization: Optional[str] = Header(None) + request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) ) -> CloneRepositoryResponse: """ 克隆 GitHub 仓库到本地 @@ -439,7 +453,7 @@ async def clone_repository( 支持多镜像源自动切换和错误重试 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -471,14 +485,14 @@ async def clone_repository( @router.post("/install") -async def install_plugin(request: InstallPluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 安装插件 从 Git 仓库克隆插件到本地插件目录 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -675,7 +689,7 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[ @router.post("/uninstall") async def uninstall_plugin( - request: UninstallPluginRequest, authorization: Optional[str] = Header(None) + request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) ) -> Dict[str, Any]: """ 卸载插件 @@ -683,7 +697,7 @@ async def uninstall_plugin( 删除插件目录及其所有文件 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -810,14 +824,14 @@ async def uninstall_plugin( @router.post("/update") -async def update_plugin(request: UpdatePluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 更新插件 删除旧版本,重新克隆新版本 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -1029,14 +1043,14 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st @router.get("/installed") -async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 获取已安装的插件列表 扫描 plugins 目录,返回所有已安装插件的 ID 和基本信息 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -1169,7 +1183,7 @@ class UpdatePluginConfigRequest(BaseModel): @router.get("/config/{plugin_id}/schema") -async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 获取插件配置 Schema @@ -1177,7 +1191,7 @@ async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] 用于前端动态生成配置表单。 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -1302,14 +1316,14 @@ async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] @router.get("/config/{plugin_id}") -async def get_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 获取插件当前配置值 返回插件的当前配置值。 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -1358,7 +1372,7 @@ async def get_plugin_config(plugin_id: str, authorization: Optional[str] = Heade @router.put("/config/{plugin_id}") async def update_plugin_config( - plugin_id: str, request: UpdatePluginConfigRequest, authorization: Optional[str] = Header(None) + plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) ) -> Dict[str, Any]: """ 更新插件配置 @@ -1366,7 +1380,7 @@ async def update_plugin_config( 保存新的配置值到插件的配置文件。 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -1431,14 +1445,14 @@ async def update_plugin_config( @router.post("/config/{plugin_id}/reset") -async def reset_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 重置插件配置为默认值 删除当前配置文件,下次加载插件时将使用默认配置。 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") @@ -1491,14 +1505,14 @@ async def reset_plugin_config(plugin_id: str, authorization: Optional[str] = Hea @router.post("/config/{plugin_id}/toggle") -async def toggle_plugin(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 切换插件启用状态 切换插件配置中的 enabled 字段。 """ # Token 验证 - token = authorization.replace("Bearer ", "") if authorization else None + token = get_token_from_cookie_or_header(maibot_session, authorization) token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") diff --git a/src/webui/routes.py b/src/webui/routes.py index 26eaf553..c3d6fd9e 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -1,10 +1,11 @@ """WebUI API 路由""" -from fastapi import APIRouter, HTTPException, Header +from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie from pydantic import BaseModel, Field from typing import Optional from src.common.logger import get_logger from .token_manager import get_token_manager +from .auth import set_auth_cookie, clear_auth_cookie from .config_routes import router as config_router from .statistics_routes import router as statistics_router from .person_routes import router as person_router @@ -51,6 +52,7 @@ class TokenVerifyResponse(BaseModel): valid: bool = Field(..., description="Token 是否有效") message: str = Field(..., description="验证结果消息") + is_first_setup: bool = Field(False, description="是否为首次设置") class TokenUpdateRequest(BaseModel): @@ -102,22 +104,27 @@ async def health_check(): @router.post("/auth/verify", response_model=TokenVerifyResponse) -async def verify_token(request: TokenVerifyRequest): +async def verify_token(request: TokenVerifyRequest, response: Response): """ - 验证访问令牌 + 验证访问令牌,验证成功后设置 HttpOnly Cookie Args: request: 包含 token 的验证请求 + response: FastAPI Response 对象 Returns: - 验证结果 + 验证结果(包含首次配置状态) """ try: token_manager = get_token_manager() is_valid = token_manager.verify_token(request.token) if is_valid: - return TokenVerifyResponse(valid=True, message="Token 验证成功") + # 设置 HttpOnly Cookie + set_auth_cookie(response, request.token) + # 同时返回首次配置状态,避免额外请求 + is_first_setup = token_manager.is_first_setup() + return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup) else: return TokenVerifyResponse(valid=False, message="Token 无效或已过期") except Exception as e: @@ -125,24 +132,86 @@ async def verify_token(request: TokenVerifyRequest): raise HTTPException(status_code=500, detail="Token 验证失败") from e +@router.post("/auth/logout") +async def logout(response: Response): + """ + 登出并清除认证 Cookie + + Args: + response: FastAPI Response 对象 + + Returns: + 登出结果 + """ + clear_auth_cookie(response) + return {"success": True, "message": "已成功登出"} + + +@router.get("/auth/check") +async def check_auth_status( + request: Request, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): + """ + 检查当前认证状态(用于前端判断是否已登录) + + Returns: + 认证状态 + """ + try: + token = None + + # 优先从 Cookie 获取 + if maibot_session: + token = maibot_session + # 其次从 Header 获取 + elif authorization and authorization.startswith("Bearer "): + token = authorization.replace("Bearer ", "") + + if not token: + return {"authenticated": False} + + token_manager = get_token_manager() + if token_manager.verify_token(token): + return {"authenticated": True} + else: + return {"authenticated": False} + except Exception: + return {"authenticated": False} + + @router.post("/auth/update", response_model=TokenUpdateResponse) -async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)): +async def update_token( + request: TokenUpdateRequest, + response: Response, + req: Request, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 更新访问令牌(需要当前有效的 token) Args: request: 包含新 token 的更新请求 + response: FastAPI Response 对象 + maibot_session: Cookie 中的 token authorization: Authorization header (Bearer token) Returns: 更新结果 """ try: - # 验证当前 token - if not authorization or not authorization.startswith("Bearer "): + # 验证当前 token(优先 Cookie,其次 Header) + current_token = None + if maibot_session: + current_token = maibot_session + elif authorization and authorization.startswith("Bearer "): + current_token = authorization.replace("Bearer ", "") + + if not current_token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") - current_token = authorization.replace("Bearer ", "") token_manager = get_token_manager() if not token_manager.verify_token(current_token): @@ -150,6 +219,10 @@ async def update_token(request: TokenUpdateRequest, authorization: Optional[str] # 更新 token success, message = token_manager.update_token(request.new_token) + + # 如果更新成功,更新 Cookie + if success: + set_auth_cookie(response, request.new_token) return TokenUpdateResponse(success=success, message=message) except HTTPException: @@ -160,22 +233,34 @@ async def update_token(request: TokenUpdateRequest, authorization: Optional[str] @router.post("/auth/regenerate", response_model=TokenRegenerateResponse) -async def regenerate_token(authorization: Optional[str] = Header(None)): +async def regenerate_token( + response: Response, + request: Request, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 重新生成访问令牌(需要当前有效的 token) Args: + response: FastAPI Response 对象 + maibot_session: Cookie 中的 token authorization: Authorization header (Bearer token) Returns: 新生成的 token """ try: - # 验证当前 token - if not authorization or not authorization.startswith("Bearer "): - raise HTTPException(status_code=401, detail="未提供有效的认证信息") + # 验证当前 token(优先 Cookie,其次 Header) + current_token = None + if maibot_session: + current_token = maibot_session + elif authorization and authorization.startswith("Bearer "): + current_token = authorization.replace("Bearer ", "") - current_token = authorization.replace("Bearer ", "") + if not current_token: + raise HTTPException(status_code=401, detail="未提供有效的认证信息") + token_manager = get_token_manager() if not token_manager.verify_token(current_token): @@ -183,6 +268,9 @@ async def regenerate_token(authorization: Optional[str] = Header(None)): # 重新生成 token new_token = token_manager.regenerate_token() + + # 更新 Cookie + set_auth_cookie(response, new_token) return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成") except HTTPException: @@ -193,22 +281,32 @@ async def regenerate_token(authorization: Optional[str] = Header(None)): @router.get("/setup/status", response_model=FirstSetupStatusResponse) -async def get_setup_status(authorization: Optional[str] = Header(None)): +async def get_setup_status( + request: Request, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 获取首次配置状态 Args: + maibot_session: Cookie 中的 token authorization: Authorization header (Bearer token) Returns: 首次配置状态 """ try: - # 验证 token - if not authorization or not authorization.startswith("Bearer "): + # 验证 token(优先 Cookie,其次 Header) + current_token = None + if maibot_session: + current_token = maibot_session + elif authorization and authorization.startswith("Bearer "): + current_token = authorization.replace("Bearer ", "") + + if not current_token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") - current_token = authorization.replace("Bearer ", "") token_manager = get_token_manager() if not token_manager.verify_token(current_token): @@ -226,22 +324,32 @@ async def get_setup_status(authorization: Optional[str] = Header(None)): @router.post("/setup/complete", response_model=CompleteSetupResponse) -async def complete_setup(authorization: Optional[str] = Header(None)): +async def complete_setup( + request: Request, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 标记首次配置完成 Args: + maibot_session: Cookie 中的 token authorization: Authorization header (Bearer token) Returns: 完成结果 """ try: - # 验证 token - if not authorization or not authorization.startswith("Bearer "): + # 验证 token(优先 Cookie,其次 Header) + current_token = None + if maibot_session: + current_token = maibot_session + elif authorization and authorization.startswith("Bearer "): + current_token = authorization.replace("Bearer ", "") + + if not current_token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") - current_token = authorization.replace("Bearer ", "") token_manager = get_token_manager() if not token_manager.verify_token(current_token): @@ -259,22 +367,32 @@ async def complete_setup(authorization: Optional[str] = Header(None)): @router.post("/setup/reset", response_model=ResetSetupResponse) -async def reset_setup(authorization: Optional[str] = Header(None)): +async def reset_setup( + request: Request, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 重置首次配置状态,允许重新进入配置向导 Args: + maibot_session: Cookie 中的 token authorization: Authorization header (Bearer token) Returns: 重置结果 """ try: - # 验证 token - if not authorization or not authorization.startswith("Bearer "): + # 验证 token(优先 Cookie,其次 Header) + current_token = None + if maibot_session: + current_token = maibot_session + elif authorization and authorization.startswith("Bearer "): + current_token = authorization.replace("Bearer ", "") + + if not current_token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") - current_token = authorization.replace("Bearer ", "") token_manager = get_token_manager() if not token_manager.verify_token(current_token): diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index 5997c3ba..ba6b7480 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -5,6 +5,7 @@ import asyncio import mimetypes from pathlib import Path from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from uvicorn import Config, Server as UvicornServer from src.common.logger import get_logger @@ -21,6 +22,9 @@ class WebUIServer: self.app = FastAPI(title="MaiBot WebUI") self._server = None + # 配置 CORS(支持开发环境跨域请求) + self._setup_cors() + # 显示 Access Token self._show_access_token() @@ -28,6 +32,23 @@ class WebUIServer: self._register_api_routes() self._setup_static_files() + def _setup_cors(self): + """配置 CORS 中间件""" + # 开发环境需要允许前端开发服务器的跨域请求 + self.app.add_middleware( + CORSMiddleware, + allow_origins=[ + "http://localhost:5173", # Vite 开发服务器 + "http://127.0.0.1:5173", + "http://localhost:8001", # 生产环境 + "http://127.0.0.1:8001", + ], + allow_credentials=True, # 允许携带 Cookie + allow_methods=["*"], + allow_headers=["*"], + ) + logger.debug("✅ CORS 中间件已配置") + def _show_access_token(self): """显示 WebUI Access Token""" try: