fix: 修复“优雅退出”需要按两次 Ctrl+C 的问题

pull/61/head
Ronifue 2025-09-18 16:30:02 +08:00
parent 178912375d
commit 5302c55184
1 changed files with 67 additions and 25 deletions

92
main.py
View File

@ -2,6 +2,7 @@ import asyncio
import sys import sys
import json import json
import http import http
import signal
import websockets as Server import websockets as Server
from src.logger import logger from src.logger import logger
from src.recv_handler.message_handler import message_handler from src.recv_handler.message_handler import message_handler
@ -14,6 +15,7 @@ from src.mmc_com_layer import mmc_start_com, mmc_stop_com, router
from src.response_pool import put_response, check_timeout_response from src.response_pool import put_response, check_timeout_response
message_queue = asyncio.Queue() message_queue = asyncio.Queue()
server: Server.Server = None
async def message_recv(server_connection: Server.ServerConnection): async def message_recv(server_connection: Server.ServerConnection):
@ -50,6 +52,7 @@ async def main():
message_send_instance.maibot_router = router message_send_instance.maibot_router = router
_ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process(), check_timeout_response()) _ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process(), check_timeout_response())
def check_napcat_server_token(conn, request): def check_napcat_server_token(conn, request):
token = global_config.napcat_server.token token = global_config.napcat_server.token
if not token or token.strip() == "": if not token or token.strip() == "":
@ -59,45 +62,84 @@ def check_napcat_server_token(conn, request):
return Server.Response( return Server.Response(
status=http.HTTPStatus.UNAUTHORIZED, status=http.HTTPStatus.UNAUTHORIZED,
headers=Server.Headers([("Content-Type", "text/plain")]), headers=Server.Headers([("Content-Type", "text/plain")]),
body=b"Unauthorized\n" body=b"Unauthorized\n",
) )
return None return None
async def napcat_server(): async def napcat_server():
global server
logger.info("正在启动adapter...") logger.info("正在启动adapter...")
async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port, max_size=2**26, process_request=check_napcat_server_token) as server: server = await Server.serve(
logger.info( message_recv,
f"Adapter已启动监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}" global_config.napcat_server.host,
) global_config.napcat_server.port,
await server.serve_forever() max_size=2**26,
process_request=check_napcat_server_token,
)
logger.info(f"Adapter已启动监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}")
await server.wait_closed()
async def graceful_shutdown(): async def graceful_shutdown(loop: asyncio.AbstractEventLoop, timeout: float = 10.0):
try: logger.info("正在关闭adapter...")
logger.info("正在关闭adapter...")
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] # 主动关闭所有 websocket 连接
if server and server.is_serving() and server.connections:
logger.info(f"正在关闭 {len(server.connections)} 个客户端连接...")
close_tasks = [conn.close() for conn in server.connections]
await asyncio.gather(*close_tasks, return_exceptions=True)
# 关闭服务器,停止接受新连接
if server and server.is_serving():
logger.info("正在关闭 Websocket 服务器...")
server.close()
await server.wait_closed()
logger.info("Websocket 服务器已关闭")
# 关闭 aiohttp 客户端
await mmc_stop_com()
logger.info("MMC com layer 已停止")
# 取消所有其他任务
tasks = [t for t in asyncio.all_tasks(loop=loop) if t is not asyncio.current_task()]
if tasks:
logger.info(f"正在取消 {len(tasks)} 个剩余任务...")
for task in tasks: for task in tasks:
if not task.done(): task.cancel()
task.cancel()
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 15) await asyncio.gather(*tasks, return_exceptions=True)
await mmc_stop_com() # 后置避免神秘exception logger.info("所有剩余任务已处理完毕")
logger.info("Adapter已成功关闭")
except Exception as e: logger.info("Adapter 已成功关闭")
logger.error(f"Adapter关闭中出现错误: {e}")
if __name__ == "__main__": if __name__ == "__main__":
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
shutdown_event = asyncio.Event()
def _shutdown_handler(sig: int):
logger.warning(f"收到信号 {signal.Signals(sig).name},开始关闭...")
shutdown_event.set()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, _shutdown_handler, sig)
except NotImplementedError:
pass # Windows
try: try:
loop.run_until_complete(main()) main_task = loop.create_task(main())
except KeyboardInterrupt: loop.run_until_complete(shutdown_event.wait())
logger.warning("收到中断信号,正在优雅关闭...")
loop.run_until_complete(graceful_shutdown())
except Exception as e:
logger.exception(f"主程序异常: {str(e)}")
sys.exit(1)
finally: finally:
logger.info("开始优雅关闭流程...")
# 执行新的关机流程
loop.run_until_complete(graceful_shutdown(loop=loop))
# 最终关闭循环
if loop and not loop.is_closed(): if loop and not loop.is_closed():
loop.close() loop.close()
sys.exit(0) logger.info("程序已退出")
sys.exit(0)