From a96bc531785bb6208a4d9b0f027cda653e665ff3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AB=E5=88=86=E6=A9=98=E5=AD=90?= Date: Tue, 11 Mar 2025 01:10:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=BE=93=E5=87=BA=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E5=99=A8=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 84 ++++++++--- requirements.txt | Bin 600 -> 624 bytes src/common/output_manager.py | 276 +++++++++++++++++++++++++++++++++++ test_output_manager.py | 47 ++++++ websocket_client_example.py | 58 ++++++++ 5 files changed, 445 insertions(+), 20 deletions(-) create mode 100644 src/common/output_manager.py create mode 100644 test_output_manager.py create mode 100644 websocket_client_example.py diff --git a/bot.py b/bot.py index c2ed3dfd..f23b1dc7 100644 --- a/bot.py +++ b/bot.py @@ -8,18 +8,31 @@ from dotenv import load_dotenv from loguru import logger from nonebot.adapters.onebot.v11 import Adapter import platform +from src.common.output_manager import OutputManager # 获取没有加载env时的环境变量 env_mask = {key: os.getenv(key) for key in os.environ} +# 初始化输出管理器 +output_manager = OutputManager() + def easter_egg(): # 彩蛋 from colorama import init, Fore init() - text = "多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午" - rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA] + text = ( + "多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午" + ) + rainbow_colors = [ + Fore.RED, + Fore.YELLOW, + Fore.GREEN, + Fore.CYAN, + Fore.BLUE, + Fore.MAGENTA, + ] rainbow_text = "" for i, char in enumerate(text): rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char @@ -37,7 +50,9 @@ def init_config(): logger.info("创建config目录") shutil.copy("template/bot_config_template.toml", "config/bot_config.toml") - logger.info("复制完成,请修改config/bot_config.toml和.env.prod中的配置后重新启动") + logger.info( + "复制完成,请修改config/bot_config.toml和.env.prod中的配置后重新启动" + ) def init_env(): @@ -66,16 +81,15 @@ def load_env(): # 使用闭包实现对加载器的横向扩展,避免大量重复判断 def prod(): logger.success("加载生产环境变量配置") - load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量 + load_dotenv( + ".env.prod", override=True + ) # override=True 允许覆盖已存在的环境变量 def dev(): logger.success("加载开发环境变量配置") load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量 - fn_map = { - "prod": prod, - "dev": dev - } + fn_map = {"prod": prod, "dev": dev} env = os.getenv("ENVIRONMENT") logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}") @@ -85,11 +99,17 @@ def load_env(): elif os.path.exists(f".env.{env}"): logger.success(f"加载{env}环境变量配置") - load_dotenv(f".env.{env}", override=True) # override=True 允许覆盖已存在的环境变量 + load_dotenv( + f".env.{env}", override=True + ) # override=True 允许覆盖已存在的环境变量 else: - logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") - RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") + logger.error( + f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在" + ) + RuntimeError( + f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在" + ) def load_logger(): @@ -97,10 +117,10 @@ def load_logger(): logger.add( sys.stderr, format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <7} | {name:.<8}:{function:.<8}:{line: >4} - {message}", + "#777777>| {name:.<8}:{function:.<8}:{line: >4} - {message}", colorize=True, - level=os.getenv("LOG_LEVEL", "DEBUG") # 根据环境设置日志级别,默认为INFO + level=os.getenv("LOG_LEVEL", "DEBUG"), # 根据环境设置日志级别,默认为INFO ) @@ -131,25 +151,30 @@ def scan_provider(env_config: dict): # 检查每个 provider 是否同时存在 url 和 key for provider_name, config in provider.items(): if config["url"] is None or config["key"] is None: - logger.error( - f"provider 内容:{config}\n" - f"env_config 内容:{env_config}" + logger.error(f"provider 内容:{config}\n" f"env_config 内容:{env_config}") + raise ValueError( + f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量" ) - raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") if __name__ == "__main__": # 利用 TZ 环境变量设定程序工作的时区 # 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用 - if platform.system().lower() != 'windows': + if platform.system().lower() != "windows": time.tzset() + # 启动输出捕获 + output_manager.start_capture() + logger.info("输出管理器已启动") + easter_egg() load_logger() init_config() init_env() load_env() - load_logger() + + # 启动彩蛋 + easter_egg() env_config = {key: os.getenv(key) for key in os.environ} scan_provider(env_config) @@ -171,4 +196,23 @@ if __name__ == "__main__": # 加载插件 nonebot.load_plugins("src/plugins") + # 在启动时,注册清理函数 + import atexit + + def cleanup(): + # 停止输出捕获 + output_manager.stop_capture() + + atexit.register(cleanup) + + # 启动 nonebot.run() + + +if __name__ == "__main__": + # 利用 TZ 环境变量设定程序工作的时区 + # 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用 + if platform.system().lower() != "windows": + time.tzset() + + init() diff --git a/requirements.txt b/requirements.txt index 4f969682f9b0ebbbd9aebe4fe94758da052d23b4..d769712d7bfc90e68ebb0f10e0547f9c81d45d77 100644 GIT binary patch delta 32 lcmcb?@_}VT1d{|W0~bR%Ln=cOLoq`>Lo!1)kX-^~0|1An2Sfk> delta 7 Ocmeysa)V_<1QP%axdPh& diff --git a/src/common/output_manager.py b/src/common/output_manager.py new file mode 100644 index 00000000..0b91b154 --- /dev/null +++ b/src/common/output_manager.py @@ -0,0 +1,276 @@ +import sys +import threading +import time +import json +import asyncio +import websockets +from io import StringIO +from datetime import datetime +from typing import Dict, List, Any, Optional, Set, Callable +import logging +from loguru import logger + + +class OutputCapture: + """捕获标准输出和错误输出的类""" + + def __init__(self, original_stream): + self.original_stream = original_stream + self.buffer = StringIO() + self.listeners: List[Callable[[str], None]] = [] + + def write(self, data): + # 写入原始流 + self.original_stream.write(data) + # 写入缓冲区 + self.buffer.write(data) + # 通知所有监听器 + for listener in self.listeners: + listener(data) + + def flush(self): + self.original_stream.flush() + self.buffer.flush() + + def add_listener(self, listener: Callable[[str], None]): + """添加输出监听器""" + self.listeners.append(listener) + + def remove_listener(self, listener: Callable[[str], None]): + """移除输出监听器""" + if listener in self.listeners: + self.listeners.remove(listener) + + +class WebSocketServer: + """WebSocket服务器,用于将捕获的输出发送到前端""" + + def __init__(self, host: str = "localhost", port: int = 8765): + self.host = host + self.port = port + self.server = None + self.clients: Set[websockets.WebSocketServerProtocol] = set() + self.running = False + self.loop = None + + async def handler(self, websocket, path): + """处理WebSocket连接""" + self.clients.add(websocket) + try: + await websocket.send( + json.dumps( + { + "type": "connection", + "message": "已连接到MaiMBot输出管理器", + "timestamp": datetime.now().isoformat(), + } + ) + ) + + # 发送历史记录 + history = OutputManager().get_message_history() + await websocket.send( + json.dumps( + { + "type": "history", + "messages": history, + "timestamp": datetime.now().isoformat(), + } + ) + ) + + # 保持连接直到客户端断开 + while True: + await websocket.recv() + except websockets.exceptions.ConnectionClosed: + pass + finally: + self.clients.remove(websocket) + + async def broadcast(self, message: Dict[str, Any]): + """广播消息到所有连接的客户端""" + if not self.clients: + return + + message_json = json.dumps(message) + await asyncio.gather(*[client.send(message_json) for client in self.clients]) + + def start(self): + """启动WebSocket服务器""" + if self.running: + return + + self.running = True + + # 在新线程中运行事件循环 + def run_server(): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + start_server = websockets.serve(self.handler, self.host, self.port) + + self.server = self.loop.run_until_complete(start_server) + logger.info(f"WebSocket服务器已启动,运行在 ws://{self.host}:{self.port}") + + try: + self.loop.run_forever() + except KeyboardInterrupt: + pass + finally: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.loop.close() + self.running = False + + threading.Thread(target=run_server, daemon=True).start() + + def stop(self): + """停止WebSocket服务器""" + if not self.running or not self.loop: + return + + self.loop.call_soon_threadsafe(self.loop.stop) + self.running = False + logger.info("WebSocket服务器已停止") + + +class OutputManager: + """单例模式的输出管理器""" + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super(OutputManager, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + # 初始化属性 + self.stdout_capture = None + self.stderr_capture = None + self.websocket_server = WebSocketServer() + self.message_history: List[Dict[str, Any]] = [] + self.max_history_size = 1000 # 最大历史记录数量 + self._initialized = True + + def start_capture(self): + """开始捕获标准输出和错误输出""" + if self.stdout_capture is None: + self.stdout_capture = OutputCapture(sys.stdout) + sys.stdout = self.stdout_capture + self.stdout_capture.add_listener( + lambda data: self._process_output("stdout", data) + ) + + if self.stderr_capture is None: + self.stderr_capture = OutputCapture(sys.stderr) + sys.stderr = self.stderr_capture + self.stderr_capture.add_listener( + lambda data: self._process_output("stderr", data) + ) + + # 启动WebSocket服务器 + self.websocket_server.start() + + logger.info("输出捕获已启动") + + def stop_capture(self): + """停止捕获标准输出和错误输出""" + if self.stdout_capture is not None: + sys.stdout = self.stdout_capture.original_stream + self.stdout_capture = None + + if self.stderr_capture is not None: + sys.stderr = self.stderr_capture.original_stream + self.stderr_capture = None + + # 停止WebSocket服务器 + self.websocket_server.stop() + + logger.info("输出捕获已停止") + + def _process_output(self, source: str, data: str): + """处理捕获的输出""" + if not data.strip(): + return + + message = { + "type": "output", + "source": source, + "content": data, + "timestamp": datetime.now().isoformat(), + } + + # 添加到历史记录 + self.message_history.append(message) + + # 限制历史记录大小 + if len(self.message_history) > self.max_history_size: + self.message_history = self.message_history[-self.max_history_size :] + + # 通过WebSocket发送 + if self.websocket_server.running and self.websocket_server.loop: + asyncio.run_coroutine_threadsafe( + self.websocket_server.broadcast(message), self.websocket_server.loop + ) + + def get_message_history(self) -> List[Dict[str, Any]]: + """获取消息历史记录""" + return self.message_history + + def send_custom_message( + self, message_type: str, content: Any, source: str = "custom" + ): + """发送自定义消息""" + message = { + "type": message_type, + "source": source, + "content": content, + "timestamp": datetime.now().isoformat(), + } + + # 添加到历史记录 + self.message_history.append(message) + + # 限制历史记录大小 + if len(self.message_history) > self.max_history_size: + self.message_history = self.message_history[-self.max_history_size :] + + # 通过WebSocket发送 + if self.websocket_server.running and self.websocket_server.loop: + asyncio.run_coroutine_threadsafe( + self.websocket_server.broadcast(message), self.websocket_server.loop + ) + + +# 示例用法 +if __name__ == "__main__": + # 启动输出管理器 + output_manager = OutputManager() + output_manager.start_capture() + + # 模拟程序输出 + print("这是一条标准输出信息") + print("这是另一条标准输出信息") + + # 模拟错误输出 + sys.stderr.write("这是一条错误信息\n") + + # 发送自定义消息 + output_manager.send_custom_message( + "status", {"progress": 50, "status": "processing"}, "task_manager" + ) + + # 保持程序运行 + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + output_manager.stop_capture() + print("程序已退出") diff --git a/test_output_manager.py b/test_output_manager.py new file mode 100644 index 00000000..a2d42688 --- /dev/null +++ b/test_output_manager.py @@ -0,0 +1,47 @@ +import time +import sys +from src.common.output_manager import OutputManager + + +def main(): + # 初始化输出管理器 + output_manager = OutputManager() + + # 启动输出捕获 + output_manager.start_capture() + print("输出管理器已启动,WebSocket服务器运行在 ws://localhost:8765") + print("您可以使用任何WebSocket客户端连接到此地址来接收输出") + + # 发送一些测试消息 + print("这是一条标准输出消息") + sys.stderr.write("这是一条错误输出消息\n") + + # 发送自定义消息 + output_manager.send_custom_message( + "status", {"progress": 50, "status": "processing"}, "task_manager" + ) + + # 每隔1秒发送一条消息 + try: + count = 0 + while True: + count += 1 + print(f"测试消息 #{count}") + if count % 5 == 0: + sys.stderr.write(f"错误消息 #{count}\n") + + # 每10条消息发送一次自定义消息 + if count % 10 == 0: + output_manager.send_custom_message( + "progress", {"count": count, "percentage": count % 100}, "counter" + ) + + time.sleep(1) + except KeyboardInterrupt: + # 停止输出捕获 + output_manager.stop_capture() + print("程序已退出") + + +if __name__ == "__main__": + main() diff --git a/websocket_client_example.py b/websocket_client_example.py new file mode 100644 index 00000000..74c7e65a --- /dev/null +++ b/websocket_client_example.py @@ -0,0 +1,58 @@ +import asyncio +import websockets +import json +from datetime import datetime + + +async def connect_to_output_manager(): + """连接到OutputManager的WebSocket服务器并接收消息""" + + uri = "ws://localhost:8765" + print(f"正在连接到 {uri}...") + + try: + async with websockets.connect(uri) as websocket: + print("已连接到服务器") + + while True: + # 接收消息 + message = await websocket.recv() + data = json.loads(message) + + # 格式化时间戳 + timestamp = datetime.fromisoformat(data["timestamp"]).strftime( + "%H:%M:%S" + ) + + # 根据消息类型和来源显示不同颜色 + if data["type"] == "output": + if data["source"] == "stdout": + # 标准输出 - 白色 + print(f"\033[37m[{timestamp}] {data['content']}\033[0m", end="") + elif data["source"] == "stderr": + # 错误输出 - 红色 + print(f"\033[31m[{timestamp}] {data['content']}\033[0m", end="") + elif data["type"] == "connection": + # 连接消息 - 绿色 + print(f"\033[32m[{timestamp}] {data['message']}\033[0m") + elif data["type"] == "history": + # 历史记录 - 蓝色 + print( + f"\033[34m[{timestamp}] 收到历史记录: {len(data['messages'])} 条消息\033[0m" + ) + else: + # 其他消息 - 黄色 + content = data.get("content", "") + if isinstance(content, dict): + content = json.dumps(content, ensure_ascii=False) + print(f"\033[33m[{timestamp}] [{data['type']}] {content}\033[0m") + + except websockets.exceptions.ConnectionClosed: + print("与服务器的连接已断开") + except Exception as e: + print(f"发生错误: {e}") + + +if __name__ == "__main__": + # 运行客户端 + asyncio.run(connect_to_output_manager())