diff --git a/.gitignore b/.gitignore
index 6d6652c..267374b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -270,4 +270,5 @@ $RECYCLE.BIN/
*.lnk
config.toml
+config.toml.back
test
\ No newline at end of file
diff --git a/assets/maimai.ico b/assets/maimai.ico
new file mode 100644
index 0000000..578b11c
Binary files /dev/null and b/assets/maimai.ico differ
diff --git a/command_args.md b/command_args.md
index 8a73651..cbbb582 100644
--- a/command_args.md
+++ b/command_args.md
@@ -5,7 +5,7 @@ Seg.type = "command"
## 群聊禁言
```python
Seg.data: Dict[str, Any] = {
- "name": "GROUP_BAN"
+ "name": "GROUP_BAN",
"args": {
"qq_id": "用户QQ号",
"duration": "禁言时长(秒)"
@@ -16,7 +16,7 @@ Seg.data: Dict[str, Any] = {
## 群聊全体禁言
```python
Seg.data: Dict[str, Any] = {
- "name": "GROUP_WHOLE_BAN"
+ "name": "GROUP_WHOLE_BAN",
"args": {
"enable": "是否开启全体禁言(True/False)"
},
@@ -28,10 +28,20 @@ Seg.data: Dict[str, Any] = {
## 群聊踢人
```python
Seg.data: Dict[str, Any] = {
- "name": "GROUP_KICK"
+ "name": "GROUP_KICK",
"args": {
"qq_id": "用户QQ号",
},
}
```
-其中,群聊ID将会通过Group_Info.group_id自动获取。
\ No newline at end of file
+其中,群聊ID将会通过Group_Info.group_id自动获取。
+
+## 戳一戳
+```python
+Seg,.data: Dict[str, Any] = {
+ "name": "SEND_POKE",
+ "args": {
+ "qq_id": "目标QQ号"
+ }
+}
+```
\ No newline at end of file
diff --git a/main.py b/main.py
index 50cf968..2c71ef1 100644
--- a/main.py
+++ b/main.py
@@ -18,7 +18,7 @@ async def message_recv(server_connection: Server.ServerConnection):
async for raw_message in server_connection:
logger.debug(
f"{raw_message[:100]}..."
- if (len(raw_message) > 100 and global_config.debug_level != "DEBUG")
+ if (len(raw_message) > 100 and global_config.debug.level != "DEBUG")
else raw_message
)
decoded_raw_message: dict = json.loads(raw_message)
@@ -52,19 +52,23 @@ async def main():
async def napcat_server():
logger.info("正在启动adapter...")
- async with Server.serve(message_recv, global_config.server_host, global_config.server_port) as server:
- logger.info(f"Adapter已启动,监听地址: ws://{global_config.server_host}:{global_config.server_port}")
+ async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port) as server:
+ logger.info(
+ f"Adapter已启动,监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}"
+ )
await server.serve_forever()
async def graceful_shutdown():
try:
logger.info("正在关闭adapter...")
- await mmc_stop_com()
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
- task.cancel()
- await asyncio.gather(*tasks, return_exceptions=True)
+ if not task.done():
+ task.cancel()
+ await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 15)
+ await mmc_stop_com() # 后置避免神秘exception
+ logger.info("Adapter已成功关闭")
except Exception as e:
logger.error(f"Adapter关闭中出现错误: {e}")
diff --git a/requirements.txt b/requirements.txt
index 41e8eb1..e5964ec 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,4 +5,5 @@ requests
maim_message
loguru
pillow
-tomli
\ No newline at end of file
+tomli
+plyer
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
index a35ff0e..bfae081 100644
--- a/src/__init__.py
+++ b/src/__init__.py
@@ -1,77 +1,78 @@
-from enum import Enum
-
-
-class MetaEventType:
- lifecycle = "lifecycle" # 生命周期
-
- class Lifecycle:
- connect = "connect" # 生命周期 - WebSocket 连接成功
-
- heartbeat = "heartbeat" # 心跳
-
-
-class MessageType: # 接受消息大类
- private = "private" # 私聊消息
-
- class Private:
- friend = "friend" # 私聊消息 - 好友
- group = "group" # 私聊消息 - 群临时
- group_self = "group_self" # 私聊消息 - 群中自身发送
- other = "other" # 私聊消息 - 其他
-
- group = "group" # 群聊消息
-
- class Group:
- normal = "normal" # 群聊消息 - 普通
- anonymous = "anonymous" # 群聊消息 - 匿名消息
- notice = "notice" # 群聊消息 - 系统提示
-
-
-class NoticeType: # 通知事件
- friend_recall = "friend_recall" # 私聊消息撤回
- group_recall = "group_recall" # 群聊消息撤回
- notify = "notify"
-
- class Notify:
- poke = "poke" # 戳一戳
-
-
-class RealMessageType: # 实际消息分类
- text = "text" # 纯文本
- face = "face" # qq表情
- image = "image" # 图片
- record = "record" # 语音
- video = "video" # 视频
- at = "at" # @某人
- rps = "rps" # 猜拳魔法表情
- dice = "dice" # 骰子
- shake = "shake" # 私聊窗口抖动(只收)
- poke = "poke" # 群聊戳一戳
- share = "share" # 链接分享(json形式)
- reply = "reply" # 回复消息
- forward = "forward" # 转发消息
- node = "node" # 转发消息节点
-
-
-class MessageSentType:
- private = "private"
-
- class Private:
- friend = "friend"
- group = "group"
-
- group = "group"
-
- class Group:
- normal = "normal"
-
-
-class CommandType(Enum):
- """命令类型"""
-
- GROUP_BAN = "set_group_ban" # 禁言用户
- GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
- GROUP_KICK = "set_group_kick" # 踢出群聊
-
- def __str__(self) -> str:
- return self.value
+from enum import Enum
+
+
+class MetaEventType:
+ lifecycle = "lifecycle" # 生命周期
+
+ class Lifecycle:
+ connect = "connect" # 生命周期 - WebSocket 连接成功
+
+ heartbeat = "heartbeat" # 心跳
+
+
+class MessageType: # 接受消息大类
+ private = "private" # 私聊消息
+
+ class Private:
+ friend = "friend" # 私聊消息 - 好友
+ group = "group" # 私聊消息 - 群临时
+ group_self = "group_self" # 私聊消息 - 群中自身发送
+ other = "other" # 私聊消息 - 其他
+
+ group = "group" # 群聊消息
+
+ class Group:
+ normal = "normal" # 群聊消息 - 普通
+ anonymous = "anonymous" # 群聊消息 - 匿名消息
+ notice = "notice" # 群聊消息 - 系统提示
+
+
+class NoticeType: # 通知事件
+ friend_recall = "friend_recall" # 私聊消息撤回
+ group_recall = "group_recall" # 群聊消息撤回
+ notify = "notify"
+
+ class Notify:
+ poke = "poke" # 戳一戳
+
+
+class RealMessageType: # 实际消息分类
+ text = "text" # 纯文本
+ face = "face" # qq表情
+ image = "image" # 图片
+ record = "record" # 语音
+ video = "video" # 视频
+ at = "at" # @某人
+ rps = "rps" # 猜拳魔法表情
+ dice = "dice" # 骰子
+ shake = "shake" # 私聊窗口抖动(只收)
+ poke = "poke" # 群聊戳一戳
+ share = "share" # 链接分享(json形式)
+ reply = "reply" # 回复消息
+ forward = "forward" # 转发消息
+ node = "node" # 转发消息节点
+
+
+class MessageSentType:
+ private = "private"
+
+ class Private:
+ friend = "friend"
+ group = "group"
+
+ group = "group"
+
+ class Group:
+ normal = "normal"
+
+
+class CommandType(Enum):
+ """命令类型"""
+
+ GROUP_BAN = "set_group_ban" # 禁言用户
+ GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
+ GROUP_KICK = "set_group_kick" # 踢出群聊
+ SEND_POKE = "send_poke" # 戳一戳
+
+ def __str__(self) -> str:
+ return self.value
diff --git a/src/config.py b/src/config.py
deleted file mode 100644
index ee13c98..0000000
--- a/src/config.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import os
-import sys
-import tomli
-import shutil
-from .logger import logger
-from typing import Optional
-
-
-class Config:
- platform: str = "qq"
- nickname: Optional[str] = None
- server_host: str = "localhost"
- server_port: int = 8095
- napcat_heartbeat_interval: int = 30
-
- def __init__(self):
- self._get_config_path()
-
- def _get_config_path(self):
- current_file_path = os.path.abspath(__file__)
- src_path = os.path.dirname(current_file_path)
- self.root_path = os.path.join(src_path, "..")
- self.config_path = os.path.join(self.root_path, "config.toml")
-
- def load_config(self): # sourcery skip: extract-method, move-assign
- include_configs = ["Napcat_Server", "MaiBot_Server", "Chat", "Voice", "Debug"]
- if not os.path.exists(self.config_path):
- logger.error("配置文件不存在!")
- logger.info("正在创建配置文件...")
- shutil.copy(
- os.path.join(self.root_path, "template", "template_config.toml"),
- os.path.join(self.root_path, "config.toml"),
- )
- logger.info("配置文件创建成功,请修改配置文件后重启程序。")
- sys.exit(1)
- with open(self.config_path, "rb") as f:
- try:
- raw_config = tomli.load(f)
- except tomli.TOMLDecodeError as e:
- logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}")
- sys.exit(1)
- for key in include_configs:
- if key not in raw_config:
- logger.error(f"配置文件中缺少必需的字段: '{key}'")
- logger.error("你的配置文件可能过时,请尝试手动更新配置文件。")
- sys.exit(1)
-
- self.server_host = raw_config["Napcat_Server"].get("host", "localhost")
- self.server_port = raw_config["Napcat_Server"].get("port", 8095)
- self.napcat_heartbeat_interval = raw_config["Napcat_Server"].get("heartbeat", 30)
-
- self.mai_host = raw_config["MaiBot_Server"].get("host", "localhost")
- self.mai_port = raw_config["MaiBot_Server"].get("port", 8000)
- self.platform = raw_config["MaiBot_Server"].get("platform_name")
- if not self.platform:
- logger.critical("请在配置文件中指定平台")
- sys.exit(1)
-
- self.group_list_type: str = raw_config["Chat"].get("group_list_type")
- self.group_list: list = raw_config["Chat"].get("group_list", [])
- self.private_list_type: str = raw_config["Chat"].get("private_list_type")
- self.private_list: list = raw_config["Chat"].get("private_list", [])
- self.ban_user_id: list = raw_config["Chat"].get("ban_user_id", [])
- self.enable_poke: bool = raw_config["Chat"].get("enable_poke", True)
- if self.group_list_type not in ["whitelist", "blacklist"]:
- logger.critical("请在配置文件中指定group_list_type或group_list_type填写错误")
- sys.exit(1)
- if self.private_list_type not in ["whitelist", "blacklist"]:
- logger.critical("请在配置文件中指定private_list_type或private_list_type填写错误")
- sys.exit(1)
-
- self.use_tts = raw_config["Voice"].get("use_tts", False)
-
- self.debug_level = raw_config["Debug"].get("level", "INFO")
- if self.debug_level == "DEBUG":
- logger.debug("原始配置文件内容:")
- logger.debug(raw_config)
- logger.debug("读取到的配置内容:")
- logger.debug(f"平台: {self.platform}")
- logger.debug(f"MaiBot服务器地址: {self.mai_host}:{self.mai_port}")
- logger.debug(f"Napcat服务器地址: {self.server_host}:{self.server_port}")
- logger.debug(f"心跳间隔: {self.napcat_heartbeat_interval}秒")
- logger.debug(f"群聊列表类型: {self.group_list_type}")
- logger.debug(f"群聊列表: {self.group_list}")
- logger.debug(f"私聊列表类型: {self.private_list_type}")
- logger.debug(f"私聊列表: {self.private_list}")
- logger.debug(f"禁用用户ID列表: {self.ban_user_id}")
- logger.debug(f"是否启用TTS: {self.use_tts}")
- logger.debug(f"调试级别: {self.debug_level}")
-
-
-global_config = Config()
-global_config.load_config()
diff --git a/src/config/__init__.py b/src/config/__init__.py
new file mode 100644
index 0000000..40ba89a
--- /dev/null
+++ b/src/config/__init__.py
@@ -0,0 +1,5 @@
+from .config import global_config
+
+__all__ = [
+ "global_config",
+]
diff --git a/src/config/config.py b/src/config/config.py
new file mode 100644
index 0000000..a219078
--- /dev/null
+++ b/src/config/config.py
@@ -0,0 +1,140 @@
+import os
+from dataclasses import dataclass
+
+import tomlkit
+import shutil
+
+from tomlkit import TOMLDocument
+from tomlkit.items import Table
+from ..logger import logger
+from rich.traceback import install
+
+from src.config.config_base import ConfigBase
+from src.config.official_configs import (
+ ChatConfig,
+ DebugConfig,
+ MaiBotServerConfig,
+ NapcatServerConfig,
+ NicknameConfig,
+ VoiceConfig,
+)
+
+install(extra_lines=3)
+
+TEMPLATE_DIR = "template"
+
+
+def update_config():
+ # 定义文件路径
+ template_path = f"{TEMPLATE_DIR}/template_config.toml"
+ old_config_path = "config.toml"
+ new_config_path = "config.toml"
+
+ # 检查配置文件是否存在
+ if not os.path.exists(old_config_path):
+ logger.info("配置文件不存在,从模板创建新配置")
+ shutil.copy2(template_path, old_config_path) # 复制模板文件
+ logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
+ # 如果是新创建的配置文件,直接返回
+ quit()
+
+ # 读取旧配置文件和模板文件
+ with open(old_config_path, "r", encoding="utf-8") as f:
+ old_config = tomlkit.load(f)
+ with open(template_path, "r", encoding="utf-8") as f:
+ new_config = tomlkit.load(f)
+
+ # 检查version是否相同
+ if old_config and "inner" in old_config and "inner" in new_config:
+ old_version = old_config["inner"].get("version")
+ new_version = new_config["inner"].get("version")
+ if old_version and new_version and old_version == new_version:
+ logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
+ return
+ else:
+ logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
+ else:
+ logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
+
+ # 备份文件名
+ old_backup_path = "config.toml.back"
+
+ # 备份旧配置文件
+ shutil.move(old_config_path, old_backup_path)
+ logger.info(f"已备份旧配置文件到: {old_backup_path}")
+
+ # 复制模板文件到配置目录
+ shutil.copy2(template_path, new_config_path)
+ logger.info(f"已创建新配置文件: {new_config_path}")
+
+ def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
+ """
+ 将source字典的值更新到target字典中(如果target中存在相同的键)
+ """
+ for key, value in source.items():
+ # 跳过version字段的更新
+ if key == "version":
+ continue
+ if key in target:
+ if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
+ update_dict(target[key], value)
+ else:
+ try:
+ # 对数组类型进行特殊处理
+ if isinstance(value, list):
+ # 如果是空数组,确保它保持为空数组
+ target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
+ else:
+ # 其他类型使用item方法创建新值
+ target[key] = tomlkit.item(value)
+ except (TypeError, ValueError):
+ # 如果转换失败,直接赋值
+ target[key] = value
+
+ # 将旧配置的值更新到新配置中
+ logger.info("开始合并新旧配置...")
+ update_dict(new_config, old_config)
+
+ # 保存更新后的配置(保留注释和格式)
+ with open(new_config_path, "w", encoding="utf-8") as f:
+ f.write(tomlkit.dumps(new_config))
+ logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
+ quit()
+
+
+@dataclass
+class Config(ConfigBase):
+ """总配置类"""
+
+ nickname: NicknameConfig
+ napcat_server: NapcatServerConfig
+ maibot_server: MaiBotServerConfig
+ chat: ChatConfig
+ voice: VoiceConfig
+ debug: DebugConfig
+
+
+def load_config(config_path: str) -> Config:
+ """
+ 加载配置文件
+ :param config_path: 配置文件路径
+ :return: Config对象
+ """
+ # 读取配置文件
+ with open(config_path, "r", encoding="utf-8") as f:
+ config_data = tomlkit.load(f)
+
+ # 创建Config对象
+ try:
+ return Config.from_dict(config_data)
+ except Exception as e:
+ logger.critical("配置文件解析失败")
+ raise e
+
+
+# 更新配置
+update_config()
+
+logger.info("正在品鉴配置文件...")
+global_config = load_config(config_path="config.toml")
+logger.info("非常的新鲜,非常的美味!")
diff --git a/src/config/config_base.py b/src/config/config_base.py
new file mode 100644
index 0000000..87cb079
--- /dev/null
+++ b/src/config/config_base.py
@@ -0,0 +1,136 @@
+from dataclasses import dataclass, fields, MISSING
+from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
+
+T = TypeVar("T", bound="ConfigBase")
+
+TOML_DICT_TYPE = {
+ int,
+ float,
+ str,
+ bool,
+ list,
+ dict,
+}
+
+
+@dataclass
+class ConfigBase:
+ """配置类的基类"""
+
+ @classmethod
+ def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
+ """从字典加载配置字段"""
+ if not isinstance(data, dict):
+ raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
+
+ init_args: Dict[str, Any] = {}
+
+ for f in fields(cls):
+ field_name = f.name
+ field_type = f.type
+ if field_name.startswith("_"):
+ # 跳过以 _ 开头的字段
+ continue
+
+ if field_name not in data:
+ if f.default is not MISSING or f.default_factory is not MISSING:
+ # 跳过未提供且有默认值/默认构造方法的字段
+ continue
+ else:
+ raise ValueError(f"Missing required field: '{field_name}'")
+
+ value = data[field_name]
+ try:
+ init_args[field_name] = cls._convert_field(value, field_type)
+ except TypeError as e:
+ raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
+ except Exception as e:
+ raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
+
+ return cls(**init_args)
+
+ @classmethod
+ def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
+ """
+ 转换字段值为指定类型
+
+ 1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
+ 2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
+ 3. 对于基础类型(int, str, float, bool),直接转换
+ 4. 对于其他类型,尝试直接转换,如果失败则抛出异常
+ """
+ # 如果是嵌套的 dataclass,递归调用 from_dict 方法
+ if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
+ return field_type.from_dict(value)
+
+ field_origin_type = get_origin(field_type)
+ field_args_type = get_args(field_type)
+
+ # 处理泛型集合类型(list, set, tuple)
+ if field_origin_type in {list, set, tuple}:
+ # 检查提供的value是否为list
+ if not isinstance(value, list):
+ raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
+
+ if field_origin_type is list:
+ return [cls._convert_field(item, field_args_type[0]) for item in value]
+ if field_origin_type is set:
+ return {cls._convert_field(item, field_args_type[0]) for item in value}
+ if field_origin_type is tuple:
+ # 检查提供的value长度是否与类型参数一致
+ if len(value) != len(field_args_type):
+ raise TypeError(
+ f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
+ )
+ return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
+
+ if field_origin_type is dict:
+ # 检查提供的value是否为dict
+ if not isinstance(value, dict):
+ raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
+
+ # 检查字典的键值类型
+ if len(field_args_type) != 2:
+ raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
+ key_type, value_type = field_args_type
+
+ return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
+
+ # 处理Optional类型
+ if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
+ if value is None:
+ return None
+ # 如果有数据,检查实际类型
+ if type(value) not in field_args_type:
+ raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
+ return cls._convert_field(value, field_args_type[0])
+
+ # 处理int, str, float, bool等基础类型
+ if field_origin_type is None:
+ if isinstance(value, field_type):
+ return field_type(value)
+ else:
+ raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
+
+ # 处理Literal类型
+ if field_origin_type is Literal:
+ # 获取Literal的允许值
+ allowed_values = get_args(field_type)
+ if value in allowed_values:
+ return value
+ else:
+ raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
+
+ # 处理其他类型
+ if field_type is Any:
+ return value
+
+ # 其他类型直接转换
+ try:
+ return field_type(value)
+ except (ValueError, TypeError) as e:
+ raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
+
+ def __str__(self):
+ """返回配置类的字符串表示"""
+ return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
new file mode 100644
index 0000000..d8928a8
--- /dev/null
+++ b/src/config/official_configs.py
@@ -0,0 +1,77 @@
+from dataclasses import dataclass, field
+from typing import Literal
+
+from src.config.config_base import ConfigBase
+
+"""
+须知:
+1. 本文件中记录了所有的配置项
+2. 所有新增的class都需要继承自ConfigBase
+3. 所有新增的class都应在config.py中的Config类中添加字段
+4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
+"""
+
+ADAPTER_PLATFORM = "qq"
+
+
+@dataclass
+class NicknameConfig(ConfigBase):
+ nickname: str
+ """机器人昵称"""
+
+
+@dataclass
+class NapcatServerConfig(ConfigBase):
+ host: str = "localhost"
+ """Napcat服务端的主机地址"""
+
+ port: int = 8095
+ """Napcat服务端的端口号"""
+
+ heartbeat_interval: int = 30
+ """Napcat心跳间隔时间,单位为秒"""
+
+
+@dataclass
+class MaiBotServerConfig(ConfigBase):
+ platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
+ """平台名称,“qq”"""
+
+ host: str = "localhost"
+ """MaiMCore的主机地址"""
+
+ port: int = 8000
+ """MaiMCore的端口号"""
+
+
+@dataclass
+class ChatConfig(ConfigBase):
+ group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
+ """群聊列表类型 白名单/黑名单"""
+
+ group_list: list[int] = field(default_factory=[])
+ """群聊列表"""
+
+ private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
+ """私聊列表类型 白名单/黑名单"""
+
+ private_list: list[int] = field(default_factory=[])
+ """私聊列表"""
+
+ ban_user_id: list[int] = field(default_factory=[])
+ """被封禁的用户ID列表,封禁后将无法与其进行交互"""
+
+ enable_poke: bool = True
+ """是否启用戳一戳功能"""
+
+
+@dataclass
+class VoiceConfig(ConfigBase):
+ use_tts: bool = False
+ """是否启用TTS功能"""
+
+
+@dataclass
+class DebugConfig(ConfigBase):
+ level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
+ """日志级别,默认为INFO"""
diff --git a/src/logger.py b/src/logger.py
index 3acba4f..8071ff7 100644
--- a/src/logger.py
+++ b/src/logger.py
@@ -5,6 +5,6 @@ import sys
logger.remove()
logger.add(
sys.stderr,
- level=global_config.debug_level,
+ level=global_config.debug.level,
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
)
diff --git a/src/mmc_com_layer.py b/src/mmc_com_layer.py
index 174ef1f..ab50cca 100644
--- a/src/mmc_com_layer.py
+++ b/src/mmc_com_layer.py
@@ -5,8 +5,8 @@ from .send_handler import send_handler
route_config = RouteConfig(
route_config={
- global_config.platform: TargetConfig(
- url=f"ws://{global_config.mai_host}:{global_config.mai_port}/ws",
+ global_config.maibot_server.platform_name: TargetConfig(
+ url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
token=None,
)
}
diff --git a/src/recv_handler.py b/src/recv_handler.py
index 7e031ec..7b11fbd 100644
--- a/src/recv_handler.py
+++ b/src/recv_handler.py
@@ -7,6 +7,8 @@ import json
import websockets as Server
from typing import List, Tuple, Optional, Dict, Any
import uuid
+from plyer import notification
+import os
from . import MetaEventType, RealMessageType, MessageType, NoticeType
from maim_message import (
@@ -36,7 +38,8 @@ class RecvHandler:
def __init__(self):
self.server_connection: Server.ServerConnection = None
- self.interval = global_config.napcat_heartbeat_interval
+ self.interval = global_config.napcat_server.heartbeat_interval
+ self._interval_checking = False
async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type")
@@ -49,6 +52,8 @@ class RecvHandler:
asyncio.create_task(self.check_heartbeat(self_id))
elif event_type == MetaEventType.heartbeat:
if message["status"].get("online") and message["status"].get("good"):
+ if not self._interval_checking:
+ asyncio.create_task(self.check_heartbeat())
self.last_heart_beat = time.time()
self.interval = message.get("interval") / 1000
else:
@@ -56,10 +61,20 @@ class RecvHandler:
logger.warning(f"Bot {self_id} Napcat 端异常!")
async def check_heartbeat(self, id: int) -> None:
+ self._interval_checking = True
while True:
now_time = time.time()
- if now_time - self.last_heart_beat > self.interval + 3:
- logger.warning(f"Bot {id} 连接已断开")
+ if now_time - self.last_heart_beat > self.interval * 2:
+ logger.error(f"Bot {id} 连接已断开,被下线,或者Napcat卡死!")
+ current_dir = os.path.dirname(__file__)
+ icon_path = os.path.join(current_dir, "..", "assets", "maimai.ico")
+ notification.notify(
+ title="警告",
+ message=f"Bot {id} 连接已断开,被下线,或者Napcat卡死!",
+ app_name="MaiBot Napcat Adapter",
+ timeout=10,
+ app_icon=icon_path,
+ )
break
else:
logger.debug("心跳正常")
@@ -77,20 +92,20 @@ class RecvHandler:
"""
logger.debug(f"群聊id: {group_id}, 用户id: {user_id}")
if group_id:
- if global_config.group_list_type == "whitelist" and group_id not in global_config.group_list:
+ if global_config.chat.group_list_type == "whitelist" and group_id not in global_config.chat.group_list:
logger.warning("群聊不在聊天白名单中,消息被丢弃")
return False
- elif global_config.group_list_type == "blacklist" and group_id in global_config.group_list:
+ elif global_config.chat.group_list_type == "blacklist" and group_id in global_config.chat.group_list:
logger.warning("群聊在聊天黑名单中,消息被丢弃")
return False
else:
- if global_config.private_list_type == "whitelist" and user_id not in global_config.private_list:
+ if global_config.chat.private_list_type == "whitelist" and user_id not in global_config.chat.private_list:
logger.warning("私聊不在聊天白名单中,消息被丢弃")
return False
- elif global_config.private_list_type == "blacklist" and user_id in global_config.private_list:
+ elif global_config.chat.private_list_type == "blacklist" and user_id in global_config.chat.private_list:
logger.warning("私聊在聊天黑名单中,消息被丢弃")
return False
- if user_id in global_config.ban_user_id:
+ if user_id in global_config.chat.ban_user_id:
logger.warning("用户在全局黑名单中,消息被丢弃")
return False
return True
@@ -123,7 +138,7 @@ class RecvHandler:
# 发送者用户信息
user_info: UserInfo = UserInfo(
- platform=global_config.platform,
+ platform=global_config.maibot_server.platform_name,
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
@@ -149,7 +164,7 @@ class RecvHandler:
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
# 发送者用户信息
user_info: UserInfo = UserInfo(
- platform=global_config.platform,
+ platform=global_config.maibot_server.platform_name,
user_id=sender_info.get("user_id"),
user_nickname=nickname,
user_cardname=None,
@@ -164,7 +179,7 @@ class RecvHandler:
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
- platform=global_config.platform,
+ platform=global_config.maibot_server.platform_name,
group_id=raw_message.get("group_id"),
group_name=group_name,
)
@@ -182,7 +197,7 @@ class RecvHandler:
# 发送者用户信息
user_info: UserInfo = UserInfo(
- platform=global_config.platform,
+ platform=global_config.maibot_server.platform_name,
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
@@ -195,7 +210,7 @@ class RecvHandler:
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
- platform=global_config.platform,
+ platform=global_config.maibot_server.platform_name,
group_id=raw_message.get("group_id"),
group_name=group_name,
)
@@ -205,12 +220,12 @@ class RecvHandler:
return None
additional_config: dict = {}
- if global_config.use_tts:
+ if global_config.voice.use_tts:
additional_config["allow_tts"] = True
# 消息信息
message_info: BaseMessageInfo = BaseMessageInfo(
- platform=global_config.platform,
+ platform=global_config.maibot_server.platform_name,
message_id=message_id,
time=message_time,
user_info=user_info,
@@ -500,7 +515,7 @@ class RecvHandler:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.Notify.poke:
- if global_config.enable_poke:
+ if global_config.chat.enable_poke:
handled_message: Seg = await self.handle_poke_notify(raw_message)
else:
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
@@ -532,7 +547,7 @@ class RecvHandler:
source_name = "QQ用户"
user_info: UserInfo = UserInfo(
- platform=global_config.platform,
+ platform=global_config.maibot_server.platform_name,
user_id=user_id,
user_nickname=source_name,
user_cardname=source_cardname,
@@ -547,13 +562,13 @@ class RecvHandler:
else:
logger.warning("无法获取戳一戳消息所在群的名称")
group_info = GroupInfo(
- platform=global_config.platform,
+ platform=global_config.maibot_server.platform_name,
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
- platform=global_config.platform,
+ platform=global_config.maibot_server.platform_name,
message_id="notice",
time=message_time,
user_info=user_info,
@@ -697,7 +712,7 @@ class RecvHandler:
user_nickname: str = sender_info.get("nickname", "QQ用户")
user_nickname_str = f"【{user_nickname}】:"
break_seg = Seg(type="text", data="\n")
- message_of_sub_message_list: dict = sub_message.get("message")
+ message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message")
if not message_of_sub_message_list:
logger.warning("转发消息内容为空")
continue
@@ -769,7 +784,9 @@ class RecvHandler:
async def message_process(self, message_base: MessageBase) -> None:
try:
- await self.maibot_router.send_message(message_base)
+ send_status = await self.maibot_router.send_message(message_base)
+ if not send_status:
+ raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常")
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
logger.error("请检查与MaiBot之间的连接")
diff --git a/src/response_pool.py b/src/response_pool.py
index 66ded1d..c41ed7f 100644
--- a/src/response_pool.py
+++ b/src/response_pool.py
@@ -35,10 +35,10 @@ async def check_timeout_response() -> None:
cleaned_message_count: int = 0
now_time = time.time()
for echo_id, response_time in list(response_time_dict.items()):
- if now_time - response_time > global_config.napcat_heartbeat_interval:
+ if now_time - response_time > global_config.napcat_server.heartbeat_interval:
cleaned_message_count += 1
response_dict.pop(echo_id)
response_time_dict.pop(echo_id)
logger.warning(f"响应消息 {echo_id} 超时,已删除")
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
- await asyncio.sleep(global_config.napcat_heartbeat_interval)
+ await asyncio.sleep(global_config.napcat_server.heartbeat_interval)
diff --git a/template/template_config.toml b/template/template_config.toml
index 1d0d830..b4cdce0 100644
--- a/template/template_config.toml
+++ b/template/template_config.toml
@@ -1,30 +1,33 @@
-[Nickname] # 现在没用
+[inner]
+version = "0.1.0" # 版本号
+# 请勿修改版本号,除非你知道自己在做什么
+
+[nickname] # 现在没用
nickname = ""
-[Napcat_Server] # Napcat连接的ws服务设置
-host = "localhost" # Napcat设定的主机地址
-port = 8095 # Napcat设定的端口
-heartbeat = 30 # 与Napcat设置的心跳相同(按秒计)
+[napcat_server] # Napcat连接的ws服务设置
+host = "localhost" # Napcat设定的主机地址
+port = 8095 # Napcat设定的端口
+heartbeat_interval = 30 # 与Napcat设置的心跳相同(按秒计)
-[MaiBot_Server] # 连接麦麦的ws服务设置
-platform_name = "qq" # 标识adapter的名称(必填)
-host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
-port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
+[maibot_server] # 连接麦麦的ws服务设置
+host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
+port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
-[Chat] # 黑白名单功能
+[chat] # 黑白名单功能
group_list_type = "whitelist" # 群组名单类型,可选为:whitelist, blacklist
-group_list = [] # 群组名单
+group_list = [] # 群组名单
# 当group_list_type为whitelist时,只有群组名单中的群组可以聊天
# 当group_list_type为blacklist时,群组名单中的任何群组无法聊天
private_list_type = "whitelist" # 私聊名单类型,可选为:whitelist, blacklist
-private_list = [] # 私聊名单
+private_list = [] # 私聊名单
# 当private_list_type为whitelist时,只有私聊名单中的用户可以聊天
# 当private_list_type为blacklist时,私聊名单中的任何用户无法聊天
-ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天)
+ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天)
enable_poke = true # 是否启用戳一戳功能
-[Voice] # 发送语音设置
+[voice] # 发送语音设置
use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter)
-[Debug]
-level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR)
+[debug]
+level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)