diff --git a/bot.py b/bot.py
index c2ed3dfd..f23b1dc7 100644
--- a/bot.py
+++ b/bot.py
@@ -8,18 +8,31 @@ from dotenv import load_dotenv
from loguru import logger
from nonebot.adapters.onebot.v11 import Adapter
import platform
+from src.common.output_manager import OutputManager
# 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ}
+# 初始化输出管理器
+output_manager = OutputManager()
+
def easter_egg():
# 彩蛋
from colorama import init, Fore
init()
- text = "多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
- rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
+ text = (
+ "多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
+ )
+ rainbow_colors = [
+ Fore.RED,
+ Fore.YELLOW,
+ Fore.GREEN,
+ Fore.CYAN,
+ Fore.BLUE,
+ Fore.MAGENTA,
+ ]
rainbow_text = ""
for i, char in enumerate(text):
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
@@ -37,7 +50,9 @@ def init_config():
logger.info("创建config目录")
shutil.copy("template/bot_config_template.toml", "config/bot_config.toml")
- logger.info("复制完成,请修改config/bot_config.toml和.env.prod中的配置后重新启动")
+ logger.info(
+ "复制完成,请修改config/bot_config.toml和.env.prod中的配置后重新启动"
+ )
def init_env():
@@ -66,16 +81,15 @@ def load_env():
# 使用闭包实现对加载器的横向扩展,避免大量重复判断
def prod():
logger.success("加载生产环境变量配置")
- load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
+ load_dotenv(
+ ".env.prod", override=True
+ ) # override=True 允许覆盖已存在的环境变量
def dev():
logger.success("加载开发环境变量配置")
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
- fn_map = {
- "prod": prod,
- "dev": dev
- }
+ fn_map = {"prod": prod, "dev": dev}
env = os.getenv("ENVIRONMENT")
logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}")
@@ -85,11 +99,17 @@ def load_env():
elif os.path.exists(f".env.{env}"):
logger.success(f"加载{env}环境变量配置")
- load_dotenv(f".env.{env}", override=True) # override=True 允许覆盖已存在的环境变量
+ load_dotenv(
+ f".env.{env}", override=True
+ ) # override=True 允许覆盖已存在的环境变量
else:
- logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
- RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
+ logger.error(
+ f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在"
+ )
+ RuntimeError(
+ f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在"
+ )
def load_logger():
@@ -97,10 +117,10 @@ def load_logger():
logger.add(
sys.stderr,
format="{time:YYYY-MM-DD HH:mm:ss.SSS} |> {level: <7} |> {name:.<8}:{function:.<8}:{line: >4} -> {message}",
+ "#777777>|> {name:.<8}:{function:.<8}:{line: >4} -> {message}",
colorize=True,
- level=os.getenv("LOG_LEVEL", "DEBUG") # 根据环境设置日志级别,默认为INFO
+ level=os.getenv("LOG_LEVEL", "DEBUG"), # 根据环境设置日志级别,默认为INFO
)
@@ -131,25 +151,30 @@ def scan_provider(env_config: dict):
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
- logger.error(
- f"provider 内容:{config}\n"
- f"env_config 内容:{env_config}"
+ logger.error(f"provider 内容:{config}\n" f"env_config 内容:{env_config}")
+ raise ValueError(
+ f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量"
)
- raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
if __name__ == "__main__":
# 利用 TZ 环境变量设定程序工作的时区
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
- if platform.system().lower() != 'windows':
+ if platform.system().lower() != "windows":
time.tzset()
+ # 启动输出捕获
+ output_manager.start_capture()
+ logger.info("输出管理器已启动")
+
easter_egg()
load_logger()
init_config()
init_env()
load_env()
- load_logger()
+
+ # 启动彩蛋
+ easter_egg()
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
@@ -171,4 +196,23 @@ if __name__ == "__main__":
# 加载插件
nonebot.load_plugins("src/plugins")
+ # 在启动时,注册清理函数
+ import atexit
+
+ def cleanup():
+ # 停止输出捕获
+ output_manager.stop_capture()
+
+ atexit.register(cleanup)
+
+ # 启动
nonebot.run()
+
+
+if __name__ == "__main__":
+ # 利用 TZ 环境变量设定程序工作的时区
+ # 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
+ if platform.system().lower() != "windows":
+ time.tzset()
+
+ init()
diff --git a/requirements.txt b/requirements.txt
index 4f969682..d769712d 100644
Binary files a/requirements.txt and b/requirements.txt differ
diff --git a/src/common/output_manager.py b/src/common/output_manager.py
new file mode 100644
index 00000000..0b91b154
--- /dev/null
+++ b/src/common/output_manager.py
@@ -0,0 +1,276 @@
+import sys
+import threading
+import time
+import json
+import asyncio
+import websockets
+from io import StringIO
+from datetime import datetime
+from typing import Dict, List, Any, Optional, Set, Callable
+import logging
+from loguru import logger
+
+
+class OutputCapture:
+ """捕获标准输出和错误输出的类"""
+
+ def __init__(self, original_stream):
+ self.original_stream = original_stream
+ self.buffer = StringIO()
+ self.listeners: List[Callable[[str], None]] = []
+
+ def write(self, data):
+ # 写入原始流
+ self.original_stream.write(data)
+ # 写入缓冲区
+ self.buffer.write(data)
+ # 通知所有监听器
+ for listener in self.listeners:
+ listener(data)
+
+ def flush(self):
+ self.original_stream.flush()
+ self.buffer.flush()
+
+ def add_listener(self, listener: Callable[[str], None]):
+ """添加输出监听器"""
+ self.listeners.append(listener)
+
+ def remove_listener(self, listener: Callable[[str], None]):
+ """移除输出监听器"""
+ if listener in self.listeners:
+ self.listeners.remove(listener)
+
+
+class WebSocketServer:
+ """WebSocket服务器,用于将捕获的输出发送到前端"""
+
+ def __init__(self, host: str = "localhost", port: int = 8765):
+ self.host = host
+ self.port = port
+ self.server = None
+ self.clients: Set[websockets.WebSocketServerProtocol] = set()
+ self.running = False
+ self.loop = None
+
+ async def handler(self, websocket, path):
+ """处理WebSocket连接"""
+ self.clients.add(websocket)
+ try:
+ await websocket.send(
+ json.dumps(
+ {
+ "type": "connection",
+ "message": "已连接到MaiMBot输出管理器",
+ "timestamp": datetime.now().isoformat(),
+ }
+ )
+ )
+
+ # 发送历史记录
+ history = OutputManager().get_message_history()
+ await websocket.send(
+ json.dumps(
+ {
+ "type": "history",
+ "messages": history,
+ "timestamp": datetime.now().isoformat(),
+ }
+ )
+ )
+
+ # 保持连接直到客户端断开
+ while True:
+ await websocket.recv()
+ except websockets.exceptions.ConnectionClosed:
+ pass
+ finally:
+ self.clients.remove(websocket)
+
+ async def broadcast(self, message: Dict[str, Any]):
+ """广播消息到所有连接的客户端"""
+ if not self.clients:
+ return
+
+ message_json = json.dumps(message)
+ await asyncio.gather(*[client.send(message_json) for client in self.clients])
+
+ def start(self):
+ """启动WebSocket服务器"""
+ if self.running:
+ return
+
+ self.running = True
+
+ # 在新线程中运行事件循环
+ def run_server():
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self.loop)
+
+ start_server = websockets.serve(self.handler, self.host, self.port)
+
+ self.server = self.loop.run_until_complete(start_server)
+ logger.info(f"WebSocket服务器已启动,运行在 ws://{self.host}:{self.port}")
+
+ try:
+ self.loop.run_forever()
+ except KeyboardInterrupt:
+ pass
+ finally:
+ self.server.close()
+ self.loop.run_until_complete(self.server.wait_closed())
+ self.loop.close()
+ self.running = False
+
+ threading.Thread(target=run_server, daemon=True).start()
+
+ def stop(self):
+ """停止WebSocket服务器"""
+ if not self.running or not self.loop:
+ return
+
+ self.loop.call_soon_threadsafe(self.loop.stop)
+ self.running = False
+ logger.info("WebSocket服务器已停止")
+
+
+class OutputManager:
+ """单例模式的输出管理器"""
+
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls):
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super(OutputManager, cls).__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self):
+ if self._initialized:
+ return
+
+ # 初始化属性
+ self.stdout_capture = None
+ self.stderr_capture = None
+ self.websocket_server = WebSocketServer()
+ self.message_history: List[Dict[str, Any]] = []
+ self.max_history_size = 1000 # 最大历史记录数量
+ self._initialized = True
+
+ def start_capture(self):
+ """开始捕获标准输出和错误输出"""
+ if self.stdout_capture is None:
+ self.stdout_capture = OutputCapture(sys.stdout)
+ sys.stdout = self.stdout_capture
+ self.stdout_capture.add_listener(
+ lambda data: self._process_output("stdout", data)
+ )
+
+ if self.stderr_capture is None:
+ self.stderr_capture = OutputCapture(sys.stderr)
+ sys.stderr = self.stderr_capture
+ self.stderr_capture.add_listener(
+ lambda data: self._process_output("stderr", data)
+ )
+
+ # 启动WebSocket服务器
+ self.websocket_server.start()
+
+ logger.info("输出捕获已启动")
+
+ def stop_capture(self):
+ """停止捕获标准输出和错误输出"""
+ if self.stdout_capture is not None:
+ sys.stdout = self.stdout_capture.original_stream
+ self.stdout_capture = None
+
+ if self.stderr_capture is not None:
+ sys.stderr = self.stderr_capture.original_stream
+ self.stderr_capture = None
+
+ # 停止WebSocket服务器
+ self.websocket_server.stop()
+
+ logger.info("输出捕获已停止")
+
+ def _process_output(self, source: str, data: str):
+ """处理捕获的输出"""
+ if not data.strip():
+ return
+
+ message = {
+ "type": "output",
+ "source": source,
+ "content": data,
+ "timestamp": datetime.now().isoformat(),
+ }
+
+ # 添加到历史记录
+ self.message_history.append(message)
+
+ # 限制历史记录大小
+ if len(self.message_history) > self.max_history_size:
+ self.message_history = self.message_history[-self.max_history_size :]
+
+ # 通过WebSocket发送
+ if self.websocket_server.running and self.websocket_server.loop:
+ asyncio.run_coroutine_threadsafe(
+ self.websocket_server.broadcast(message), self.websocket_server.loop
+ )
+
+ def get_message_history(self) -> List[Dict[str, Any]]:
+ """获取消息历史记录"""
+ return self.message_history
+
+ def send_custom_message(
+ self, message_type: str, content: Any, source: str = "custom"
+ ):
+ """发送自定义消息"""
+ message = {
+ "type": message_type,
+ "source": source,
+ "content": content,
+ "timestamp": datetime.now().isoformat(),
+ }
+
+ # 添加到历史记录
+ self.message_history.append(message)
+
+ # 限制历史记录大小
+ if len(self.message_history) > self.max_history_size:
+ self.message_history = self.message_history[-self.max_history_size :]
+
+ # 通过WebSocket发送
+ if self.websocket_server.running and self.websocket_server.loop:
+ asyncio.run_coroutine_threadsafe(
+ self.websocket_server.broadcast(message), self.websocket_server.loop
+ )
+
+
+# 示例用法
+if __name__ == "__main__":
+ # 启动输出管理器
+ output_manager = OutputManager()
+ output_manager.start_capture()
+
+ # 模拟程序输出
+ print("这是一条标准输出信息")
+ print("这是另一条标准输出信息")
+
+ # 模拟错误输出
+ sys.stderr.write("这是一条错误信息\n")
+
+ # 发送自定义消息
+ output_manager.send_custom_message(
+ "status", {"progress": 50, "status": "processing"}, "task_manager"
+ )
+
+ # 保持程序运行
+ try:
+ while True:
+ time.sleep(1)
+ except KeyboardInterrupt:
+ output_manager.stop_capture()
+ print("程序已退出")
diff --git a/test_output_manager.py b/test_output_manager.py
new file mode 100644
index 00000000..a2d42688
--- /dev/null
+++ b/test_output_manager.py
@@ -0,0 +1,47 @@
+import time
+import sys
+from src.common.output_manager import OutputManager
+
+
+def main():
+ # 初始化输出管理器
+ output_manager = OutputManager()
+
+ # 启动输出捕获
+ output_manager.start_capture()
+ print("输出管理器已启动,WebSocket服务器运行在 ws://localhost:8765")
+ print("您可以使用任何WebSocket客户端连接到此地址来接收输出")
+
+ # 发送一些测试消息
+ print("这是一条标准输出消息")
+ sys.stderr.write("这是一条错误输出消息\n")
+
+ # 发送自定义消息
+ output_manager.send_custom_message(
+ "status", {"progress": 50, "status": "processing"}, "task_manager"
+ )
+
+ # 每隔1秒发送一条消息
+ try:
+ count = 0
+ while True:
+ count += 1
+ print(f"测试消息 #{count}")
+ if count % 5 == 0:
+ sys.stderr.write(f"错误消息 #{count}\n")
+
+ # 每10条消息发送一次自定义消息
+ if count % 10 == 0:
+ output_manager.send_custom_message(
+ "progress", {"count": count, "percentage": count % 100}, "counter"
+ )
+
+ time.sleep(1)
+ except KeyboardInterrupt:
+ # 停止输出捕获
+ output_manager.stop_capture()
+ print("程序已退出")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/websocket_client_example.py b/websocket_client_example.py
new file mode 100644
index 00000000..74c7e65a
--- /dev/null
+++ b/websocket_client_example.py
@@ -0,0 +1,58 @@
+import asyncio
+import websockets
+import json
+from datetime import datetime
+
+
+async def connect_to_output_manager():
+ """连接到OutputManager的WebSocket服务器并接收消息"""
+
+ uri = "ws://localhost:8765"
+ print(f"正在连接到 {uri}...")
+
+ try:
+ async with websockets.connect(uri) as websocket:
+ print("已连接到服务器")
+
+ while True:
+ # 接收消息
+ message = await websocket.recv()
+ data = json.loads(message)
+
+ # 格式化时间戳
+ timestamp = datetime.fromisoformat(data["timestamp"]).strftime(
+ "%H:%M:%S"
+ )
+
+ # 根据消息类型和来源显示不同颜色
+ if data["type"] == "output":
+ if data["source"] == "stdout":
+ # 标准输出 - 白色
+ print(f"\033[37m[{timestamp}] {data['content']}\033[0m", end="")
+ elif data["source"] == "stderr":
+ # 错误输出 - 红色
+ print(f"\033[31m[{timestamp}] {data['content']}\033[0m", end="")
+ elif data["type"] == "connection":
+ # 连接消息 - 绿色
+ print(f"\033[32m[{timestamp}] {data['message']}\033[0m")
+ elif data["type"] == "history":
+ # 历史记录 - 蓝色
+ print(
+ f"\033[34m[{timestamp}] 收到历史记录: {len(data['messages'])} 条消息\033[0m"
+ )
+ else:
+ # 其他消息 - 黄色
+ content = data.get("content", "")
+ if isinstance(content, dict):
+ content = json.dumps(content, ensure_ascii=False)
+ print(f"\033[33m[{timestamp}] [{data['type']}] {content}\033[0m")
+
+ except websockets.exceptions.ConnectionClosed:
+ print("与服务器的连接已断开")
+ except Exception as e:
+ print(f"发生错误: {e}")
+
+
+if __name__ == "__main__":
+ # 运行客户端
+ asyncio.run(connect_to_output_manager())