mirror of https://github.com/Mai-with-u/MaiBot.git
commit
7f2570a429
76
bot.py
76
bot.py
|
|
@ -2,10 +2,13 @@ import asyncio
|
|||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
# import shutil
|
||||
from pathlib import Path
|
||||
import time
|
||||
import platform
|
||||
import traceback
|
||||
import signal
|
||||
from dotenv import load_dotenv
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
|
|
@ -14,7 +17,8 @@ from src.common.crash_logger import install_crash_handler
|
|||
from src.main import MainSystem
|
||||
from rich.traceback import install
|
||||
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from asyncio import CancelledError
|
||||
# from src.manager.async_task_manager import async_task_manager
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
|
@ -34,15 +38,19 @@ driver = None
|
|||
app = None
|
||||
loop = None
|
||||
|
||||
# shutdown_requested = False # 新增全局变量
|
||||
shutdown_requested = False # 新增全局变量
|
||||
|
||||
|
||||
async def request_shutdown() -> bool:
|
||||
"""请求关闭程序"""
|
||||
global shutdown_requested
|
||||
if shutdown_requested:
|
||||
return True
|
||||
shutdown_requested = True
|
||||
try:
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
loop.run_until_complete(graceful_shutdown())
|
||||
await graceful_shutdown()
|
||||
except Exception as ge: # 捕捉优雅关闭时可能发生的错误
|
||||
logger.error(f"优雅关闭时发生错误: {ge}")
|
||||
return False
|
||||
|
|
@ -65,6 +73,38 @@ def easter_egg():
|
|||
print(rainbow_text)
|
||||
|
||||
|
||||
# def init_config():
|
||||
# # 初次启动检测
|
||||
# if not os.path.exists("config/bot_config.toml"):
|
||||
# logger.warning("检测到bot_config.toml不存在,正在从模板复制")
|
||||
|
||||
# # 检查config目录是否存在
|
||||
# if not os.path.exists("config"):
|
||||
# os.makedirs("config")
|
||||
# logger.info("创建config目录")
|
||||
|
||||
# shutil.copy("template/bot_config_template.toml", "config/bot_config.toml")
|
||||
# logger.info("复制完成,请修改config/bot_config.toml和.env中的配置后重新启动")
|
||||
# if not os.path.exists("config/lpmm_config.toml"):
|
||||
# logger.warning("检测到lpmm_config.toml不存在,正在从模板复制")
|
||||
|
||||
# # 检查config目录是否存在
|
||||
# if not os.path.exists("config"):
|
||||
# os.makedirs("config")
|
||||
# logger.info("创建config目录")
|
||||
|
||||
# shutil.copy("template/lpmm_config_template.toml", "config/lpmm_config.toml")
|
||||
# logger.info("复制完成,请修改config/lpmm_config.toml和.env中的配置后重新启动")
|
||||
|
||||
|
||||
# def init_env():
|
||||
# # 检测.env文件是否存在
|
||||
# if not os.path.exists(".env"):
|
||||
# logger.error("检测到.env文件不存在")
|
||||
# shutil.copy("template/template.env", "./.env")
|
||||
# logger.info("已从template/template.env复制创建.env,请修改配置后重新启动")
|
||||
|
||||
|
||||
def load_env():
|
||||
# 直接加载生产环境变量配置
|
||||
if os.path.exists(".env"):
|
||||
|
|
@ -109,10 +149,7 @@ def scan_provider(env_config: dict):
|
|||
async def graceful_shutdown():
|
||||
try:
|
||||
logger.info("正在优雅关闭麦麦...")
|
||||
|
||||
# 停止所有异步任务
|
||||
await async_task_manager.stop_and_wait_all_tasks()
|
||||
|
||||
# await async_task_manager.stop_and_wait_all_tasks()
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
|
@ -208,9 +245,9 @@ def raw_main():
|
|||
|
||||
check_eula()
|
||||
print("检查EULA和隐私条款完成")
|
||||
|
||||
easter_egg()
|
||||
|
||||
# init_config()
|
||||
# init_env()
|
||||
load_env()
|
||||
|
||||
env_config = {key: os.getenv(key) for key in os.environ}
|
||||
|
|
@ -220,6 +257,11 @@ def raw_main():
|
|||
return MainSystem()
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
"""信号处理函数,捕获SIGINT和SIGTERM信号"""
|
||||
loop.create_task(request_shutdown())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = 0 # 用于记录程序最终的退出状态
|
||||
try:
|
||||
|
|
@ -230,18 +272,22 @@ if __name__ == "__main__":
|
|||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 新增:设置信号处理
|
||||
if platform.system().lower() != "windows":
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
# 执行初始化和任务调度
|
||||
loop.run_until_complete(main_system.initialize())
|
||||
loop.run_until_complete(main_system.schedule_tasks())
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
# loop.run_until_complete(global_api.stop())
|
||||
logger.warning("收到中断信号,正在优雅关闭...")
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
loop.run_until_complete(graceful_shutdown())
|
||||
except Exception as ge: # 捕捉优雅关闭时可能发生的错误
|
||||
logger.error(f"优雅关闭时发生错误: {ge}")
|
||||
try:
|
||||
loop.run_until_complete(request_shutdown())
|
||||
except CancelledError as e:
|
||||
logger.error(f"优雅关闭时发生错误: {e}")
|
||||
# 新增:检测外部请求关闭
|
||||
|
||||
# except Exception as e: # 将主异常捕获移到外层 try...except
|
||||
|
|
|
|||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
|
|
@ -113,3 +113,31 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
|||
if not tool_class:
|
||||
return None
|
||||
return tool_class()
|
||||
|
||||
|
||||
def run_lua_code(lua_code: str):
|
||||
"""兼容Lua代码运行(小工具)
|
||||
|
||||
Args:
|
||||
lua_code (str): Lua代码
|
||||
|
||||
Returns:
|
||||
_LuaTable: Lua运行时的全局变量
|
||||
"""
|
||||
try:
|
||||
from lupa import LuaRuntime
|
||||
except ImportError as e:
|
||||
raise ImportError("无法导入lupa模块,请确保已安装lupa库(可通过pip安装)") from e
|
||||
|
||||
lua = LuaRuntime(unpack_returned_tuples=True)
|
||||
try:
|
||||
lua.execute(lua_code)
|
||||
return lua.globals()
|
||||
except Exception as e:
|
||||
# 返回包含错误信息的字典而不是直接抛出异常,保持函数接口的稳定性
|
||||
return {
|
||||
"__error__": True,
|
||||
"type": type(e).__name__,
|
||||
"message": "Lua 代码执行出错,请检查代码是否正确。",
|
||||
"lua_code": lua_code, # 可选:返回出错的代码片段便于调试
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.tools.tool_can_use.base_tool import BaseTool, run_lua_code
|
||||
from src.common.logger import get_module_logger
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -31,13 +31,15 @@ class CompareNumbersTool(BaseTool):
|
|||
try:
|
||||
num1 = function_args.get("num1")
|
||||
num2 = function_args.get("num2")
|
||||
|
||||
if num1 > num2:
|
||||
result = f"{num1} 大于 {num2}"
|
||||
elif num1 < num2:
|
||||
result = f"{num1} 小于 {num2}"
|
||||
else:
|
||||
result = f"{num1} 等于 {num2}"
|
||||
if not (isinstance(num1, (int, float)) and isinstance(num2, (int, float))):
|
||||
raise ValueError("参数'num1'和'num2'必须为数字")
|
||||
lua_code = """
|
||||
function CompareNumbers(a, b)
|
||||
return a .. (a > b and " 大于 " or a < b and " 小于 " or " 等于 ") .. b
|
||||
end
|
||||
"""
|
||||
CompareNumbers = run_lua_code(lua_code).CompareNumbers
|
||||
result = CompareNumbers(num1, num2)
|
||||
|
||||
return {"type": "comparison_result", "id": f"{num1}_vs_{num2}", "content": result}
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
from src.tools.tool_can_use.base_tool import BaseTool, run_lua_code
|
||||
from src.common.logger_manager import get_logger
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = get_logger("get_time_date")
|
||||
|
||||
|
||||
class GetCurrentDateTimeTool(BaseTool):
|
||||
"""获取当前时间、日期、年份和星期的工具"""
|
||||
|
||||
name = "get_current_date_time"
|
||||
description = "当有人询问或者涉及到具体时间或者日期的时候,必须使用这个工具"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行获取当前时间、日期、年份和星期
|
||||
|
||||
Args:
|
||||
function_args: 工具参数(此工具不使用)
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
lua_code = """
|
||||
GetCurrentDateTime = function()
|
||||
return ("当前时间: %s, 日期: %s, 年份: %s, 星期: %s"):format(os.date("%H:%M:%S"), os.date("%Y-%m-%d"), os.date("%Y"), os.date("%A"))
|
||||
end
|
||||
"""
|
||||
GetCurrentDateTime = run_lua_code(lua_code).GetCurrentDateTime
|
||||
return {
|
||||
"name": "get_current_date_time",
|
||||
"content": GetCurrentDateTime(),
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
import re
|
||||
from src.tools.tool_can_use.base_tool import BaseTool, run_lua_code
|
||||
from src.common.logger import get_module_logger
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = get_module_logger("letter_count_tool")
|
||||
|
||||
|
||||
class LetterCountTool(BaseTool):
|
||||
"""数单词内某字母个数的工具"""
|
||||
|
||||
name = "word_letter_count"
|
||||
description = "当有人询问你或者提到某个英文单词内有多少个某字母时,可以使用这个工具来数字母(如果传入的是中文,传入之前要将中文转为英文)"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"word": {"type": "string", "description": "英文单词"},
|
||||
"letter": {"type": "string", "description": "英文字母"},
|
||||
},
|
||||
"required": ["word", "letter"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
执行数数该单词的某字母个数的函数
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
word = function_args.get("word")
|
||||
letter = function_args.get("letter")
|
||||
if re.match(r"^[a-zA-Z]$", letter) is None:
|
||||
raise ValueError("请输入单个英文字母")
|
||||
lua_code = """
|
||||
function LetterCount(inputStr, targetLetter)
|
||||
local lower = (inputStr:gsub("[^"..targetLetter:lower().."]", "")):len()
|
||||
local upper = (inputStr:gsub("[^"..targetLetter:upper().."]", "")):len()
|
||||
return string.format("字母 %s 在字符串 %s 中出现的次数:%d个(小写), %d个(大写)", targetLetter, inputStr, lower, upper)
|
||||
end
|
||||
"""
|
||||
LetterCount = run_lua_code(lua_code).LetterCount
|
||||
return {"name": self.name, "content": LetterCount(word, letter)}
|
||||
except Exception as e:
|
||||
logger.error(f"数字母失败: {str(e)}")
|
||||
return {"name": self.name, "content": f"数字母失败: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
# register_tool(LetterCountTool)
|
||||
Loading…
Reference in New Issue