diff --git a/bot.py b/bot.py index 3737279d..20299644 100644 --- a/bot.py +++ b/bot.py @@ -2,10 +2,13 @@ import asyncio import hashlib import os import sys + +# import shutil from pathlib import Path import time import platform import traceback +import signal from dotenv import load_dotenv from src.common.logger_manager import get_logger @@ -14,7 +17,8 @@ from src.common.crash_logger import install_crash_handler from src.main import MainSystem from rich.traceback import install -from src.manager.async_task_manager import async_task_manager +from asyncio import CancelledError +# from src.manager.async_task_manager import async_task_manager install(extra_lines=3) @@ -34,15 +38,19 @@ driver = None app = None loop = None -# shutdown_requested = False # 新增全局变量 +shutdown_requested = False # 新增全局变量 async def request_shutdown() -> bool: """请求关闭程序""" + global shutdown_requested + if shutdown_requested: + return True + shutdown_requested = True try: if loop and not loop.is_closed(): try: - loop.run_until_complete(graceful_shutdown()) + await graceful_shutdown() except Exception as ge: # 捕捉优雅关闭时可能发生的错误 logger.error(f"优雅关闭时发生错误: {ge}") return False @@ -65,6 +73,38 @@ def easter_egg(): print(rainbow_text) +# def init_config(): +# # 初次启动检测 +# if not os.path.exists("config/bot_config.toml"): +# logger.warning("检测到bot_config.toml不存在,正在从模板复制") + +# # 检查config目录是否存在 +# if not os.path.exists("config"): +# os.makedirs("config") +# logger.info("创建config目录") + +# shutil.copy("template/bot_config_template.toml", "config/bot_config.toml") +# logger.info("复制完成,请修改config/bot_config.toml和.env中的配置后重新启动") +# if not os.path.exists("config/lpmm_config.toml"): +# logger.warning("检测到lpmm_config.toml不存在,正在从模板复制") + +# # 检查config目录是否存在 +# if not os.path.exists("config"): +# os.makedirs("config") +# logger.info("创建config目录") + +# shutil.copy("template/lpmm_config_template.toml", "config/lpmm_config.toml") +# logger.info("复制完成,请修改config/lpmm_config.toml和.env中的配置后重新启动") + + +# def init_env(): +# # 检测.env文件是否存在 +# if not os.path.exists(".env"): +# logger.error("检测到.env文件不存在") +# shutil.copy("template/template.env", "./.env") +# logger.info("已从template/template.env复制创建.env,请修改配置后重新启动") + + def load_env(): # 直接加载生产环境变量配置 if os.path.exists(".env"): @@ -109,10 +149,7 @@ def scan_provider(env_config: dict): async def graceful_shutdown(): try: logger.info("正在优雅关闭麦麦...") - - # 停止所有异步任务 - await async_task_manager.stop_and_wait_all_tasks() - + # await async_task_manager.stop_and_wait_all_tasks() tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] for task in tasks: task.cancel() @@ -208,9 +245,9 @@ def raw_main(): check_eula() print("检查EULA和隐私条款完成") - easter_egg() - + # init_config() + # init_env() load_env() env_config = {key: os.getenv(key) for key in os.environ} @@ -220,6 +257,11 @@ def raw_main(): return MainSystem() +def signal_handler(sig, frame): + """信号处理函数,捕获SIGINT和SIGTERM信号""" + loop.create_task(request_shutdown()) + + if __name__ == "__main__": exit_code = 0 # 用于记录程序最终的退出状态 try: @@ -230,18 +272,22 @@ if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + # 新增:设置信号处理 + if platform.system().lower() != "windows": + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + try: # 执行初始化和任务调度 loop.run_until_complete(main_system.initialize()) loop.run_until_complete(main_system.schedule_tasks()) + loop.run_forever() except KeyboardInterrupt: - # loop.run_until_complete(global_api.stop()) logger.warning("收到中断信号,正在优雅关闭...") - if loop and not loop.is_closed(): - try: - loop.run_until_complete(graceful_shutdown()) - except Exception as ge: # 捕捉优雅关闭时可能发生的错误 - logger.error(f"优雅关闭时发生错误: {ge}") + try: + loop.run_until_complete(request_shutdown()) + except CancelledError as e: + logger.error(f"优雅关闭时发生错误: {e}") # 新增:检测外部请求关闭 # except Exception as e: # 将主异常捕获移到外层 try...except diff --git a/requirements.txt b/requirements.txt index 0e60bc19..39de65a8 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/src/tools/tool_can_use/base_tool.py b/src/tools/tool_can_use/base_tool.py index 62697168..921af185 100644 --- a/src/tools/tool_can_use/base_tool.py +++ b/src/tools/tool_can_use/base_tool.py @@ -113,3 +113,31 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]: if not tool_class: return None return tool_class() + + +def run_lua_code(lua_code: str): + """兼容Lua代码运行(小工具) + + Args: + lua_code (str): Lua代码 + + Returns: + _LuaTable: Lua运行时的全局变量 + """ + try: + from lupa import LuaRuntime + except ImportError as e: + raise ImportError("无法导入lupa模块,请确保已安装lupa库(可通过pip安装)") from e + + lua = LuaRuntime(unpack_returned_tuples=True) + try: + lua.execute(lua_code) + return lua.globals() + except Exception as e: + # 返回包含错误信息的字典而不是直接抛出异常,保持函数接口的稳定性 + return { + "__error__": True, + "type": type(e).__name__, + "message": "Lua 代码执行出错,请检查代码是否正确。", + "lua_code": lua_code, # 可选:返回出错的代码片段便于调试 + } diff --git a/src/tools/tool_can_use/compare_numbers_tool.py b/src/tools/tool_can_use/compare_numbers_tool.py index 72c7d7d1..08a50752 100644 --- a/src/tools/tool_can_use/compare_numbers_tool.py +++ b/src/tools/tool_can_use/compare_numbers_tool.py @@ -1,4 +1,4 @@ -from src.tools.tool_can_use.base_tool import BaseTool +from src.tools.tool_can_use.base_tool import BaseTool, run_lua_code from src.common.logger import get_module_logger from typing import Any @@ -31,13 +31,15 @@ class CompareNumbersTool(BaseTool): try: num1 = function_args.get("num1") num2 = function_args.get("num2") - - if num1 > num2: - result = f"{num1} 大于 {num2}" - elif num1 < num2: - result = f"{num1} 小于 {num2}" - else: - result = f"{num1} 等于 {num2}" + if not (isinstance(num1, (int, float)) and isinstance(num2, (int, float))): + raise ValueError("参数'num1'和'num2'必须为数字") + lua_code = """ + function CompareNumbers(a, b) + return a .. (a > b and " 大于 " or a < b and " 小于 " or " 等于 ") .. b + end + """ + CompareNumbers = run_lua_code(lua_code).CompareNumbers + result = CompareNumbers(num1, num2) return {"type": "comparison_result", "id": f"{num1}_vs_{num2}", "content": result} except Exception as e: diff --git a/src/tools/tool_can_use/get_time_date.py b/src/tools/tool_can_use/get_time_date.py new file mode 100644 index 00000000..f42c15d5 --- /dev/null +++ b/src/tools/tool_can_use/get_time_date.py @@ -0,0 +1,37 @@ +from src.tools.tool_can_use.base_tool import BaseTool, run_lua_code +from src.common.logger_manager import get_logger +from typing import Dict, Any + +logger = get_logger("get_time_date") + + +class GetCurrentDateTimeTool(BaseTool): + """获取当前时间、日期、年份和星期的工具""" + + name = "get_current_date_time" + description = "当有人询问或者涉及到具体时间或者日期的时候,必须使用这个工具" + parameters = { + "type": "object", + "properties": {}, + "required": [], + } + + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + """执行获取当前时间、日期、年份和星期 + + Args: + function_args: 工具参数(此工具不使用) + + Returns: + Dict: 工具执行结果 + """ + lua_code = """ + GetCurrentDateTime = function() + return ("当前时间: %s, 日期: %s, 年份: %s, 星期: %s"):format(os.date("%H:%M:%S"), os.date("%Y-%m-%d"), os.date("%Y"), os.date("%A")) + end + """ + GetCurrentDateTime = run_lua_code(lua_code).GetCurrentDateTime + return { + "name": "get_current_date_time", + "content": GetCurrentDateTime(), + } diff --git a/src/tools/tool_can_use/letter_count_tool.py b/src/tools/tool_can_use/letter_count_tool.py new file mode 100644 index 00000000..021f2d9b --- /dev/null +++ b/src/tools/tool_can_use/letter_count_tool.py @@ -0,0 +1,54 @@ +import re +from src.tools.tool_can_use.base_tool import BaseTool, run_lua_code +from src.common.logger import get_module_logger +from typing import Dict, Any + +logger = get_module_logger("letter_count_tool") + + +class LetterCountTool(BaseTool): + """数单词内某字母个数的工具""" + + name = "word_letter_count" + description = "当有人询问你或者提到某个英文单词内有多少个某字母时,可以使用这个工具来数字母(如果传入的是中文,传入之前要将中文转为英文)" + parameters = { + "type": "object", + "properties": { + "word": {"type": "string", "description": "英文单词"}, + "letter": {"type": "string", "description": "英文字母"}, + }, + "required": ["word", "letter"], + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """ + 执行数数该单词的某字母个数的函数 + + Args: + function_args: 工具参数 + message_txt: 原始消息文本 + + Returns: + Dict: 工具执行结果 + """ + try: + word = function_args.get("word") + letter = function_args.get("letter") + if re.match(r"^[a-zA-Z]$", letter) is None: + raise ValueError("请输入单个英文字母") + lua_code = """ + function LetterCount(inputStr, targetLetter) + local lower = (inputStr:gsub("[^"..targetLetter:lower().."]", "")):len() + local upper = (inputStr:gsub("[^"..targetLetter:upper().."]", "")):len() + return string.format("字母 %s 在字符串 %s 中出现的次数:%d个(小写), %d个(大写)", targetLetter, inputStr, lower, upper) + end + """ + LetterCount = run_lua_code(lua_code).LetterCount + return {"name": self.name, "content": LetterCount(word, letter)} + except Exception as e: + logger.error(f"数字母失败: {str(e)}") + return {"name": self.name, "content": f"数字母失败: {str(e)}"} + + +# 注册工具 +# register_tool(LetterCountTool)