mirror of https://github.com/Mai-with-u/MaiBot.git
Revert "优雅关闭,复归!(感谢豆包的代码)"
parent
4ba8c621c5
commit
0819b74931
76
bot.py
76
bot.py
|
|
@ -2,13 +2,10 @@ import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# import shutil
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
import platform
|
import platform
|
||||||
import traceback
|
import traceback
|
||||||
import signal
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
|
||||||
|
|
@ -17,8 +14,7 @@ from src.common.crash_logger import install_crash_handler
|
||||||
from src.main import MainSystem
|
from src.main import MainSystem
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from asyncio import CancelledError
|
from src.manager.async_task_manager import async_task_manager
|
||||||
# from src.manager.async_task_manager import async_task_manager
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
@ -38,19 +34,15 @@ driver = None
|
||||||
app = None
|
app = None
|
||||||
loop = None
|
loop = None
|
||||||
|
|
||||||
shutdown_requested = False # 新增全局变量
|
# shutdown_requested = False # 新增全局变量
|
||||||
|
|
||||||
|
|
||||||
async def request_shutdown() -> bool:
|
async def request_shutdown() -> bool:
|
||||||
"""请求关闭程序"""
|
"""请求关闭程序"""
|
||||||
global shutdown_requested
|
|
||||||
if shutdown_requested:
|
|
||||||
return True
|
|
||||||
shutdown_requested = True
|
|
||||||
try:
|
try:
|
||||||
if loop and not loop.is_closed():
|
if loop and not loop.is_closed():
|
||||||
try:
|
try:
|
||||||
await graceful_shutdown()
|
loop.run_until_complete(graceful_shutdown())
|
||||||
except Exception as ge: # 捕捉优雅关闭时可能发生的错误
|
except Exception as ge: # 捕捉优雅关闭时可能发生的错误
|
||||||
logger.error(f"优雅关闭时发生错误: {ge}")
|
logger.error(f"优雅关闭时发生错误: {ge}")
|
||||||
return False
|
return False
|
||||||
|
|
@ -73,38 +65,6 @@ def easter_egg():
|
||||||
print(rainbow_text)
|
print(rainbow_text)
|
||||||
|
|
||||||
|
|
||||||
# def init_config():
|
|
||||||
# # 初次启动检测
|
|
||||||
# if not os.path.exists("config/bot_config.toml"):
|
|
||||||
# logger.warning("检测到bot_config.toml不存在,正在从模板复制")
|
|
||||||
|
|
||||||
# # 检查config目录是否存在
|
|
||||||
# if not os.path.exists("config"):
|
|
||||||
# os.makedirs("config")
|
|
||||||
# logger.info("创建config目录")
|
|
||||||
|
|
||||||
# shutil.copy("template/bot_config_template.toml", "config/bot_config.toml")
|
|
||||||
# logger.info("复制完成,请修改config/bot_config.toml和.env中的配置后重新启动")
|
|
||||||
# if not os.path.exists("config/lpmm_config.toml"):
|
|
||||||
# logger.warning("检测到lpmm_config.toml不存在,正在从模板复制")
|
|
||||||
|
|
||||||
# # 检查config目录是否存在
|
|
||||||
# if not os.path.exists("config"):
|
|
||||||
# os.makedirs("config")
|
|
||||||
# logger.info("创建config目录")
|
|
||||||
|
|
||||||
# shutil.copy("template/lpmm_config_template.toml", "config/lpmm_config.toml")
|
|
||||||
# logger.info("复制完成,请修改config/lpmm_config.toml和.env中的配置后重新启动")
|
|
||||||
|
|
||||||
|
|
||||||
# def init_env():
|
|
||||||
# # 检测.env文件是否存在
|
|
||||||
# if not os.path.exists(".env"):
|
|
||||||
# logger.error("检测到.env文件不存在")
|
|
||||||
# shutil.copy("template/template.env", "./.env")
|
|
||||||
# logger.info("已从template/template.env复制创建.env,请修改配置后重新启动")
|
|
||||||
|
|
||||||
|
|
||||||
def load_env():
|
def load_env():
|
||||||
# 直接加载生产环境变量配置
|
# 直接加载生产环境变量配置
|
||||||
if os.path.exists(".env"):
|
if os.path.exists(".env"):
|
||||||
|
|
@ -149,7 +109,10 @@ def scan_provider(env_config: dict):
|
||||||
async def graceful_shutdown():
|
async def graceful_shutdown():
|
||||||
try:
|
try:
|
||||||
logger.info("正在优雅关闭麦麦...")
|
logger.info("正在优雅关闭麦麦...")
|
||||||
# await async_task_manager.stop_and_wait_all_tasks()
|
|
||||||
|
# 停止所有异步任务
|
||||||
|
await async_task_manager.stop_and_wait_all_tasks()
|
||||||
|
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
@ -245,9 +208,9 @@ def raw_main():
|
||||||
|
|
||||||
check_eula()
|
check_eula()
|
||||||
print("检查EULA和隐私条款完成")
|
print("检查EULA和隐私条款完成")
|
||||||
|
|
||||||
easter_egg()
|
easter_egg()
|
||||||
# init_config()
|
|
||||||
# init_env()
|
|
||||||
load_env()
|
load_env()
|
||||||
|
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
|
|
@ -257,11 +220,6 @@ def raw_main():
|
||||||
return MainSystem()
|
return MainSystem()
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
"""信号处理函数,捕获SIGINT和SIGTERM信号"""
|
|
||||||
loop.create_task(request_shutdown())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
exit_code = 0 # 用于记录程序最终的退出状态
|
exit_code = 0 # 用于记录程序最终的退出状态
|
||||||
try:
|
try:
|
||||||
|
|
@ -272,22 +230,18 @@ if __name__ == "__main__":
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
# 新增:设置信号处理
|
|
||||||
if platform.system().lower() != "windows":
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 执行初始化和任务调度
|
# 执行初始化和任务调度
|
||||||
loop.run_until_complete(main_system.initialize())
|
loop.run_until_complete(main_system.initialize())
|
||||||
loop.run_until_complete(main_system.schedule_tasks())
|
loop.run_until_complete(main_system.schedule_tasks())
|
||||||
loop.run_forever()
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
# loop.run_until_complete(global_api.stop())
|
||||||
logger.warning("收到中断信号,正在优雅关闭...")
|
logger.warning("收到中断信号,正在优雅关闭...")
|
||||||
try:
|
if loop and not loop.is_closed():
|
||||||
loop.run_until_complete(request_shutdown())
|
try:
|
||||||
except CancelledError as e:
|
loop.run_until_complete(graceful_shutdown())
|
||||||
logger.error(f"优雅关闭时发生错误: {e}")
|
except Exception as ge: # 捕捉优雅关闭时可能发生的错误
|
||||||
|
logger.error(f"优雅关闭时发生错误: {ge}")
|
||||||
# 新增:检测外部请求关闭
|
# 新增:检测外部请求关闭
|
||||||
|
|
||||||
# except Exception as e: # 将主异常捕获移到外层 try...except
|
# except Exception as e: # 将主异常捕获移到外层 try...except
|
||||||
|
|
|
||||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
|
|
@ -113,31 +113,3 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||||
if not tool_class:
|
if not tool_class:
|
||||||
return None
|
return None
|
||||||
return tool_class()
|
return tool_class()
|
||||||
|
|
||||||
|
|
||||||
def run_lua_code(lua_code: str):
|
|
||||||
"""兼容Lua代码运行(小工具)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
lua_code (str): Lua代码
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_LuaTable: Lua运行时的全局变量
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from lupa import LuaRuntime
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError("无法导入lupa模块,请确保已安装lupa库(可通过pip安装)") from e
|
|
||||||
|
|
||||||
lua = LuaRuntime(unpack_returned_tuples=True)
|
|
||||||
try:
|
|
||||||
lua.execute(lua_code)
|
|
||||||
return lua.globals()
|
|
||||||
except Exception as e:
|
|
||||||
# 返回包含错误信息的字典而不是直接抛出异常,保持函数接口的稳定性
|
|
||||||
return {
|
|
||||||
"__error__": True,
|
|
||||||
"type": type(e).__name__,
|
|
||||||
"message": "Lua 代码执行出错,请检查代码是否正确。",
|
|
||||||
"lua_code": lua_code, # 可选:返回出错的代码片段便于调试
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from src.tools.tool_can_use.base_tool import BaseTool, run_lua_code
|
from src.tools.tool_can_use.base_tool import BaseTool
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -31,15 +31,13 @@ class CompareNumbersTool(BaseTool):
|
||||||
try:
|
try:
|
||||||
num1 = function_args.get("num1")
|
num1 = function_args.get("num1")
|
||||||
num2 = function_args.get("num2")
|
num2 = function_args.get("num2")
|
||||||
if not (isinstance(num1, (int, float)) and isinstance(num2, (int, float))):
|
|
||||||
raise ValueError("参数'num1'和'num2'必须为数字")
|
if num1 > num2:
|
||||||
lua_code = """
|
result = f"{num1} 大于 {num2}"
|
||||||
function CompareNumbers(a, b)
|
elif num1 < num2:
|
||||||
return a .. (a > b and " 大于 " or a < b and " 小于 " or " 等于 ") .. b
|
result = f"{num1} 小于 {num2}"
|
||||||
end
|
else:
|
||||||
"""
|
result = f"{num1} 等于 {num2}"
|
||||||
CompareNumbers = run_lua_code(lua_code).CompareNumbers
|
|
||||||
result = CompareNumbers(num1, num2)
|
|
||||||
|
|
||||||
return {"type": "comparison_result", "id": f"{num1}_vs_{num2}", "content": result}
|
return {"type": "comparison_result", "id": f"{num1}_vs_{num2}", "content": result}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
||||||
from src.tools.tool_can_use.base_tool import BaseTool, run_lua_code
|
|
||||||
from src.common.logger_manager import get_logger
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
logger = get_logger("get_time_date")
|
|
||||||
|
|
||||||
|
|
||||||
class GetCurrentDateTimeTool(BaseTool):
|
|
||||||
"""获取当前时间、日期、年份和星期的工具"""
|
|
||||||
|
|
||||||
name = "get_current_date_time"
|
|
||||||
description = "当有人询问或者涉及到具体时间或者日期的时候,必须使用这个工具"
|
|
||||||
parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行获取当前时间、日期、年份和星期
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具参数(此工具不使用)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
lua_code = """
|
|
||||||
GetCurrentDateTime = function()
|
|
||||||
return ("当前时间: %s, 日期: %s, 年份: %s, 星期: %s"):format(os.date("%H:%M:%S"), os.date("%Y-%m-%d"), os.date("%Y"), os.date("%A"))
|
|
||||||
end
|
|
||||||
"""
|
|
||||||
GetCurrentDateTime = run_lua_code(lua_code).GetCurrentDateTime
|
|
||||||
return {
|
|
||||||
"name": "get_current_date_time",
|
|
||||||
"content": GetCurrentDateTime(),
|
|
||||||
}
|
|
||||||
|
|
@ -1,54 +0,0 @@
|
||||||
import re
|
|
||||||
from src.tools.tool_can_use.base_tool import BaseTool, run_lua_code
|
|
||||||
from src.common.logger import get_module_logger
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
logger = get_module_logger("letter_count_tool")
|
|
||||||
|
|
||||||
|
|
||||||
class LetterCountTool(BaseTool):
|
|
||||||
"""数单词内某字母个数的工具"""
|
|
||||||
|
|
||||||
name = "word_letter_count"
|
|
||||||
description = "当有人询问你或者提到某个英文单词内有多少个某字母时,可以使用这个工具来数字母(如果传入的是中文,传入之前要将中文转为英文)"
|
|
||||||
parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"word": {"type": "string", "description": "英文单词"},
|
|
||||||
"letter": {"type": "string", "description": "英文字母"},
|
|
||||||
},
|
|
||||||
"required": ["word", "letter"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
执行数数该单词的某字母个数的函数
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具参数
|
|
||||||
message_txt: 原始消息文本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
word = function_args.get("word")
|
|
||||||
letter = function_args.get("letter")
|
|
||||||
if re.match(r"^[a-zA-Z]$", letter) is None:
|
|
||||||
raise ValueError("请输入单个英文字母")
|
|
||||||
lua_code = """
|
|
||||||
function LetterCount(inputStr, targetLetter)
|
|
||||||
local lower = (inputStr:gsub("[^"..targetLetter:lower().."]", "")):len()
|
|
||||||
local upper = (inputStr:gsub("[^"..targetLetter:upper().."]", "")):len()
|
|
||||||
return string.format("字母 %s 在字符串 %s 中出现的次数:%d个(小写), %d个(大写)", targetLetter, inputStr, lower, upper)
|
|
||||||
end
|
|
||||||
"""
|
|
||||||
LetterCount = run_lua_code(lua_code).LetterCount
|
|
||||||
return {"name": self.name, "content": LetterCount(word, letter)}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"数字母失败: {str(e)}")
|
|
||||||
return {"name": self.name, "content": f"数字母失败: {str(e)}"}
|
|
||||||
|
|
||||||
|
|
||||||
# 注册工具
|
|
||||||
# register_tool(LetterCountTool)
|
|
||||||
Loading…
Reference in New Issue