添加输出管理器功能

pull/177/head
八分橘子 2025-03-11 01:10:50 +08:00
parent 62523409d1
commit a96bc53178
5 changed files with 445 additions and 20 deletions

84
bot.py
View File

@ -8,18 +8,31 @@ from dotenv import load_dotenv
from loguru import logger from loguru import logger
from nonebot.adapters.onebot.v11 import Adapter from nonebot.adapters.onebot.v11 import Adapter
import platform import platform
from src.common.output_manager import OutputManager
# 获取没有加载env时的环境变量 # 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ} env_mask = {key: os.getenv(key) for key in os.environ}
# 初始化输出管理器
output_manager = OutputManager()
def easter_egg(): def easter_egg():
# 彩蛋 # 彩蛋
from colorama import init, Fore from colorama import init, Fore
init() init()
text = "多年以后面对AI行刑队张三将会回想起他2023年在会议上讨论人工智能的那个下午" text = (
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA] "多年以后面对AI行刑队张三将会回想起他2023年在会议上讨论人工智能的那个下午"
)
rainbow_colors = [
Fore.RED,
Fore.YELLOW,
Fore.GREEN,
Fore.CYAN,
Fore.BLUE,
Fore.MAGENTA,
]
rainbow_text = "" rainbow_text = ""
for i, char in enumerate(text): for i, char in enumerate(text):
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
@ -37,7 +50,9 @@ def init_config():
logger.info("创建config目录") logger.info("创建config目录")
shutil.copy("template/bot_config_template.toml", "config/bot_config.toml") shutil.copy("template/bot_config_template.toml", "config/bot_config.toml")
logger.info("复制完成请修改config/bot_config.toml和.env.prod中的配置后重新启动") logger.info(
"复制完成请修改config/bot_config.toml和.env.prod中的配置后重新启动"
)
def init_env(): def init_env():
@ -66,16 +81,15 @@ def load_env():
# 使用闭包实现对加载器的横向扩展,避免大量重复判断 # 使用闭包实现对加载器的横向扩展,避免大量重复判断
def prod(): def prod():
logger.success("加载生产环境变量配置") logger.success("加载生产环境变量配置")
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量 load_dotenv(
".env.prod", override=True
) # override=True 允许覆盖已存在的环境变量
def dev(): def dev():
logger.success("加载开发环境变量配置") logger.success("加载开发环境变量配置")
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量 load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
fn_map = { fn_map = {"prod": prod, "dev": dev}
"prod": prod,
"dev": dev
}
env = os.getenv("ENVIRONMENT") env = os.getenv("ENVIRONMENT")
logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}") logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}")
@ -85,11 +99,17 @@ def load_env():
elif os.path.exists(f".env.{env}"): elif os.path.exists(f".env.{env}"):
logger.success(f"加载{env}环境变量配置") logger.success(f"加载{env}环境变量配置")
load_dotenv(f".env.{env}", override=True) # override=True 允许覆盖已存在的环境变量 load_dotenv(
f".env.{env}", override=True
) # override=True 允许覆盖已存在的环境变量
else: else:
logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") logger.error(
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在"
)
RuntimeError(
f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在"
)
def load_logger(): def load_logger():
@ -97,10 +117,10 @@ def load_logger():
logger.add( logger.add(
sys.stderr, sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg " format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg " "#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>", "#777777>-</> <level>{message}</level>",
colorize=True, colorize=True,
level=os.getenv("LOG_LEVEL", "DEBUG") # 根据环境设置日志级别默认为INFO level=os.getenv("LOG_LEVEL", "DEBUG"), # 根据环境设置日志级别默认为INFO
) )
@ -131,25 +151,30 @@ def scan_provider(env_config: dict):
# 检查每个 provider 是否同时存在 url 和 key # 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items(): for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None: if config["url"] is None or config["key"] is None:
logger.error( logger.error(f"provider 内容:{config}\n" f"env_config 内容:{env_config}")
f"provider 内容:{config}\n" raise ValueError(
f"env_config 内容:{env_config}" f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量"
) )
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
if __name__ == "__main__": if __name__ == "__main__":
# 利用 TZ 环境变量设定程序工作的时区 # 利用 TZ 环境变量设定程序工作的时区
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用 # 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
if platform.system().lower() != 'windows': if platform.system().lower() != "windows":
time.tzset() time.tzset()
# 启动输出捕获
output_manager.start_capture()
logger.info("输出管理器已启动")
easter_egg() easter_egg()
load_logger() load_logger()
init_config() init_config()
init_env() init_env()
load_env() load_env()
load_logger()
# 启动彩蛋
easter_egg()
env_config = {key: os.getenv(key) for key in os.environ} env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config) scan_provider(env_config)
@ -171,4 +196,23 @@ if __name__ == "__main__":
# 加载插件 # 加载插件
nonebot.load_plugins("src/plugins") nonebot.load_plugins("src/plugins")
# 在启动时,注册清理函数
import atexit
def cleanup():
# 停止输出捕获
output_manager.stop_capture()
atexit.register(cleanup)
# 启动
nonebot.run() nonebot.run()
if __name__ == "__main__":
# 利用 TZ 环境变量设定程序工作的时区
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
if platform.system().lower() != "windows":
time.tzset()
init()

Binary file not shown.

View File

@ -0,0 +1,276 @@
import sys
import threading
import time
import json
import asyncio
import websockets
from io import StringIO
from datetime import datetime
from typing import Dict, List, Any, Optional, Set, Callable
import logging
from loguru import logger
class OutputCapture:
"""捕获标准输出和错误输出的类"""
def __init__(self, original_stream):
self.original_stream = original_stream
self.buffer = StringIO()
self.listeners: List[Callable[[str], None]] = []
def write(self, data):
# 写入原始流
self.original_stream.write(data)
# 写入缓冲区
self.buffer.write(data)
# 通知所有监听器
for listener in self.listeners:
listener(data)
def flush(self):
self.original_stream.flush()
self.buffer.flush()
def add_listener(self, listener: Callable[[str], None]):
"""添加输出监听器"""
self.listeners.append(listener)
def remove_listener(self, listener: Callable[[str], None]):
"""移除输出监听器"""
if listener in self.listeners:
self.listeners.remove(listener)
class WebSocketServer:
"""WebSocket服务器用于将捕获的输出发送到前端"""
def __init__(self, host: str = "localhost", port: int = 8765):
self.host = host
self.port = port
self.server = None
self.clients: Set[websockets.WebSocketServerProtocol] = set()
self.running = False
self.loop = None
async def handler(self, websocket, path):
"""处理WebSocket连接"""
self.clients.add(websocket)
try:
await websocket.send(
json.dumps(
{
"type": "connection",
"message": "已连接到MaiMBot输出管理器",
"timestamp": datetime.now().isoformat(),
}
)
)
# 发送历史记录
history = OutputManager().get_message_history()
await websocket.send(
json.dumps(
{
"type": "history",
"messages": history,
"timestamp": datetime.now().isoformat(),
}
)
)
# 保持连接直到客户端断开
while True:
await websocket.recv()
except websockets.exceptions.ConnectionClosed:
pass
finally:
self.clients.remove(websocket)
async def broadcast(self, message: Dict[str, Any]):
"""广播消息到所有连接的客户端"""
if not self.clients:
return
message_json = json.dumps(message)
await asyncio.gather(*[client.send(message_json) for client in self.clients])
def start(self):
"""启动WebSocket服务器"""
if self.running:
return
self.running = True
# 在新线程中运行事件循环
def run_server():
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
start_server = websockets.serve(self.handler, self.host, self.port)
self.server = self.loop.run_until_complete(start_server)
logger.info(f"WebSocket服务器已启动运行在 ws://{self.host}:{self.port}")
try:
self.loop.run_forever()
except KeyboardInterrupt:
pass
finally:
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.loop.close()
self.running = False
threading.Thread(target=run_server, daemon=True).start()
def stop(self):
"""停止WebSocket服务器"""
if not self.running or not self.loop:
return
self.loop.call_soon_threadsafe(self.loop.stop)
self.running = False
logger.info("WebSocket服务器已停止")
class OutputManager:
"""单例模式的输出管理器"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super(OutputManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# 初始化属性
self.stdout_capture = None
self.stderr_capture = None
self.websocket_server = WebSocketServer()
self.message_history: List[Dict[str, Any]] = []
self.max_history_size = 1000 # 最大历史记录数量
self._initialized = True
def start_capture(self):
"""开始捕获标准输出和错误输出"""
if self.stdout_capture is None:
self.stdout_capture = OutputCapture(sys.stdout)
sys.stdout = self.stdout_capture
self.stdout_capture.add_listener(
lambda data: self._process_output("stdout", data)
)
if self.stderr_capture is None:
self.stderr_capture = OutputCapture(sys.stderr)
sys.stderr = self.stderr_capture
self.stderr_capture.add_listener(
lambda data: self._process_output("stderr", data)
)
# 启动WebSocket服务器
self.websocket_server.start()
logger.info("输出捕获已启动")
def stop_capture(self):
"""停止捕获标准输出和错误输出"""
if self.stdout_capture is not None:
sys.stdout = self.stdout_capture.original_stream
self.stdout_capture = None
if self.stderr_capture is not None:
sys.stderr = self.stderr_capture.original_stream
self.stderr_capture = None
# 停止WebSocket服务器
self.websocket_server.stop()
logger.info("输出捕获已停止")
def _process_output(self, source: str, data: str):
"""处理捕获的输出"""
if not data.strip():
return
message = {
"type": "output",
"source": source,
"content": data,
"timestamp": datetime.now().isoformat(),
}
# 添加到历史记录
self.message_history.append(message)
# 限制历史记录大小
if len(self.message_history) > self.max_history_size:
self.message_history = self.message_history[-self.max_history_size :]
# 通过WebSocket发送
if self.websocket_server.running and self.websocket_server.loop:
asyncio.run_coroutine_threadsafe(
self.websocket_server.broadcast(message), self.websocket_server.loop
)
def get_message_history(self) -> List[Dict[str, Any]]:
"""获取消息历史记录"""
return self.message_history
def send_custom_message(
self, message_type: str, content: Any, source: str = "custom"
):
"""发送自定义消息"""
message = {
"type": message_type,
"source": source,
"content": content,
"timestamp": datetime.now().isoformat(),
}
# 添加到历史记录
self.message_history.append(message)
# 限制历史记录大小
if len(self.message_history) > self.max_history_size:
self.message_history = self.message_history[-self.max_history_size :]
# 通过WebSocket发送
if self.websocket_server.running and self.websocket_server.loop:
asyncio.run_coroutine_threadsafe(
self.websocket_server.broadcast(message), self.websocket_server.loop
)
# 示例用法
if __name__ == "__main__":
# 启动输出管理器
output_manager = OutputManager()
output_manager.start_capture()
# 模拟程序输出
print("这是一条标准输出信息")
print("这是另一条标准输出信息")
# 模拟错误输出
sys.stderr.write("这是一条错误信息\n")
# 发送自定义消息
output_manager.send_custom_message(
"status", {"progress": 50, "status": "processing"}, "task_manager"
)
# 保持程序运行
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
output_manager.stop_capture()
print("程序已退出")

View File

@ -0,0 +1,47 @@
import time
import sys
from src.common.output_manager import OutputManager
def main():
# 初始化输出管理器
output_manager = OutputManager()
# 启动输出捕获
output_manager.start_capture()
print("输出管理器已启动WebSocket服务器运行在 ws://localhost:8765")
print("您可以使用任何WebSocket客户端连接到此地址来接收输出")
# 发送一些测试消息
print("这是一条标准输出消息")
sys.stderr.write("这是一条错误输出消息\n")
# 发送自定义消息
output_manager.send_custom_message(
"status", {"progress": 50, "status": "processing"}, "task_manager"
)
# 每隔1秒发送一条消息
try:
count = 0
while True:
count += 1
print(f"测试消息 #{count}")
if count % 5 == 0:
sys.stderr.write(f"错误消息 #{count}\n")
# 每10条消息发送一次自定义消息
if count % 10 == 0:
output_manager.send_custom_message(
"progress", {"count": count, "percentage": count % 100}, "counter"
)
time.sleep(1)
except KeyboardInterrupt:
# 停止输出捕获
output_manager.stop_capture()
print("程序已退出")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,58 @@
import asyncio
import websockets
import json
from datetime import datetime
async def connect_to_output_manager():
"""连接到OutputManager的WebSocket服务器并接收消息"""
uri = "ws://localhost:8765"
print(f"正在连接到 {uri}...")
try:
async with websockets.connect(uri) as websocket:
print("已连接到服务器")
while True:
# 接收消息
message = await websocket.recv()
data = json.loads(message)
# 格式化时间戳
timestamp = datetime.fromisoformat(data["timestamp"]).strftime(
"%H:%M:%S"
)
# 根据消息类型和来源显示不同颜色
if data["type"] == "output":
if data["source"] == "stdout":
# 标准输出 - 白色
print(f"\033[37m[{timestamp}] {data['content']}\033[0m", end="")
elif data["source"] == "stderr":
# 错误输出 - 红色
print(f"\033[31m[{timestamp}] {data['content']}\033[0m", end="")
elif data["type"] == "connection":
# 连接消息 - 绿色
print(f"\033[32m[{timestamp}] {data['message']}\033[0m")
elif data["type"] == "history":
# 历史记录 - 蓝色
print(
f"\033[34m[{timestamp}] 收到历史记录: {len(data['messages'])} 条消息\033[0m"
)
else:
# 其他消息 - 黄色
content = data.get("content", "")
if isinstance(content, dict):
content = json.dumps(content, ensure_ascii=False)
print(f"\033[33m[{timestamp}] [{data['type']}] {content}\033[0m")
except websockets.exceptions.ConnectionClosed:
print("与服务器的连接已断开")
except Exception as e:
print(f"发生错误: {e}")
if __name__ == "__main__":
# 运行客户端
asyncio.run(connect_to_output_manager())