Merge branch 'dev' into feat/voice
commit
e4620fb7db
|
|
@ -270,4 +270,5 @@ $RECYCLE.BIN/
|
|||
*.lnk
|
||||
|
||||
config.toml
|
||||
config.toml.back
|
||||
test
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 66 KiB |
|
|
@ -5,7 +5,7 @@ Seg.type = "command"
|
|||
## 群聊禁言
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"name": "GROUP_BAN"
|
||||
"name": "GROUP_BAN",
|
||||
"args": {
|
||||
"qq_id": "用户QQ号",
|
||||
"duration": "禁言时长(秒)"
|
||||
|
|
@ -16,7 +16,7 @@ Seg.data: Dict[str, Any] = {
|
|||
## 群聊全体禁言
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"name": "GROUP_WHOLE_BAN"
|
||||
"name": "GROUP_WHOLE_BAN",
|
||||
"args": {
|
||||
"enable": "是否开启全体禁言(True/False)"
|
||||
},
|
||||
|
|
@ -28,10 +28,20 @@ Seg.data: Dict[str, Any] = {
|
|||
## 群聊踢人
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"name": "GROUP_KICK"
|
||||
"name": "GROUP_KICK",
|
||||
"args": {
|
||||
"qq_id": "用户QQ号",
|
||||
},
|
||||
}
|
||||
```
|
||||
其中,群聊ID将会通过Group_Info.group_id自动获取。
|
||||
其中,群聊ID将会通过Group_Info.group_id自动获取。
|
||||
|
||||
## 戳一戳
|
||||
```python
|
||||
Seg,.data: Dict[str, Any] = {
|
||||
"name": "SEND_POKE",
|
||||
"args": {
|
||||
"qq_id": "目标QQ号"
|
||||
}
|
||||
}
|
||||
```
|
||||
16
main.py
16
main.py
|
|
@ -18,7 +18,7 @@ async def message_recv(server_connection: Server.ServerConnection):
|
|||
async for raw_message in server_connection:
|
||||
logger.debug(
|
||||
f"{raw_message[:100]}..."
|
||||
if (len(raw_message) > 100 and global_config.debug_level != "DEBUG")
|
||||
if (len(raw_message) > 100 and global_config.debug.level != "DEBUG")
|
||||
else raw_message
|
||||
)
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
|
|
@ -52,19 +52,23 @@ async def main():
|
|||
|
||||
async def napcat_server():
|
||||
logger.info("正在启动adapter...")
|
||||
async with Server.serve(message_recv, global_config.server_host, global_config.server_port) as server:
|
||||
logger.info(f"Adapter已启动,监听地址: ws://{global_config.server_host}:{global_config.server_port}")
|
||||
async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port) as server:
|
||||
logger.info(
|
||||
f"Adapter已启动,监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}"
|
||||
)
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
async def graceful_shutdown():
|
||||
try:
|
||||
logger.info("正在关闭adapter...")
|
||||
await mmc_stop_com()
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 15)
|
||||
await mmc_stop_com() # 后置避免神秘exception
|
||||
logger.info("Adapter已成功关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"Adapter关闭中出现错误: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,4 +5,5 @@ requests
|
|||
maim_message
|
||||
loguru
|
||||
pillow
|
||||
tomli
|
||||
tomli
|
||||
plyer
|
||||
155
src/__init__.py
155
src/__init__.py
|
|
@ -1,77 +1,78 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class MetaEventType:
|
||||
lifecycle = "lifecycle" # 生命周期
|
||||
|
||||
class Lifecycle:
|
||||
connect = "connect" # 生命周期 - WebSocket 连接成功
|
||||
|
||||
heartbeat = "heartbeat" # 心跳
|
||||
|
||||
|
||||
class MessageType: # 接受消息大类
|
||||
private = "private" # 私聊消息
|
||||
|
||||
class Private:
|
||||
friend = "friend" # 私聊消息 - 好友
|
||||
group = "group" # 私聊消息 - 群临时
|
||||
group_self = "group_self" # 私聊消息 - 群中自身发送
|
||||
other = "other" # 私聊消息 - 其他
|
||||
|
||||
group = "group" # 群聊消息
|
||||
|
||||
class Group:
|
||||
normal = "normal" # 群聊消息 - 普通
|
||||
anonymous = "anonymous" # 群聊消息 - 匿名消息
|
||||
notice = "notice" # 群聊消息 - 系统提示
|
||||
|
||||
|
||||
class NoticeType: # 通知事件
|
||||
friend_recall = "friend_recall" # 私聊消息撤回
|
||||
group_recall = "group_recall" # 群聊消息撤回
|
||||
notify = "notify"
|
||||
|
||||
class Notify:
|
||||
poke = "poke" # 戳一戳
|
||||
|
||||
|
||||
class RealMessageType: # 实际消息分类
|
||||
text = "text" # 纯文本
|
||||
face = "face" # qq表情
|
||||
image = "image" # 图片
|
||||
record = "record" # 语音
|
||||
video = "video" # 视频
|
||||
at = "at" # @某人
|
||||
rps = "rps" # 猜拳魔法表情
|
||||
dice = "dice" # 骰子
|
||||
shake = "shake" # 私聊窗口抖动(只收)
|
||||
poke = "poke" # 群聊戳一戳
|
||||
share = "share" # 链接分享(json形式)
|
||||
reply = "reply" # 回复消息
|
||||
forward = "forward" # 转发消息
|
||||
node = "node" # 转发消息节点
|
||||
|
||||
|
||||
class MessageSentType:
|
||||
private = "private"
|
||||
|
||||
class Private:
|
||||
friend = "friend"
|
||||
group = "group"
|
||||
|
||||
group = "group"
|
||||
|
||||
class Group:
|
||||
normal = "normal"
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
"""命令类型"""
|
||||
|
||||
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MetaEventType:
|
||||
lifecycle = "lifecycle" # 生命周期
|
||||
|
||||
class Lifecycle:
|
||||
connect = "connect" # 生命周期 - WebSocket 连接成功
|
||||
|
||||
heartbeat = "heartbeat" # 心跳
|
||||
|
||||
|
||||
class MessageType: # 接受消息大类
|
||||
private = "private" # 私聊消息
|
||||
|
||||
class Private:
|
||||
friend = "friend" # 私聊消息 - 好友
|
||||
group = "group" # 私聊消息 - 群临时
|
||||
group_self = "group_self" # 私聊消息 - 群中自身发送
|
||||
other = "other" # 私聊消息 - 其他
|
||||
|
||||
group = "group" # 群聊消息
|
||||
|
||||
class Group:
|
||||
normal = "normal" # 群聊消息 - 普通
|
||||
anonymous = "anonymous" # 群聊消息 - 匿名消息
|
||||
notice = "notice" # 群聊消息 - 系统提示
|
||||
|
||||
|
||||
class NoticeType: # 通知事件
|
||||
friend_recall = "friend_recall" # 私聊消息撤回
|
||||
group_recall = "group_recall" # 群聊消息撤回
|
||||
notify = "notify"
|
||||
|
||||
class Notify:
|
||||
poke = "poke" # 戳一戳
|
||||
|
||||
|
||||
class RealMessageType: # 实际消息分类
|
||||
text = "text" # 纯文本
|
||||
face = "face" # qq表情
|
||||
image = "image" # 图片
|
||||
record = "record" # 语音
|
||||
video = "video" # 视频
|
||||
at = "at" # @某人
|
||||
rps = "rps" # 猜拳魔法表情
|
||||
dice = "dice" # 骰子
|
||||
shake = "shake" # 私聊窗口抖动(只收)
|
||||
poke = "poke" # 群聊戳一戳
|
||||
share = "share" # 链接分享(json形式)
|
||||
reply = "reply" # 回复消息
|
||||
forward = "forward" # 转发消息
|
||||
node = "node" # 转发消息节点
|
||||
|
||||
|
||||
class MessageSentType:
|
||||
private = "private"
|
||||
|
||||
class Private:
|
||||
friend = "friend"
|
||||
group = "group"
|
||||
|
||||
group = "group"
|
||||
|
||||
class Group:
|
||||
normal = "normal"
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
"""命令类型"""
|
||||
|
||||
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||
SEND_POKE = "send_poke" # 戳一戳
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
|
|
|||
|
|
@ -1,93 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import tomli
|
||||
import shutil
|
||||
from .logger import logger
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Config:
|
||||
platform: str = "qq"
|
||||
nickname: Optional[str] = None
|
||||
server_host: str = "localhost"
|
||||
server_port: int = 8095
|
||||
napcat_heartbeat_interval: int = 30
|
||||
|
||||
def __init__(self):
|
||||
self._get_config_path()
|
||||
|
||||
def _get_config_path(self):
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
src_path = os.path.dirname(current_file_path)
|
||||
self.root_path = os.path.join(src_path, "..")
|
||||
self.config_path = os.path.join(self.root_path, "config.toml")
|
||||
|
||||
def load_config(self): # sourcery skip: extract-method, move-assign
|
||||
include_configs = ["Napcat_Server", "MaiBot_Server", "Chat", "Voice", "Debug"]
|
||||
if not os.path.exists(self.config_path):
|
||||
logger.error("配置文件不存在!")
|
||||
logger.info("正在创建配置文件...")
|
||||
shutil.copy(
|
||||
os.path.join(self.root_path, "template", "template_config.toml"),
|
||||
os.path.join(self.root_path, "config.toml"),
|
||||
)
|
||||
logger.info("配置文件创建成功,请修改配置文件后重启程序。")
|
||||
sys.exit(1)
|
||||
with open(self.config_path, "rb") as f:
|
||||
try:
|
||||
raw_config = tomli.load(f)
|
||||
except tomli.TOMLDecodeError as e:
|
||||
logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}")
|
||||
sys.exit(1)
|
||||
for key in include_configs:
|
||||
if key not in raw_config:
|
||||
logger.error(f"配置文件中缺少必需的字段: '{key}'")
|
||||
logger.error("你的配置文件可能过时,请尝试手动更新配置文件。")
|
||||
sys.exit(1)
|
||||
|
||||
self.server_host = raw_config["Napcat_Server"].get("host", "localhost")
|
||||
self.server_port = raw_config["Napcat_Server"].get("port", 8095)
|
||||
self.napcat_heartbeat_interval = raw_config["Napcat_Server"].get("heartbeat", 30)
|
||||
|
||||
self.mai_host = raw_config["MaiBot_Server"].get("host", "localhost")
|
||||
self.mai_port = raw_config["MaiBot_Server"].get("port", 8000)
|
||||
self.platform = raw_config["MaiBot_Server"].get("platform_name")
|
||||
if not self.platform:
|
||||
logger.critical("请在配置文件中指定平台")
|
||||
sys.exit(1)
|
||||
|
||||
self.group_list_type: str = raw_config["Chat"].get("group_list_type")
|
||||
self.group_list: list = raw_config["Chat"].get("group_list", [])
|
||||
self.private_list_type: str = raw_config["Chat"].get("private_list_type")
|
||||
self.private_list: list = raw_config["Chat"].get("private_list", [])
|
||||
self.ban_user_id: list = raw_config["Chat"].get("ban_user_id", [])
|
||||
self.enable_poke: bool = raw_config["Chat"].get("enable_poke", True)
|
||||
if self.group_list_type not in ["whitelist", "blacklist"]:
|
||||
logger.critical("请在配置文件中指定group_list_type或group_list_type填写错误")
|
||||
sys.exit(1)
|
||||
if self.private_list_type not in ["whitelist", "blacklist"]:
|
||||
logger.critical("请在配置文件中指定private_list_type或private_list_type填写错误")
|
||||
sys.exit(1)
|
||||
|
||||
self.use_tts = raw_config["Voice"].get("use_tts", False)
|
||||
|
||||
self.debug_level = raw_config["Debug"].get("level", "INFO")
|
||||
if self.debug_level == "DEBUG":
|
||||
logger.debug("原始配置文件内容:")
|
||||
logger.debug(raw_config)
|
||||
logger.debug("读取到的配置内容:")
|
||||
logger.debug(f"平台: {self.platform}")
|
||||
logger.debug(f"MaiBot服务器地址: {self.mai_host}:{self.mai_port}")
|
||||
logger.debug(f"Napcat服务器地址: {self.server_host}:{self.server_port}")
|
||||
logger.debug(f"心跳间隔: {self.napcat_heartbeat_interval}秒")
|
||||
logger.debug(f"群聊列表类型: {self.group_list_type}")
|
||||
logger.debug(f"群聊列表: {self.group_list}")
|
||||
logger.debug(f"私聊列表类型: {self.private_list_type}")
|
||||
logger.debug(f"私聊列表: {self.private_list}")
|
||||
logger.debug(f"禁用用户ID列表: {self.ban_user_id}")
|
||||
logger.debug(f"是否启用TTS: {self.use_tts}")
|
||||
logger.debug(f"调试级别: {self.debug_level}")
|
||||
|
||||
|
||||
global_config = Config()
|
||||
global_config.load_config()
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from .config import global_config
|
||||
|
||||
__all__ = [
|
||||
"global_config",
|
||||
]
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import tomlkit
|
||||
import shutil
|
||||
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from ..logger import logger
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.official_configs import (
|
||||
ChatConfig,
|
||||
DebugConfig,
|
||||
MaiBotServerConfig,
|
||||
NapcatServerConfig,
|
||||
NicknameConfig,
|
||||
VoiceConfig,
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
TEMPLATE_DIR = "template"
|
||||
|
||||
|
||||
def update_config():
|
||||
# 定义文件路径
|
||||
template_path = f"{TEMPLATE_DIR}/template_config.toml"
|
||||
old_config_path = "config.toml"
|
||||
new_config_path = "config.toml"
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(old_config_path):
|
||||
logger.info("配置文件不存在,从模板创建新配置")
|
||||
shutil.copy2(template_path, old_config_path) # 复制模板文件
|
||||
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
|
||||
# 如果是新创建的配置文件,直接返回
|
||||
quit()
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version")
|
||||
new_version = new_config["inner"].get("version")
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
else:
|
||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 备份文件名
|
||||
old_backup_path = "config.toml.back"
|
||||
|
||||
# 备份旧配置文件
|
||||
shutil.move(old_config_path, old_backup_path)
|
||||
logger.info(f"已备份旧配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path, new_config_path)
|
||||
logger.info(f"已创建新配置文件: {new_config_path}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
||||
update_dict(target[key], value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||
quit()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(ConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
nickname: NicknameConfig
|
||||
napcat_server: NapcatServerConfig
|
||||
maibot_server: MaiBotServerConfig
|
||||
chat: ChatConfig
|
||||
voice: VoiceConfig
|
||||
debug: DebugConfig
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
"""
|
||||
加载配置文件
|
||||
:param config_path: 配置文件路径
|
||||
:return: Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建Config对象
|
||||
try:
|
||||
return Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
# 更新配置
|
||||
update_config()
|
||||
|
||||
logger.info("正在品鉴配置文件...")
|
||||
global_config = load_config(config_path="config.toml")
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
TOML_DICT_TYPE = {
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
bool,
|
||||
list,
|
||||
dict,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigBase:
|
||||
"""配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
||||
"""从字典加载配置字段"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: Dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
field_type = f.type
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type)
|
||||
except TypeError as e:
|
||||
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
|
||||
|
||||
return cls(**init_args)
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
"""
|
||||
转换字段值为指定类型
|
||||
|
||||
1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
|
||||
2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
|
||||
3. 对于基础类型(int, str, float, bool),直接转换
|
||||
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
|
||||
"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
|
||||
return field_type.from_dict(value)
|
||||
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_args_type = get_args(field_type)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
# 检查提供的value是否为list
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
return [cls._convert_field(item, field_args_type[0]) for item in value]
|
||||
if field_origin_type is set:
|
||||
return {cls._convert_field(item, field_args_type[0]) for item in value}
|
||||
if field_origin_type is tuple:
|
||||
# 检查提供的value长度是否与类型参数一致
|
||||
if len(value) != len(field_args_type):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
|
||||
|
||||
if field_origin_type is dict:
|
||||
# 检查提供的value是否为dict
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 检查字典的键值类型
|
||||
if len(field_args_type) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_args_type
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理Optional类型
|
||||
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
|
||||
if value is None:
|
||||
return None
|
||||
# 如果有数据,检查实际类型
|
||||
if type(value) not in field_args_type:
|
||||
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
|
||||
return cls._convert_field(value, field_args_type[0])
|
||||
|
||||
# 处理int, str, float, bool等基础类型
|
||||
if field_origin_type is None:
|
||||
if isinstance(value, field_type):
|
||||
return field_type(value)
|
||||
else:
|
||||
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal:
|
||||
# 获取Literal的允许值
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
# 处理其他类型
|
||||
if field_type is Any:
|
||||
return value
|
||||
|
||||
# 其他类型直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
|
||||
|
||||
def __str__(self):
|
||||
"""返回配置类的字符串表示"""
|
||||
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
"""
|
||||
须知:
|
||||
1. 本文件中记录了所有的配置项
|
||||
2. 所有新增的class都需要继承自ConfigBase
|
||||
3. 所有新增的class都应在config.py中的Config类中添加字段
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
|
||||
ADAPTER_PLATFORM = "qq"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NicknameConfig(ConfigBase):
|
||||
nickname: str
|
||||
"""机器人昵称"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NapcatServerConfig(ConfigBase):
|
||||
host: str = "localhost"
|
||||
"""Napcat服务端的主机地址"""
|
||||
|
||||
port: int = 8095
|
||||
"""Napcat服务端的端口号"""
|
||||
|
||||
heartbeat_interval: int = 30
|
||||
"""Napcat心跳间隔时间,单位为秒"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiBotServerConfig(ConfigBase):
|
||||
platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
|
||||
"""平台名称,“qq”"""
|
||||
|
||||
host: str = "localhost"
|
||||
"""MaiMCore的主机地址"""
|
||||
|
||||
port: int = 8000
|
||||
"""MaiMCore的端口号"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatConfig(ConfigBase):
|
||||
group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""群聊列表类型 白名单/黑名单"""
|
||||
|
||||
group_list: list[int] = field(default_factory=[])
|
||||
"""群聊列表"""
|
||||
|
||||
private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""私聊列表类型 白名单/黑名单"""
|
||||
|
||||
private_list: list[int] = field(default_factory=[])
|
||||
"""私聊列表"""
|
||||
|
||||
ban_user_id: list[int] = field(default_factory=[])
|
||||
"""被封禁的用户ID列表,封禁后将无法与其进行交互"""
|
||||
|
||||
enable_poke: bool = True
|
||||
"""是否启用戳一戳功能"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
use_tts: bool = False
|
||||
"""是否启用TTS功能"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugConfig(ConfigBase):
|
||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
"""日志级别,默认为INFO"""
|
||||
|
|
@ -5,6 +5,6 @@ import sys
|
|||
logger.remove()
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
level=global_config.debug_level,
|
||||
level=global_config.debug.level,
|
||||
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from .send_handler import send_handler
|
|||
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
global_config.platform: TargetConfig(
|
||||
url=f"ws://{global_config.mai_host}:{global_config.mai_port}/ws",
|
||||
global_config.maibot_server.platform_name: TargetConfig(
|
||||
url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
|
||||
token=None,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import json
|
|||
import websockets as Server
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import uuid
|
||||
from plyer import notification
|
||||
import os
|
||||
|
||||
from . import MetaEventType, RealMessageType, MessageType, NoticeType
|
||||
from maim_message import (
|
||||
|
|
@ -36,7 +38,8 @@ class RecvHandler:
|
|||
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
self.interval = global_config.napcat_heartbeat_interval
|
||||
self.interval = global_config.napcat_server.heartbeat_interval
|
||||
self._interval_checking = False
|
||||
|
||||
async def handle_meta_event(self, message: dict) -> None:
|
||||
event_type = message.get("meta_event_type")
|
||||
|
|
@ -49,6 +52,8 @@ class RecvHandler:
|
|||
asyncio.create_task(self.check_heartbeat(self_id))
|
||||
elif event_type == MetaEventType.heartbeat:
|
||||
if message["status"].get("online") and message["status"].get("good"):
|
||||
if not self._interval_checking:
|
||||
asyncio.create_task(self.check_heartbeat())
|
||||
self.last_heart_beat = time.time()
|
||||
self.interval = message.get("interval") / 1000
|
||||
else:
|
||||
|
|
@ -56,10 +61,20 @@ class RecvHandler:
|
|||
logger.warning(f"Bot {self_id} Napcat 端异常!")
|
||||
|
||||
async def check_heartbeat(self, id: int) -> None:
|
||||
self._interval_checking = True
|
||||
while True:
|
||||
now_time = time.time()
|
||||
if now_time - self.last_heart_beat > self.interval + 3:
|
||||
logger.warning(f"Bot {id} 连接已断开")
|
||||
if now_time - self.last_heart_beat > self.interval * 2:
|
||||
logger.error(f"Bot {id} 连接已断开,被下线,或者Napcat卡死!")
|
||||
current_dir = os.path.dirname(__file__)
|
||||
icon_path = os.path.join(current_dir, "..", "assets", "maimai.ico")
|
||||
notification.notify(
|
||||
title="警告",
|
||||
message=f"Bot {id} 连接已断开,被下线,或者Napcat卡死!",
|
||||
app_name="MaiBot Napcat Adapter",
|
||||
timeout=10,
|
||||
app_icon=icon_path,
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.debug("心跳正常")
|
||||
|
|
@ -77,20 +92,20 @@ class RecvHandler:
|
|||
"""
|
||||
logger.debug(f"群聊id: {group_id}, 用户id: {user_id}")
|
||||
if group_id:
|
||||
if global_config.group_list_type == "whitelist" and group_id not in global_config.group_list:
|
||||
if global_config.chat.group_list_type == "whitelist" and group_id not in global_config.chat.group_list:
|
||||
logger.warning("群聊不在聊天白名单中,消息被丢弃")
|
||||
return False
|
||||
elif global_config.group_list_type == "blacklist" and group_id in global_config.group_list:
|
||||
elif global_config.chat.group_list_type == "blacklist" and group_id in global_config.chat.group_list:
|
||||
logger.warning("群聊在聊天黑名单中,消息被丢弃")
|
||||
return False
|
||||
else:
|
||||
if global_config.private_list_type == "whitelist" and user_id not in global_config.private_list:
|
||||
if global_config.chat.private_list_type == "whitelist" and user_id not in global_config.chat.private_list:
|
||||
logger.warning("私聊不在聊天白名单中,消息被丢弃")
|
||||
return False
|
||||
elif global_config.private_list_type == "blacklist" and user_id in global_config.private_list:
|
||||
elif global_config.chat.private_list_type == "blacklist" and user_id in global_config.chat.private_list:
|
||||
logger.warning("私聊在聊天黑名单中,消息被丢弃")
|
||||
return False
|
||||
if user_id in global_config.ban_user_id:
|
||||
if user_id in global_config.chat.ban_user_id:
|
||||
logger.warning("用户在全局黑名单中,消息被丢弃")
|
||||
return False
|
||||
return True
|
||||
|
|
@ -123,7 +138,7 @@ class RecvHandler:
|
|||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.platform,
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
|
|
@ -149,7 +164,7 @@ class RecvHandler:
|
|||
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.platform,
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=nickname,
|
||||
user_cardname=None,
|
||||
|
|
@ -164,7 +179,7 @@ class RecvHandler:
|
|||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.platform,
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
|
|
@ -182,7 +197,7 @@ class RecvHandler:
|
|||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.platform,
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
|
|
@ -195,7 +210,7 @@ class RecvHandler:
|
|||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.platform,
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
|
|
@ -205,12 +220,12 @@ class RecvHandler:
|
|||
return None
|
||||
|
||||
additional_config: dict = {}
|
||||
if global_config.use_tts:
|
||||
if global_config.voice.use_tts:
|
||||
additional_config["allow_tts"] = True
|
||||
|
||||
# 消息信息
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.platform,
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id=message_id,
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
|
|
@ -500,7 +515,7 @@ class RecvHandler:
|
|||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.Notify.poke:
|
||||
if global_config.enable_poke:
|
||||
if global_config.chat.enable_poke:
|
||||
handled_message: Seg = await self.handle_poke_notify(raw_message)
|
||||
else:
|
||||
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
|
||||
|
|
@ -532,7 +547,7 @@ class RecvHandler:
|
|||
source_name = "QQ用户"
|
||||
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.platform,
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=source_name,
|
||||
user_cardname=source_cardname,
|
||||
|
|
@ -547,13 +562,13 @@ class RecvHandler:
|
|||
else:
|
||||
logger.warning("无法获取戳一戳消息所在群的名称")
|
||||
group_info = GroupInfo(
|
||||
platform=global_config.platform,
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.platform,
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id="notice",
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
|
|
@ -697,7 +712,7 @@ class RecvHandler:
|
|||
user_nickname: str = sender_info.get("nickname", "QQ用户")
|
||||
user_nickname_str = f"【{user_nickname}】:"
|
||||
break_seg = Seg(type="text", data="\n")
|
||||
message_of_sub_message_list: dict = sub_message.get("message")
|
||||
message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message")
|
||||
if not message_of_sub_message_list:
|
||||
logger.warning("转发消息内容为空")
|
||||
continue
|
||||
|
|
@ -769,7 +784,9 @@ class RecvHandler:
|
|||
|
||||
async def message_process(self, message_base: MessageBase) -> None:
|
||||
try:
|
||||
await self.maibot_router.send_message(message_base)
|
||||
send_status = await self.maibot_router.send_message(message_base)
|
||||
if not send_status:
|
||||
raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常")
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {str(e)}")
|
||||
logger.error("请检查与MaiBot之间的连接")
|
||||
|
|
|
|||
|
|
@ -35,10 +35,10 @@ async def check_timeout_response() -> None:
|
|||
cleaned_message_count: int = 0
|
||||
now_time = time.time()
|
||||
for echo_id, response_time in list(response_time_dict.items()):
|
||||
if now_time - response_time > global_config.napcat_heartbeat_interval:
|
||||
if now_time - response_time > global_config.napcat_server.heartbeat_interval:
|
||||
cleaned_message_count += 1
|
||||
response_dict.pop(echo_id)
|
||||
response_time_dict.pop(echo_id)
|
||||
logger.warning(f"响应消息 {echo_id} 超时,已删除")
|
||||
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
|
||||
await asyncio.sleep(global_config.napcat_heartbeat_interval)
|
||||
await asyncio.sleep(global_config.napcat_server.heartbeat_interval)
|
||||
|
|
|
|||
|
|
@ -1,30 +1,33 @@
|
|||
[Nickname] # 现在没用
|
||||
[inner]
|
||||
version = "0.1.0" # 版本号
|
||||
# 请勿修改版本号,除非你知道自己在做什么
|
||||
|
||||
[nickname] # 现在没用
|
||||
nickname = ""
|
||||
|
||||
[Napcat_Server] # Napcat连接的ws服务设置
|
||||
host = "localhost" # Napcat设定的主机地址
|
||||
port = 8095 # Napcat设定的端口
|
||||
heartbeat = 30 # 与Napcat设置的心跳相同(按秒计)
|
||||
[napcat_server] # Napcat连接的ws服务设置
|
||||
host = "localhost" # Napcat设定的主机地址
|
||||
port = 8095 # Napcat设定的端口
|
||||
heartbeat_interval = 30 # 与Napcat设置的心跳相同(按秒计)
|
||||
|
||||
[MaiBot_Server] # 连接麦麦的ws服务设置
|
||||
platform_name = "qq" # 标识adapter的名称(必填)
|
||||
host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
|
||||
port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
|
||||
[maibot_server] # 连接麦麦的ws服务设置
|
||||
host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
|
||||
port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
|
||||
|
||||
[Chat] # 黑白名单功能
|
||||
[chat] # 黑白名单功能
|
||||
group_list_type = "whitelist" # 群组名单类型,可选为:whitelist, blacklist
|
||||
group_list = [] # 群组名单
|
||||
group_list = [] # 群组名单
|
||||
# 当group_list_type为whitelist时,只有群组名单中的群组可以聊天
|
||||
# 当group_list_type为blacklist时,群组名单中的任何群组无法聊天
|
||||
private_list_type = "whitelist" # 私聊名单类型,可选为:whitelist, blacklist
|
||||
private_list = [] # 私聊名单
|
||||
private_list = [] # 私聊名单
|
||||
# 当private_list_type为whitelist时,只有私聊名单中的用户可以聊天
|
||||
# 当private_list_type为blacklist时,私聊名单中的任何用户无法聊天
|
||||
ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天)
|
||||
ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天)
|
||||
enable_poke = true # 是否启用戳一戳功能
|
||||
|
||||
[Voice] # 发送语音设置
|
||||
[voice] # 发送语音设置
|
||||
use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter)
|
||||
|
||||
[Debug]
|
||||
level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR)
|
||||
[debug]
|
||||
level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
|
|
|
|||
Loading…
Reference in New Issue