diff --git a/.gitignore b/.gitignore index 6d6652c..267374b 100644 --- a/.gitignore +++ b/.gitignore @@ -270,4 +270,5 @@ $RECYCLE.BIN/ *.lnk config.toml +config.toml.back test \ No newline at end of file diff --git a/assets/maimai.ico b/assets/maimai.ico new file mode 100644 index 0000000..578b11c Binary files /dev/null and b/assets/maimai.ico differ diff --git a/command_args.md b/command_args.md index 8a73651..cbbb582 100644 --- a/command_args.md +++ b/command_args.md @@ -5,7 +5,7 @@ Seg.type = "command" ## 群聊禁言 ```python Seg.data: Dict[str, Any] = { - "name": "GROUP_BAN" + "name": "GROUP_BAN", "args": { "qq_id": "用户QQ号", "duration": "禁言时长(秒)" @@ -16,7 +16,7 @@ Seg.data: Dict[str, Any] = { ## 群聊全体禁言 ```python Seg.data: Dict[str, Any] = { - "name": "GROUP_WHOLE_BAN" + "name": "GROUP_WHOLE_BAN", "args": { "enable": "是否开启全体禁言(True/False)" }, @@ -28,10 +28,20 @@ Seg.data: Dict[str, Any] = { ## 群聊踢人 ```python Seg.data: Dict[str, Any] = { - "name": "GROUP_KICK" + "name": "GROUP_KICK", "args": { "qq_id": "用户QQ号", }, } ``` -其中,群聊ID将会通过Group_Info.group_id自动获取。 \ No newline at end of file +其中,群聊ID将会通过Group_Info.group_id自动获取。 + +## 戳一戳 +```python +Seg,.data: Dict[str, Any] = { + "name": "SEND_POKE", + "args": { + "qq_id": "目标QQ号" + } +} +``` \ No newline at end of file diff --git a/main.py b/main.py index 50cf968..2c71ef1 100644 --- a/main.py +++ b/main.py @@ -18,7 +18,7 @@ async def message_recv(server_connection: Server.ServerConnection): async for raw_message in server_connection: logger.debug( f"{raw_message[:100]}..." - if (len(raw_message) > 100 and global_config.debug_level != "DEBUG") + if (len(raw_message) > 100 and global_config.debug.level != "DEBUG") else raw_message ) decoded_raw_message: dict = json.loads(raw_message) @@ -52,19 +52,23 @@ async def main(): async def napcat_server(): logger.info("正在启动adapter...") - async with Server.serve(message_recv, global_config.server_host, global_config.server_port) as server: - logger.info(f"Adapter已启动,监听地址: ws://{global_config.server_host}:{global_config.server_port}") + async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port) as server: + logger.info( + f"Adapter已启动,监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}" + ) await server.serve_forever() async def graceful_shutdown(): try: logger.info("正在关闭adapter...") - await mmc_stop_com() tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + if not task.done(): + task.cancel() + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 15) + await mmc_stop_com() # 后置避免神秘exception + logger.info("Adapter已成功关闭") except Exception as e: logger.error(f"Adapter关闭中出现错误: {e}") diff --git a/requirements.txt b/requirements.txt index 41e8eb1..e5964ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ requests maim_message loguru pillow -tomli \ No newline at end of file +tomli +plyer \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index a35ff0e..bfae081 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,77 +1,78 @@ -from enum import Enum - - -class MetaEventType: - lifecycle = "lifecycle" # 生命周期 - - class Lifecycle: - connect = "connect" # 生命周期 - WebSocket 连接成功 - - heartbeat = "heartbeat" # 心跳 - - -class MessageType: # 接受消息大类 - private = "private" # 私聊消息 - - class Private: - friend = "friend" # 私聊消息 - 好友 - group = "group" # 私聊消息 - 群临时 - group_self = "group_self" # 私聊消息 - 群中自身发送 - other = "other" # 私聊消息 - 其他 - - group = "group" # 群聊消息 - - class Group: - normal = "normal" # 群聊消息 - 普通 - anonymous = "anonymous" # 群聊消息 - 匿名消息 - notice = "notice" # 群聊消息 - 系统提示 - - -class NoticeType: # 通知事件 - friend_recall = "friend_recall" # 私聊消息撤回 - group_recall = "group_recall" # 群聊消息撤回 - notify = "notify" - - class Notify: - poke = "poke" # 戳一戳 - - -class RealMessageType: # 实际消息分类 - text = "text" # 纯文本 - face = "face" # qq表情 - image = "image" # 图片 - record = "record" # 语音 - video = "video" # 视频 - at = "at" # @某人 - rps = "rps" # 猜拳魔法表情 - dice = "dice" # 骰子 - shake = "shake" # 私聊窗口抖动(只收) - poke = "poke" # 群聊戳一戳 - share = "share" # 链接分享(json形式) - reply = "reply" # 回复消息 - forward = "forward" # 转发消息 - node = "node" # 转发消息节点 - - -class MessageSentType: - private = "private" - - class Private: - friend = "friend" - group = "group" - - group = "group" - - class Group: - normal = "normal" - - -class CommandType(Enum): - """命令类型""" - - GROUP_BAN = "set_group_ban" # 禁言用户 - GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 - GROUP_KICK = "set_group_kick" # 踢出群聊 - - def __str__(self) -> str: - return self.value +from enum import Enum + + +class MetaEventType: + lifecycle = "lifecycle" # 生命周期 + + class Lifecycle: + connect = "connect" # 生命周期 - WebSocket 连接成功 + + heartbeat = "heartbeat" # 心跳 + + +class MessageType: # 接受消息大类 + private = "private" # 私聊消息 + + class Private: + friend = "friend" # 私聊消息 - 好友 + group = "group" # 私聊消息 - 群临时 + group_self = "group_self" # 私聊消息 - 群中自身发送 + other = "other" # 私聊消息 - 其他 + + group = "group" # 群聊消息 + + class Group: + normal = "normal" # 群聊消息 - 普通 + anonymous = "anonymous" # 群聊消息 - 匿名消息 + notice = "notice" # 群聊消息 - 系统提示 + + +class NoticeType: # 通知事件 + friend_recall = "friend_recall" # 私聊消息撤回 + group_recall = "group_recall" # 群聊消息撤回 + notify = "notify" + + class Notify: + poke = "poke" # 戳一戳 + + +class RealMessageType: # 实际消息分类 + text = "text" # 纯文本 + face = "face" # qq表情 + image = "image" # 图片 + record = "record" # 语音 + video = "video" # 视频 + at = "at" # @某人 + rps = "rps" # 猜拳魔法表情 + dice = "dice" # 骰子 + shake = "shake" # 私聊窗口抖动(只收) + poke = "poke" # 群聊戳一戳 + share = "share" # 链接分享(json形式) + reply = "reply" # 回复消息 + forward = "forward" # 转发消息 + node = "node" # 转发消息节点 + + +class MessageSentType: + private = "private" + + class Private: + friend = "friend" + group = "group" + + group = "group" + + class Group: + normal = "normal" + + +class CommandType(Enum): + """命令类型""" + + GROUP_BAN = "set_group_ban" # 禁言用户 + GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 + GROUP_KICK = "set_group_kick" # 踢出群聊 + SEND_POKE = "send_poke" # 戳一戳 + + def __str__(self) -> str: + return self.value diff --git a/src/config.py b/src/config.py deleted file mode 100644 index ee13c98..0000000 --- a/src/config.py +++ /dev/null @@ -1,93 +0,0 @@ -import os -import sys -import tomli -import shutil -from .logger import logger -from typing import Optional - - -class Config: - platform: str = "qq" - nickname: Optional[str] = None - server_host: str = "localhost" - server_port: int = 8095 - napcat_heartbeat_interval: int = 30 - - def __init__(self): - self._get_config_path() - - def _get_config_path(self): - current_file_path = os.path.abspath(__file__) - src_path = os.path.dirname(current_file_path) - self.root_path = os.path.join(src_path, "..") - self.config_path = os.path.join(self.root_path, "config.toml") - - def load_config(self): # sourcery skip: extract-method, move-assign - include_configs = ["Napcat_Server", "MaiBot_Server", "Chat", "Voice", "Debug"] - if not os.path.exists(self.config_path): - logger.error("配置文件不存在!") - logger.info("正在创建配置文件...") - shutil.copy( - os.path.join(self.root_path, "template", "template_config.toml"), - os.path.join(self.root_path, "config.toml"), - ) - logger.info("配置文件创建成功,请修改配置文件后重启程序。") - sys.exit(1) - with open(self.config_path, "rb") as f: - try: - raw_config = tomli.load(f) - except tomli.TOMLDecodeError as e: - logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}") - sys.exit(1) - for key in include_configs: - if key not in raw_config: - logger.error(f"配置文件中缺少必需的字段: '{key}'") - logger.error("你的配置文件可能过时,请尝试手动更新配置文件。") - sys.exit(1) - - self.server_host = raw_config["Napcat_Server"].get("host", "localhost") - self.server_port = raw_config["Napcat_Server"].get("port", 8095) - self.napcat_heartbeat_interval = raw_config["Napcat_Server"].get("heartbeat", 30) - - self.mai_host = raw_config["MaiBot_Server"].get("host", "localhost") - self.mai_port = raw_config["MaiBot_Server"].get("port", 8000) - self.platform = raw_config["MaiBot_Server"].get("platform_name") - if not self.platform: - logger.critical("请在配置文件中指定平台") - sys.exit(1) - - self.group_list_type: str = raw_config["Chat"].get("group_list_type") - self.group_list: list = raw_config["Chat"].get("group_list", []) - self.private_list_type: str = raw_config["Chat"].get("private_list_type") - self.private_list: list = raw_config["Chat"].get("private_list", []) - self.ban_user_id: list = raw_config["Chat"].get("ban_user_id", []) - self.enable_poke: bool = raw_config["Chat"].get("enable_poke", True) - if self.group_list_type not in ["whitelist", "blacklist"]: - logger.critical("请在配置文件中指定group_list_type或group_list_type填写错误") - sys.exit(1) - if self.private_list_type not in ["whitelist", "blacklist"]: - logger.critical("请在配置文件中指定private_list_type或private_list_type填写错误") - sys.exit(1) - - self.use_tts = raw_config["Voice"].get("use_tts", False) - - self.debug_level = raw_config["Debug"].get("level", "INFO") - if self.debug_level == "DEBUG": - logger.debug("原始配置文件内容:") - logger.debug(raw_config) - logger.debug("读取到的配置内容:") - logger.debug(f"平台: {self.platform}") - logger.debug(f"MaiBot服务器地址: {self.mai_host}:{self.mai_port}") - logger.debug(f"Napcat服务器地址: {self.server_host}:{self.server_port}") - logger.debug(f"心跳间隔: {self.napcat_heartbeat_interval}秒") - logger.debug(f"群聊列表类型: {self.group_list_type}") - logger.debug(f"群聊列表: {self.group_list}") - logger.debug(f"私聊列表类型: {self.private_list_type}") - logger.debug(f"私聊列表: {self.private_list}") - logger.debug(f"禁用用户ID列表: {self.ban_user_id}") - logger.debug(f"是否启用TTS: {self.use_tts}") - logger.debug(f"调试级别: {self.debug_level}") - - -global_config = Config() -global_config.load_config() diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000..40ba89a --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,5 @@ +from .config import global_config + +__all__ = [ + "global_config", +] diff --git a/src/config/config.py b/src/config/config.py new file mode 100644 index 0000000..a219078 --- /dev/null +++ b/src/config/config.py @@ -0,0 +1,140 @@ +import os +from dataclasses import dataclass + +import tomlkit +import shutil + +from tomlkit import TOMLDocument +from tomlkit.items import Table +from ..logger import logger +from rich.traceback import install + +from src.config.config_base import ConfigBase +from src.config.official_configs import ( + ChatConfig, + DebugConfig, + MaiBotServerConfig, + NapcatServerConfig, + NicknameConfig, + VoiceConfig, +) + +install(extra_lines=3) + +TEMPLATE_DIR = "template" + + +def update_config(): + # 定义文件路径 + template_path = f"{TEMPLATE_DIR}/template_config.toml" + old_config_path = "config.toml" + new_config_path = "config.toml" + + # 检查配置文件是否存在 + if not os.path.exists(old_config_path): + logger.info("配置文件不存在,从模板创建新配置") + shutil.copy2(template_path, old_config_path) # 复制模板文件 + logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") + # 如果是新创建的配置文件,直接返回 + quit() + + # 读取旧配置文件和模板文件 + with open(old_config_path, "r", encoding="utf-8") as f: + old_config = tomlkit.load(f) + with open(template_path, "r", encoding="utf-8") as f: + new_config = tomlkit.load(f) + + # 检查version是否相同 + if old_config and "inner" in old_config and "inner" in new_config: + old_version = old_config["inner"].get("version") + new_version = new_config["inner"].get("version") + if old_version and new_version and old_version == new_version: + logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") + return + else: + logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") + else: + logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") + + # 备份文件名 + old_backup_path = "config.toml.back" + + # 备份旧配置文件 + shutil.move(old_config_path, old_backup_path) + logger.info(f"已备份旧配置文件到: {old_backup_path}") + + # 复制模板文件到配置目录 + shutil.copy2(template_path, new_config_path) + logger.info(f"已创建新配置文件: {new_config_path}") + + def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict): + """ + 将source字典的值更新到target字典中(如果target中存在相同的键) + """ + for key, value in source.items(): + # 跳过version字段的更新 + if key == "version": + continue + if key in target: + if isinstance(value, dict) and isinstance(target[key], (dict, Table)): + update_dict(target[key], value) + else: + try: + # 对数组类型进行特殊处理 + if isinstance(value, list): + # 如果是空数组,确保它保持为空数组 + target[key] = tomlkit.array(str(value)) if value else tomlkit.array() + else: + # 其他类型使用item方法创建新值 + target[key] = tomlkit.item(value) + except (TypeError, ValueError): + # 如果转换失败,直接赋值 + target[key] = value + + # 将旧配置的值更新到新配置中 + logger.info("开始合并新旧配置...") + update_dict(new_config, old_config) + + # 保存更新后的配置(保留注释和格式) + with open(new_config_path, "w", encoding="utf-8") as f: + f.write(tomlkit.dumps(new_config)) + logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") + quit() + + +@dataclass +class Config(ConfigBase): + """总配置类""" + + nickname: NicknameConfig + napcat_server: NapcatServerConfig + maibot_server: MaiBotServerConfig + chat: ChatConfig + voice: VoiceConfig + debug: DebugConfig + + +def load_config(config_path: str) -> Config: + """ + 加载配置文件 + :param config_path: 配置文件路径 + :return: Config对象 + """ + # 读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + config_data = tomlkit.load(f) + + # 创建Config对象 + try: + return Config.from_dict(config_data) + except Exception as e: + logger.critical("配置文件解析失败") + raise e + + +# 更新配置 +update_config() + +logger.info("正在品鉴配置文件...") +global_config = load_config(config_path="config.toml") +logger.info("非常的新鲜,非常的美味!") diff --git a/src/config/config_base.py b/src/config/config_base.py new file mode 100644 index 0000000..87cb079 --- /dev/null +++ b/src/config/config_base.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass, fields, MISSING +from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union + +T = TypeVar("T", bound="ConfigBase") + +TOML_DICT_TYPE = { + int, + float, + str, + bool, + list, + dict, +} + + +@dataclass +class ConfigBase: + """配置类的基类""" + + @classmethod + def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: + """从字典加载配置字段""" + if not isinstance(data, dict): + raise TypeError(f"Expected a dictionary, got {type(data).__name__}") + + init_args: Dict[str, Any] = {} + + for f in fields(cls): + field_name = f.name + field_type = f.type + if field_name.startswith("_"): + # 跳过以 _ 开头的字段 + continue + + if field_name not in data: + if f.default is not MISSING or f.default_factory is not MISSING: + # 跳过未提供且有默认值/默认构造方法的字段 + continue + else: + raise ValueError(f"Missing required field: '{field_name}'") + + value = data[field_name] + try: + init_args[field_name] = cls._convert_field(value, field_type) + except TypeError as e: + raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e + except Exception as e: + raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e + + return cls(**init_args) + + @classmethod + def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: + """ + 转换字段值为指定类型 + + 1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法 + 2. 对于泛型集合类型(list, set, tuple),递归转换每个元素 + 3. 对于基础类型(int, str, float, bool),直接转换 + 4. 对于其他类型,尝试直接转换,如果失败则抛出异常 + """ + # 如果是嵌套的 dataclass,递归调用 from_dict 方法 + if isinstance(field_type, type) and issubclass(field_type, ConfigBase): + return field_type.from_dict(value) + + field_origin_type = get_origin(field_type) + field_args_type = get_args(field_type) + + # 处理泛型集合类型(list, set, tuple) + if field_origin_type in {list, set, tuple}: + # 检查提供的value是否为list + if not isinstance(value, list): + raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}") + + if field_origin_type is list: + return [cls._convert_field(item, field_args_type[0]) for item in value] + if field_origin_type is set: + return {cls._convert_field(item, field_args_type[0]) for item in value} + if field_origin_type is tuple: + # 检查提供的value长度是否与类型参数一致 + if len(value) != len(field_args_type): + raise TypeError( + f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}" + ) + return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type)) + + if field_origin_type is dict: + # 检查提供的value是否为dict + if not isinstance(value, dict): + raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") + + # 检查字典的键值类型 + if len(field_args_type) != 2: + raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}") + key_type, value_type = field_args_type + + return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()} + + # 处理Optional类型 + if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union + if value is None: + return None + # 如果有数据,检查实际类型 + if type(value) not in field_args_type: + raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}") + return cls._convert_field(value, field_args_type[0]) + + # 处理int, str, float, bool等基础类型 + if field_origin_type is None: + if isinstance(value, field_type): + return field_type(value) + else: + raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}") + + # 处理Literal类型 + if field_origin_type is Literal: + # 获取Literal的允许值 + allowed_values = get_args(field_type) + if value in allowed_values: + return value + else: + raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type") + + # 处理其他类型 + if field_type is Any: + return value + + # 其他类型直接转换 + try: + return field_type(value) + except (ValueError, TypeError) as e: + raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e + + def __str__(self): + """返回配置类的字符串表示""" + return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})" diff --git a/src/config/official_configs.py b/src/config/official_configs.py new file mode 100644 index 0000000..d8928a8 --- /dev/null +++ b/src/config/official_configs.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass, field +from typing import Literal + +from src.config.config_base import ConfigBase + +""" +须知: +1. 本文件中记录了所有的配置项 +2. 所有新增的class都需要继承自ConfigBase +3. 所有新增的class都应在config.py中的Config类中添加字段 +4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default +""" + +ADAPTER_PLATFORM = "qq" + + +@dataclass +class NicknameConfig(ConfigBase): + nickname: str + """机器人昵称""" + + +@dataclass +class NapcatServerConfig(ConfigBase): + host: str = "localhost" + """Napcat服务端的主机地址""" + + port: int = 8095 + """Napcat服务端的端口号""" + + heartbeat_interval: int = 30 + """Napcat心跳间隔时间,单位为秒""" + + +@dataclass +class MaiBotServerConfig(ConfigBase): + platform_name: str = field(default=ADAPTER_PLATFORM, init=False) + """平台名称,“qq”""" + + host: str = "localhost" + """MaiMCore的主机地址""" + + port: int = 8000 + """MaiMCore的端口号""" + + +@dataclass +class ChatConfig(ConfigBase): + group_list_type: Literal["whitelist", "blacklist"] = "whitelist" + """群聊列表类型 白名单/黑名单""" + + group_list: list[int] = field(default_factory=[]) + """群聊列表""" + + private_list_type: Literal["whitelist", "blacklist"] = "whitelist" + """私聊列表类型 白名单/黑名单""" + + private_list: list[int] = field(default_factory=[]) + """私聊列表""" + + ban_user_id: list[int] = field(default_factory=[]) + """被封禁的用户ID列表,封禁后将无法与其进行交互""" + + enable_poke: bool = True + """是否启用戳一戳功能""" + + +@dataclass +class VoiceConfig(ConfigBase): + use_tts: bool = False + """是否启用TTS功能""" + + +@dataclass +class DebugConfig(ConfigBase): + level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + """日志级别,默认为INFO""" diff --git a/src/logger.py b/src/logger.py index 3acba4f..8071ff7 100644 --- a/src/logger.py +++ b/src/logger.py @@ -5,6 +5,6 @@ import sys logger.remove() logger.add( sys.stderr, - level=global_config.debug_level, + level=global_config.debug.level, format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", ) diff --git a/src/mmc_com_layer.py b/src/mmc_com_layer.py index 174ef1f..ab50cca 100644 --- a/src/mmc_com_layer.py +++ b/src/mmc_com_layer.py @@ -5,8 +5,8 @@ from .send_handler import send_handler route_config = RouteConfig( route_config={ - global_config.platform: TargetConfig( - url=f"ws://{global_config.mai_host}:{global_config.mai_port}/ws", + global_config.maibot_server.platform_name: TargetConfig( + url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws", token=None, ) } diff --git a/src/recv_handler.py b/src/recv_handler.py index 7e031ec..7b11fbd 100644 --- a/src/recv_handler.py +++ b/src/recv_handler.py @@ -7,6 +7,8 @@ import json import websockets as Server from typing import List, Tuple, Optional, Dict, Any import uuid +from plyer import notification +import os from . import MetaEventType, RealMessageType, MessageType, NoticeType from maim_message import ( @@ -36,7 +38,8 @@ class RecvHandler: def __init__(self): self.server_connection: Server.ServerConnection = None - self.interval = global_config.napcat_heartbeat_interval + self.interval = global_config.napcat_server.heartbeat_interval + self._interval_checking = False async def handle_meta_event(self, message: dict) -> None: event_type = message.get("meta_event_type") @@ -49,6 +52,8 @@ class RecvHandler: asyncio.create_task(self.check_heartbeat(self_id)) elif event_type == MetaEventType.heartbeat: if message["status"].get("online") and message["status"].get("good"): + if not self._interval_checking: + asyncio.create_task(self.check_heartbeat()) self.last_heart_beat = time.time() self.interval = message.get("interval") / 1000 else: @@ -56,10 +61,20 @@ class RecvHandler: logger.warning(f"Bot {self_id} Napcat 端异常!") async def check_heartbeat(self, id: int) -> None: + self._interval_checking = True while True: now_time = time.time() - if now_time - self.last_heart_beat > self.interval + 3: - logger.warning(f"Bot {id} 连接已断开") + if now_time - self.last_heart_beat > self.interval * 2: + logger.error(f"Bot {id} 连接已断开,被下线,或者Napcat卡死!") + current_dir = os.path.dirname(__file__) + icon_path = os.path.join(current_dir, "..", "assets", "maimai.ico") + notification.notify( + title="警告", + message=f"Bot {id} 连接已断开,被下线,或者Napcat卡死!", + app_name="MaiBot Napcat Adapter", + timeout=10, + app_icon=icon_path, + ) break else: logger.debug("心跳正常") @@ -77,20 +92,20 @@ class RecvHandler: """ logger.debug(f"群聊id: {group_id}, 用户id: {user_id}") if group_id: - if global_config.group_list_type == "whitelist" and group_id not in global_config.group_list: + if global_config.chat.group_list_type == "whitelist" and group_id not in global_config.chat.group_list: logger.warning("群聊不在聊天白名单中,消息被丢弃") return False - elif global_config.group_list_type == "blacklist" and group_id in global_config.group_list: + elif global_config.chat.group_list_type == "blacklist" and group_id in global_config.chat.group_list: logger.warning("群聊在聊天黑名单中,消息被丢弃") return False else: - if global_config.private_list_type == "whitelist" and user_id not in global_config.private_list: + if global_config.chat.private_list_type == "whitelist" and user_id not in global_config.chat.private_list: logger.warning("私聊不在聊天白名单中,消息被丢弃") return False - elif global_config.private_list_type == "blacklist" and user_id in global_config.private_list: + elif global_config.chat.private_list_type == "blacklist" and user_id in global_config.chat.private_list: logger.warning("私聊在聊天黑名单中,消息被丢弃") return False - if user_id in global_config.ban_user_id: + if user_id in global_config.chat.ban_user_id: logger.warning("用户在全局黑名单中,消息被丢弃") return False return True @@ -123,7 +138,7 @@ class RecvHandler: # 发送者用户信息 user_info: UserInfo = UserInfo( - platform=global_config.platform, + platform=global_config.maibot_server.platform_name, user_id=sender_info.get("user_id"), user_nickname=sender_info.get("nickname"), user_cardname=sender_info.get("card"), @@ -149,7 +164,7 @@ class RecvHandler: nickname = fetched_member_info.get("nickname") if fetched_member_info else None # 发送者用户信息 user_info: UserInfo = UserInfo( - platform=global_config.platform, + platform=global_config.maibot_server.platform_name, user_id=sender_info.get("user_id"), user_nickname=nickname, user_cardname=None, @@ -164,7 +179,7 @@ class RecvHandler: group_name = fetched_group_info.get("group_name") group_info: GroupInfo = GroupInfo( - platform=global_config.platform, + platform=global_config.maibot_server.platform_name, group_id=raw_message.get("group_id"), group_name=group_name, ) @@ -182,7 +197,7 @@ class RecvHandler: # 发送者用户信息 user_info: UserInfo = UserInfo( - platform=global_config.platform, + platform=global_config.maibot_server.platform_name, user_id=sender_info.get("user_id"), user_nickname=sender_info.get("nickname"), user_cardname=sender_info.get("card"), @@ -195,7 +210,7 @@ class RecvHandler: group_name = fetched_group_info.get("group_name") group_info: GroupInfo = GroupInfo( - platform=global_config.platform, + platform=global_config.maibot_server.platform_name, group_id=raw_message.get("group_id"), group_name=group_name, ) @@ -205,12 +220,12 @@ class RecvHandler: return None additional_config: dict = {} - if global_config.use_tts: + if global_config.voice.use_tts: additional_config["allow_tts"] = True # 消息信息 message_info: BaseMessageInfo = BaseMessageInfo( - platform=global_config.platform, + platform=global_config.maibot_server.platform_name, message_id=message_id, time=message_time, user_info=user_info, @@ -500,7 +515,7 @@ class RecvHandler: sub_type = raw_message.get("sub_type") match sub_type: case NoticeType.Notify.poke: - if global_config.enable_poke: + if global_config.chat.enable_poke: handled_message: Seg = await self.handle_poke_notify(raw_message) else: logger.warning("戳一戳消息被禁用,取消戳一戳处理") @@ -532,7 +547,7 @@ class RecvHandler: source_name = "QQ用户" user_info: UserInfo = UserInfo( - platform=global_config.platform, + platform=global_config.maibot_server.platform_name, user_id=user_id, user_nickname=source_name, user_cardname=source_cardname, @@ -547,13 +562,13 @@ class RecvHandler: else: logger.warning("无法获取戳一戳消息所在群的名称") group_info = GroupInfo( - platform=global_config.platform, + platform=global_config.maibot_server.platform_name, group_id=group_id, group_name=group_name, ) message_info: BaseMessageInfo = BaseMessageInfo( - platform=global_config.platform, + platform=global_config.maibot_server.platform_name, message_id="notice", time=message_time, user_info=user_info, @@ -697,7 +712,7 @@ class RecvHandler: user_nickname: str = sender_info.get("nickname", "QQ用户") user_nickname_str = f"【{user_nickname}】:" break_seg = Seg(type="text", data="\n") - message_of_sub_message_list: dict = sub_message.get("message") + message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message") if not message_of_sub_message_list: logger.warning("转发消息内容为空") continue @@ -769,7 +784,9 @@ class RecvHandler: async def message_process(self, message_base: MessageBase) -> None: try: - await self.maibot_router.send_message(message_base) + send_status = await self.maibot_router.send_message(message_base) + if not send_status: + raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常") except Exception as e: logger.error(f"发送消息失败: {str(e)}") logger.error("请检查与MaiBot之间的连接") diff --git a/src/response_pool.py b/src/response_pool.py index 66ded1d..c41ed7f 100644 --- a/src/response_pool.py +++ b/src/response_pool.py @@ -35,10 +35,10 @@ async def check_timeout_response() -> None: cleaned_message_count: int = 0 now_time = time.time() for echo_id, response_time in list(response_time_dict.items()): - if now_time - response_time > global_config.napcat_heartbeat_interval: + if now_time - response_time > global_config.napcat_server.heartbeat_interval: cleaned_message_count += 1 response_dict.pop(echo_id) response_time_dict.pop(echo_id) logger.warning(f"响应消息 {echo_id} 超时,已删除") logger.info(f"已删除 {cleaned_message_count} 条超时响应消息") - await asyncio.sleep(global_config.napcat_heartbeat_interval) + await asyncio.sleep(global_config.napcat_server.heartbeat_interval) diff --git a/template/template_config.toml b/template/template_config.toml index 1d0d830..b4cdce0 100644 --- a/template/template_config.toml +++ b/template/template_config.toml @@ -1,30 +1,33 @@ -[Nickname] # 现在没用 +[inner] +version = "0.1.0" # 版本号 +# 请勿修改版本号,除非你知道自己在做什么 + +[nickname] # 现在没用 nickname = "" -[Napcat_Server] # Napcat连接的ws服务设置 -host = "localhost" # Napcat设定的主机地址 -port = 8095 # Napcat设定的端口 -heartbeat = 30 # 与Napcat设置的心跳相同(按秒计) +[napcat_server] # Napcat连接的ws服务设置 +host = "localhost" # Napcat设定的主机地址 +port = 8095 # Napcat设定的端口 +heartbeat_interval = 30 # 与Napcat设置的心跳相同(按秒计) -[MaiBot_Server] # 连接麦麦的ws服务设置 -platform_name = "qq" # 标识adapter的名称(必填) -host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段 -port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段 +[maibot_server] # 连接麦麦的ws服务设置 +host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段 +port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段 -[Chat] # 黑白名单功能 +[chat] # 黑白名单功能 group_list_type = "whitelist" # 群组名单类型,可选为:whitelist, blacklist -group_list = [] # 群组名单 +group_list = [] # 群组名单 # 当group_list_type为whitelist时,只有群组名单中的群组可以聊天 # 当group_list_type为blacklist时,群组名单中的任何群组无法聊天 private_list_type = "whitelist" # 私聊名单类型,可选为:whitelist, blacklist -private_list = [] # 私聊名单 +private_list = [] # 私聊名单 # 当private_list_type为whitelist时,只有私聊名单中的用户可以聊天 # 当private_list_type为blacklist时,私聊名单中的任何用户无法聊天 -ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天) +ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天) enable_poke = true # 是否启用戳一戳功能 -[Voice] # 发送语音设置 +[voice] # 发送语音设置 use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter) -[Debug] -level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR) +[debug] +level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)