mirror of https://github.com/Mai-with-u/MaiBot.git
添加输出管理器功能
parent
62523409d1
commit
a96bc53178
84
bot.py
84
bot.py
|
|
@ -8,18 +8,31 @@ from dotenv import load_dotenv
|
|||
from loguru import logger
|
||||
from nonebot.adapters.onebot.v11 import Adapter
|
||||
import platform
|
||||
from src.common.output_manager import OutputManager
|
||||
|
||||
# 获取没有加载env时的环境变量
|
||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||
|
||||
# 初始化输出管理器
|
||||
output_manager = OutputManager()
|
||||
|
||||
|
||||
def easter_egg():
|
||||
# 彩蛋
|
||||
from colorama import init, Fore
|
||||
|
||||
init()
|
||||
text = "多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
|
||||
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
|
||||
text = (
|
||||
"多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
|
||||
)
|
||||
rainbow_colors = [
|
||||
Fore.RED,
|
||||
Fore.YELLOW,
|
||||
Fore.GREEN,
|
||||
Fore.CYAN,
|
||||
Fore.BLUE,
|
||||
Fore.MAGENTA,
|
||||
]
|
||||
rainbow_text = ""
|
||||
for i, char in enumerate(text):
|
||||
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
|
||||
|
|
@ -37,7 +50,9 @@ def init_config():
|
|||
logger.info("创建config目录")
|
||||
|
||||
shutil.copy("template/bot_config_template.toml", "config/bot_config.toml")
|
||||
logger.info("复制完成,请修改config/bot_config.toml和.env.prod中的配置后重新启动")
|
||||
logger.info(
|
||||
"复制完成,请修改config/bot_config.toml和.env.prod中的配置后重新启动"
|
||||
)
|
||||
|
||||
|
||||
def init_env():
|
||||
|
|
@ -66,16 +81,15 @@ def load_env():
|
|||
# 使用闭包实现对加载器的横向扩展,避免大量重复判断
|
||||
def prod():
|
||||
logger.success("加载生产环境变量配置")
|
||||
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
||||
load_dotenv(
|
||||
".env.prod", override=True
|
||||
) # override=True 允许覆盖已存在的环境变量
|
||||
|
||||
def dev():
|
||||
logger.success("加载开发环境变量配置")
|
||||
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
|
||||
|
||||
fn_map = {
|
||||
"prod": prod,
|
||||
"dev": dev
|
||||
}
|
||||
fn_map = {"prod": prod, "dev": dev}
|
||||
|
||||
env = os.getenv("ENVIRONMENT")
|
||||
logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}")
|
||||
|
|
@ -85,11 +99,17 @@ def load_env():
|
|||
|
||||
elif os.path.exists(f".env.{env}"):
|
||||
logger.success(f"加载{env}环境变量配置")
|
||||
load_dotenv(f".env.{env}", override=True) # override=True 允许覆盖已存在的环境变量
|
||||
load_dotenv(
|
||||
f".env.{env}", override=True
|
||||
) # override=True 允许覆盖已存在的环境变量
|
||||
|
||||
else:
|
||||
logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
|
||||
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
|
||||
logger.error(
|
||||
f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在"
|
||||
)
|
||||
RuntimeError(
|
||||
f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在"
|
||||
)
|
||||
|
||||
|
||||
def load_logger():
|
||||
|
|
@ -97,10 +117,10 @@ def load_logger():
|
|||
logger.add(
|
||||
sys.stderr,
|
||||
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
|
||||
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
|
||||
"#777777>-</> <level>{message}</level>",
|
||||
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
|
||||
"#777777>-</> <level>{message}</level>",
|
||||
colorize=True,
|
||||
level=os.getenv("LOG_LEVEL", "DEBUG") # 根据环境设置日志级别,默认为INFO
|
||||
level=os.getenv("LOG_LEVEL", "DEBUG"), # 根据环境设置日志级别,默认为INFO
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -131,25 +151,30 @@ def scan_provider(env_config: dict):
|
|||
# 检查每个 provider 是否同时存在 url 和 key
|
||||
for provider_name, config in provider.items():
|
||||
if config["url"] is None or config["key"] is None:
|
||||
logger.error(
|
||||
f"provider 内容:{config}\n"
|
||||
f"env_config 内容:{env_config}"
|
||||
logger.error(f"provider 内容:{config}\n" f"env_config 内容:{env_config}")
|
||||
raise ValueError(
|
||||
f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量"
|
||||
)
|
||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 利用 TZ 环境变量设定程序工作的时区
|
||||
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
|
||||
if platform.system().lower() != 'windows':
|
||||
if platform.system().lower() != "windows":
|
||||
time.tzset()
|
||||
|
||||
# 启动输出捕获
|
||||
output_manager.start_capture()
|
||||
logger.info("输出管理器已启动")
|
||||
|
||||
easter_egg()
|
||||
load_logger()
|
||||
init_config()
|
||||
init_env()
|
||||
load_env()
|
||||
load_logger()
|
||||
|
||||
# 启动彩蛋
|
||||
easter_egg()
|
||||
|
||||
env_config = {key: os.getenv(key) for key in os.environ}
|
||||
scan_provider(env_config)
|
||||
|
|
@ -171,4 +196,23 @@ if __name__ == "__main__":
|
|||
# 加载插件
|
||||
nonebot.load_plugins("src/plugins")
|
||||
|
||||
# 在启动时,注册清理函数
|
||||
import atexit
|
||||
|
||||
def cleanup():
|
||||
# 停止输出捕获
|
||||
output_manager.stop_capture()
|
||||
|
||||
atexit.register(cleanup)
|
||||
|
||||
# 启动
|
||||
nonebot.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 利用 TZ 环境变量设定程序工作的时区
|
||||
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
|
||||
if platform.system().lower() != "windows":
|
||||
time.tzset()
|
||||
|
||||
init()
|
||||
|
|
|
|||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
|
|
@ -0,0 +1,276 @@
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import websockets
|
||||
from io import StringIO
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional, Set, Callable
|
||||
import logging
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class OutputCapture:
|
||||
"""捕获标准输出和错误输出的类"""
|
||||
|
||||
def __init__(self, original_stream):
|
||||
self.original_stream = original_stream
|
||||
self.buffer = StringIO()
|
||||
self.listeners: List[Callable[[str], None]] = []
|
||||
|
||||
def write(self, data):
|
||||
# 写入原始流
|
||||
self.original_stream.write(data)
|
||||
# 写入缓冲区
|
||||
self.buffer.write(data)
|
||||
# 通知所有监听器
|
||||
for listener in self.listeners:
|
||||
listener(data)
|
||||
|
||||
def flush(self):
|
||||
self.original_stream.flush()
|
||||
self.buffer.flush()
|
||||
|
||||
def add_listener(self, listener: Callable[[str], None]):
|
||||
"""添加输出监听器"""
|
||||
self.listeners.append(listener)
|
||||
|
||||
def remove_listener(self, listener: Callable[[str], None]):
|
||||
"""移除输出监听器"""
|
||||
if listener in self.listeners:
|
||||
self.listeners.remove(listener)
|
||||
|
||||
|
||||
class WebSocketServer:
|
||||
"""WebSocket服务器,用于将捕获的输出发送到前端"""
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8765):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.clients: Set[websockets.WebSocketServerProtocol] = set()
|
||||
self.running = False
|
||||
self.loop = None
|
||||
|
||||
async def handler(self, websocket, path):
|
||||
"""处理WebSocket连接"""
|
||||
self.clients.add(websocket)
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "connection",
|
||||
"message": "已连接到MaiMBot输出管理器",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# 发送历史记录
|
||||
history = OutputManager().get_message_history()
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# 保持连接直到客户端断开
|
||||
while True:
|
||||
await websocket.recv()
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
self.clients.remove(websocket)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""广播消息到所有连接的客户端"""
|
||||
if not self.clients:
|
||||
return
|
||||
|
||||
message_json = json.dumps(message)
|
||||
await asyncio.gather(*[client.send(message_json) for client in self.clients])
|
||||
|
||||
def start(self):
|
||||
"""启动WebSocket服务器"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
# 在新线程中运行事件循环
|
||||
def run_server():
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
start_server = websockets.serve(self.handler, self.host, self.port)
|
||||
|
||||
self.server = self.loop.run_until_complete(start_server)
|
||||
logger.info(f"WebSocket服务器已启动,运行在 ws://{self.host}:{self.port}")
|
||||
|
||||
try:
|
||||
self.loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
self.server.close()
|
||||
self.loop.run_until_complete(self.server.wait_closed())
|
||||
self.loop.close()
|
||||
self.running = False
|
||||
|
||||
threading.Thread(target=run_server, daemon=True).start()
|
||||
|
||||
def stop(self):
|
||||
"""停止WebSocket服务器"""
|
||||
if not self.running or not self.loop:
|
||||
return
|
||||
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
self.running = False
|
||||
logger.info("WebSocket服务器已停止")
|
||||
|
||||
|
||||
class OutputManager:
|
||||
"""单例模式的输出管理器"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(OutputManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 初始化属性
|
||||
self.stdout_capture = None
|
||||
self.stderr_capture = None
|
||||
self.websocket_server = WebSocketServer()
|
||||
self.message_history: List[Dict[str, Any]] = []
|
||||
self.max_history_size = 1000 # 最大历史记录数量
|
||||
self._initialized = True
|
||||
|
||||
def start_capture(self):
|
||||
"""开始捕获标准输出和错误输出"""
|
||||
if self.stdout_capture is None:
|
||||
self.stdout_capture = OutputCapture(sys.stdout)
|
||||
sys.stdout = self.stdout_capture
|
||||
self.stdout_capture.add_listener(
|
||||
lambda data: self._process_output("stdout", data)
|
||||
)
|
||||
|
||||
if self.stderr_capture is None:
|
||||
self.stderr_capture = OutputCapture(sys.stderr)
|
||||
sys.stderr = self.stderr_capture
|
||||
self.stderr_capture.add_listener(
|
||||
lambda data: self._process_output("stderr", data)
|
||||
)
|
||||
|
||||
# 启动WebSocket服务器
|
||||
self.websocket_server.start()
|
||||
|
||||
logger.info("输出捕获已启动")
|
||||
|
||||
def stop_capture(self):
|
||||
"""停止捕获标准输出和错误输出"""
|
||||
if self.stdout_capture is not None:
|
||||
sys.stdout = self.stdout_capture.original_stream
|
||||
self.stdout_capture = None
|
||||
|
||||
if self.stderr_capture is not None:
|
||||
sys.stderr = self.stderr_capture.original_stream
|
||||
self.stderr_capture = None
|
||||
|
||||
# 停止WebSocket服务器
|
||||
self.websocket_server.stop()
|
||||
|
||||
logger.info("输出捕获已停止")
|
||||
|
||||
def _process_output(self, source: str, data: str):
|
||||
"""处理捕获的输出"""
|
||||
if not data.strip():
|
||||
return
|
||||
|
||||
message = {
|
||||
"type": "output",
|
||||
"source": source,
|
||||
"content": data,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# 添加到历史记录
|
||||
self.message_history.append(message)
|
||||
|
||||
# 限制历史记录大小
|
||||
if len(self.message_history) > self.max_history_size:
|
||||
self.message_history = self.message_history[-self.max_history_size :]
|
||||
|
||||
# 通过WebSocket发送
|
||||
if self.websocket_server.running and self.websocket_server.loop:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.websocket_server.broadcast(message), self.websocket_server.loop
|
||||
)
|
||||
|
||||
def get_message_history(self) -> List[Dict[str, Any]]:
|
||||
"""获取消息历史记录"""
|
||||
return self.message_history
|
||||
|
||||
def send_custom_message(
|
||||
self, message_type: str, content: Any, source: str = "custom"
|
||||
):
|
||||
"""发送自定义消息"""
|
||||
message = {
|
||||
"type": message_type,
|
||||
"source": source,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# 添加到历史记录
|
||||
self.message_history.append(message)
|
||||
|
||||
# 限制历史记录大小
|
||||
if len(self.message_history) > self.max_history_size:
|
||||
self.message_history = self.message_history[-self.max_history_size :]
|
||||
|
||||
# 通过WebSocket发送
|
||||
if self.websocket_server.running and self.websocket_server.loop:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.websocket_server.broadcast(message), self.websocket_server.loop
|
||||
)
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 启动输出管理器
|
||||
output_manager = OutputManager()
|
||||
output_manager.start_capture()
|
||||
|
||||
# 模拟程序输出
|
||||
print("这是一条标准输出信息")
|
||||
print("这是另一条标准输出信息")
|
||||
|
||||
# 模拟错误输出
|
||||
sys.stderr.write("这是一条错误信息\n")
|
||||
|
||||
# 发送自定义消息
|
||||
output_manager.send_custom_message(
|
||||
"status", {"progress": 50, "status": "processing"}, "task_manager"
|
||||
)
|
||||
|
||||
# 保持程序运行
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
output_manager.stop_capture()
|
||||
print("程序已退出")
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
import time
|
||||
import sys
|
||||
from src.common.output_manager import OutputManager
|
||||
|
||||
|
||||
def main():
|
||||
# 初始化输出管理器
|
||||
output_manager = OutputManager()
|
||||
|
||||
# 启动输出捕获
|
||||
output_manager.start_capture()
|
||||
print("输出管理器已启动,WebSocket服务器运行在 ws://localhost:8765")
|
||||
print("您可以使用任何WebSocket客户端连接到此地址来接收输出")
|
||||
|
||||
# 发送一些测试消息
|
||||
print("这是一条标准输出消息")
|
||||
sys.stderr.write("这是一条错误输出消息\n")
|
||||
|
||||
# 发送自定义消息
|
||||
output_manager.send_custom_message(
|
||||
"status", {"progress": 50, "status": "processing"}, "task_manager"
|
||||
)
|
||||
|
||||
# 每隔1秒发送一条消息
|
||||
try:
|
||||
count = 0
|
||||
while True:
|
||||
count += 1
|
||||
print(f"测试消息 #{count}")
|
||||
if count % 5 == 0:
|
||||
sys.stderr.write(f"错误消息 #{count}\n")
|
||||
|
||||
# 每10条消息发送一次自定义消息
|
||||
if count % 10 == 0:
|
||||
output_manager.send_custom_message(
|
||||
"progress", {"count": count, "percentage": count % 100}, "counter"
|
||||
)
|
||||
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
# 停止输出捕获
|
||||
output_manager.stop_capture()
|
||||
print("程序已退出")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
async def connect_to_output_manager():
|
||||
"""连接到OutputManager的WebSocket服务器并接收消息"""
|
||||
|
||||
uri = "ws://localhost:8765"
|
||||
print(f"正在连接到 {uri}...")
|
||||
|
||||
try:
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print("已连接到服务器")
|
||||
|
||||
while True:
|
||||
# 接收消息
|
||||
message = await websocket.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
# 格式化时间戳
|
||||
timestamp = datetime.fromisoformat(data["timestamp"]).strftime(
|
||||
"%H:%M:%S"
|
||||
)
|
||||
|
||||
# 根据消息类型和来源显示不同颜色
|
||||
if data["type"] == "output":
|
||||
if data["source"] == "stdout":
|
||||
# 标准输出 - 白色
|
||||
print(f"\033[37m[{timestamp}] {data['content']}\033[0m", end="")
|
||||
elif data["source"] == "stderr":
|
||||
# 错误输出 - 红色
|
||||
print(f"\033[31m[{timestamp}] {data['content']}\033[0m", end="")
|
||||
elif data["type"] == "connection":
|
||||
# 连接消息 - 绿色
|
||||
print(f"\033[32m[{timestamp}] {data['message']}\033[0m")
|
||||
elif data["type"] == "history":
|
||||
# 历史记录 - 蓝色
|
||||
print(
|
||||
f"\033[34m[{timestamp}] 收到历史记录: {len(data['messages'])} 条消息\033[0m"
|
||||
)
|
||||
else:
|
||||
# 其他消息 - 黄色
|
||||
content = data.get("content", "")
|
||||
if isinstance(content, dict):
|
||||
content = json.dumps(content, ensure_ascii=False)
|
||||
print(f"\033[33m[{timestamp}] [{data['type']}] {content}\033[0m")
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print("与服务器的连接已断开")
|
||||
except Exception as e:
|
||||
print(f"发生错误: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行客户端
|
||||
asyncio.run(connect_to_output_manager())
|
||||
Loading…
Reference in New Issue