diff --git a/main.py b/main.py index 64d8c32..1e29ec2 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import asyncio import sys import json +import http import websockets as Server from src.logger import logger from src.recv_handler.message_handler import message_handler @@ -49,10 +50,17 @@ async def main(): message_send_instance.maibot_router = router _ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process(), check_timeout_response()) +def check_napcat_server_token(path, request_headers): + token = global_config.napcat_server.token + if not token or token.strip() == "": + return + auth_header = request_headers.get("Authorization") + if auth_header != f"Bearer {token}": + return http.HTTPStatus.UNAUTHORIZED, [], b"Unauthorized\n" async def napcat_server(): logger.info("正在启动adapter...") - async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port, max_size=2**26) as server: + async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port, max_size=2**26, process_headers=check_napcat_server_token) as server: logger.info( f"Adapter已启动,监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}" )