diff --git a/.gitignore b/.gitignore index b2d679d..f0977e5 100644 --- a/.gitignore +++ b/.gitignore @@ -272,4 +272,6 @@ $RECYCLE.BIN/ config.toml config.toml.back test -data/qq_bot.json \ No newline at end of file +data/NapcatAdapter.db +data/NapcatAdapter.db-shm +data/NapcatAdapter.db-wal \ No newline at end of file diff --git a/README.md b/README.md index 4615f49..3e76e39 100644 --- a/README.md +++ b/README.md @@ -78,4 +78,6 @@ sequenceDiagram - [x] 群踢人功能 # 特别鸣谢 - 特别感谢[@Maple127667](https://github.com/Maple127667)对本项目代码思路的支持 \ No newline at end of file + 特别感谢[@Maple127667](https://github.com/Maple127667)对本项目代码思路的支持 + + 以及[@墨梓柒](https://github.com/DrSmoothl)对部分代码想法的支持 \ No newline at end of file diff --git a/command_args.md b/command_args.md index 8dff207..01390a7 100644 --- a/command_args.md +++ b/command_args.md @@ -13,6 +13,8 @@ Seg.data: Dict[str, Any] = { } ``` 其中,群聊ID将会通过Group_Info.group_id自动获取。 + +**当`duration`为 0 时相当于解除禁言。** ## 群聊全体禁言 ```python Seg.data: Dict[str, Any] = { diff --git a/main.py b/main.py index 2c71ef1..a928191 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,10 @@ import sys import json import websockets as Server from src.logger import logger -from src.recv_handler import recv_handler +from src.recv_handler.message_handler import message_handler +from src.recv_handler.meta_event_handler import meta_event_handler +from src.recv_handler.notice_handler import notice_handler +from src.recv_handler.message_sending import message_send_instance from src.send_handler import send_handler from src.config import global_config from src.mmc_com_layer import mmc_start_com, mmc_stop_com, router @@ -13,8 +16,9 @@ message_queue = asyncio.Queue() async def message_recv(server_connection: Server.ServerConnection): - recv_handler.server_connection = server_connection - send_handler.server_connection = server_connection + await message_handler.set_server_connection(server_connection) + asyncio.create_task(notice_handler.set_server_connection(server_connection)) + await send_handler.set_server_connection(server_connection) async for raw_message in server_connection: logger.debug( f"{raw_message[:100]}..." @@ -34,11 +38,11 @@ async def message_process(): message = await message_queue.get() post_type = message.get("post_type") if post_type == "message": - await recv_handler.handle_raw_message(message) + await message_handler.handle_raw_message(message) elif post_type == "meta_event": - await recv_handler.handle_meta_event(message) + await meta_event_handler.handle_meta_event(message) elif post_type == "notice": - await recv_handler.handle_notice(message) + await notice_handler.handle_notice(message) else: logger.warning(f"未知的post_type: {post_type}") message_queue.task_done() @@ -46,7 +50,7 @@ async def message_process(): async def main(): - recv_handler.maibot_router = router + message_send_instance.maibot_router = router _ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process(), check_timeout_response()) diff --git a/notify_args.md b/notify_args.md new file mode 100644 index 0000000..31e1151 --- /dev/null +++ b/notify_args.md @@ -0,0 +1,40 @@ +# Notify Args +```python +Seg.type = "notify" +``` +## 群聊成员被禁言 +```python +Seg.data: Dict[str, Any] = { + "sub_type": "ban", + "duration": "对应的禁言时间,单位为秒", + "banned_user_info": "被禁言的用户的信息,为标准UserInfo对象" +} +``` +此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息 +## 群聊开启全体禁言 +```python +Seg.data: Dict[str, Any] = { + "sub_type": "whole_ban", + "duration": -1, + "banned_user_info": None +} +``` +此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息 +## 群聊成员被解除禁言 +```python +Seg.data: Dict[str, Any] = { + "sub_type": "whole_lift_ban", + "lifted_user_info": "被解除禁言的用户的信息,为标准UserInfo对象" +} +``` +**对于自然禁言解除的情况,此时`MessageBase.UserInfo`为`None`** + +对于手动解除禁言的情况,此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息 +## 群聊关闭全体禁言 +```python +Seg.data: Dict[str, Any] = { + "sub_type": "whole_lift_ban", + "lifted_user_info": None, +} +``` +此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0fedfb2..2f6423d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "MaiBotNapcatAdapter" -version = "0.2.6" +version = "0.3.0" description = "A MaiBot adapter for Napcat" [tool.ruff] diff --git a/requirements.txt b/requirements.txt index c0ad7e0..54d8729 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,5 @@ requests maim_message loguru pillow -tomli tomlkit rich \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index bfae081..b1ac77e 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,69 +1,7 @@ from enum import Enum - - -class MetaEventType: - lifecycle = "lifecycle" # 生命周期 - - class Lifecycle: - connect = "connect" # 生命周期 - WebSocket 连接成功 - - heartbeat = "heartbeat" # 心跳 - - -class MessageType: # 接受消息大类 - private = "private" # 私聊消息 - - class Private: - friend = "friend" # 私聊消息 - 好友 - group = "group" # 私聊消息 - 群临时 - group_self = "group_self" # 私聊消息 - 群中自身发送 - other = "other" # 私聊消息 - 其他 - - group = "group" # 群聊消息 - - class Group: - normal = "normal" # 群聊消息 - 普通 - anonymous = "anonymous" # 群聊消息 - 匿名消息 - notice = "notice" # 群聊消息 - 系统提示 - - -class NoticeType: # 通知事件 - friend_recall = "friend_recall" # 私聊消息撤回 - group_recall = "group_recall" # 群聊消息撤回 - notify = "notify" - - class Notify: - poke = "poke" # 戳一戳 - - -class RealMessageType: # 实际消息分类 - text = "text" # 纯文本 - face = "face" # qq表情 - image = "image" # 图片 - record = "record" # 语音 - video = "video" # 视频 - at = "at" # @某人 - rps = "rps" # 猜拳魔法表情 - dice = "dice" # 骰子 - shake = "shake" # 私聊窗口抖动(只收) - poke = "poke" # 群聊戳一戳 - share = "share" # 链接分享(json形式) - reply = "reply" # 回复消息 - forward = "forward" # 转发消息 - node = "node" # 转发消息节点 - - -class MessageSentType: - private = "private" - - class Private: - friend = "friend" - group = "group" - - group = "group" - - class Group: - normal = "normal" +import tomlkit +import os +from .logger import logger class CommandType(Enum): @@ -76,3 +14,9 @@ class CommandType(Enum): def __str__(self) -> str: return self.value + + +pyproject_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "pyproject.toml") +toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read()) +version = toml_data["project"]["version"] +logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n") diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000..45100cb --- /dev/null +++ b/src/database.py @@ -0,0 +1,149 @@ +import os +from typing import Optional, List +from dataclasses import dataclass +from sqlmodel import Field, Session, SQLModel, create_engine, select + +from src.logger import logger + +""" +表记录的方式: +| group_id | user_id | lift_time | +|----------|---------|-----------| + +其中使用 user_id == 0 表示群全体禁言 +""" + + +@dataclass +class BanUser: + """ + 程序处理使用的实例 + """ + + user_id: int + group_id: int + lift_time: Optional[int] = Field(default=-1) + + +class DB_BanUser(SQLModel, table=True): + """ + 表示数据库中的用户禁言记录。 + 使用双重主键 + """ + + user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID + group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID + lift_time: Optional[int] # 禁言解除的时间(时间戳) + + +def is_identical(obj1: BanUser, obj2: BanUser) -> bool: + """ + 检查两个 BanUser 对象是否相同。 + """ + return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id + + +class DatabaseManager: + """ + 数据库管理类,负责与数据库交互。 + """ + + def __init__(self): + DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db") + self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL + self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎 + self._ensure_database() # 确保数据库和表已创建 + + def _ensure_database(self) -> None: + """ + 确保数据库和表已创建。 + """ + logger.info("确保数据库文件和表已创建...") + SQLModel.metadata.create_all(self.engine) + logger.success("数据库和表已创建或已存在") + + def update_ban_record(self, ban_list: List[BanUser]) -> None: + # sourcery skip: class-extract-method + """ + 更新禁言列表到数据库。 + 支持在不存在时创建新记录,对于多余的项目自动删除。 + """ + with Session(self.engine) as session: + all_records = session.exec(select(DB_BanUser)).all() + for ban_user in ban_list: + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id + ) + if existing_record := session.exec(statement).first(): + if existing_record.lift_time == ban_user.lift_time: + logger.debug(f"禁言记录未变更: {existing_record}") + continue + # 更新现有记录的 lift_time + existing_record.lift_time = ban_user.lift_time + session.add(existing_record) + logger.debug(f"更新禁言记录: {existing_record}") + else: + # 创建新记录 + db_record = DB_BanUser( + user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time + ) + session.add(db_record) + logger.debug(f"创建新禁言记录: {ban_user}") + # 删除不在 ban_list 中的记录 + for db_record in all_records: + record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time) + if not any(is_identical(record, ban_user) for ban_user in ban_list): + statement = select(DB_BanUser).where( + DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id + ) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + session.commit() + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: {ban_record}") + + session.commit() + logger.info("禁言记录已更新") + + def get_ban_records(self) -> List[BanUser]: + """ + 读取所有禁言记录。 + """ + with Session(self.engine) as session: + statement = select(DB_BanUser) + records = session.exec(statement).all() + return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records] + + def create_ban_record(self, ban_record: BanUser) -> None: + """ + 为特定群组中的用户创建禁言记录。 + 一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。 + 其同时还是简化版的更新方式。 + """ + with Session(self.engine) as session: + db_record = DB_BanUser( + user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + ) + session.add(db_record) + session.commit() + logger.debug(f"创建/更新禁言记录: {ban_record}") + + def delete_ban_record(self, ban_record: BanUser) -> bool: + """ + 删除特定用户在特定群组中的禁言记录。 + 一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。 + """ + user_id = ban_record.user_id + group_id = ban_record.group_id + with Session(self.engine) as session: + statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + session.commit() + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}") + + +db_manager = DatabaseManager() diff --git a/src/mmc_com_layer.py b/src/mmc_com_layer.py index ab50cca..f7fd1ad 100644 --- a/src/mmc_com_layer.py +++ b/src/mmc_com_layer.py @@ -11,7 +11,7 @@ route_config = RouteConfig( ) } ) -router = Router(route_config) +router = Router(route_config, logger) async def mmc_start_com(): diff --git a/src/recv_handler/__init__.py b/src/recv_handler/__init__.py new file mode 100644 index 0000000..40f0070 --- /dev/null +++ b/src/recv_handler/__init__.py @@ -0,0 +1,83 @@ +from enum import Enum + + +class MetaEventType: + lifecycle = "lifecycle" # 生命周期 + + class Lifecycle: + connect = "connect" # 生命周期 - WebSocket 连接成功 + + heartbeat = "heartbeat" # 心跳 + + +class MessageType: # 接受消息大类 + private = "private" # 私聊消息 + + class Private: + friend = "friend" # 私聊消息 - 好友 + group = "group" # 私聊消息 - 群临时 + group_self = "group_self" # 私聊消息 - 群中自身发送 + other = "other" # 私聊消息 - 其他 + + group = "group" # 群聊消息 + + class Group: + normal = "normal" # 群聊消息 - 普通 + anonymous = "anonymous" # 群聊消息 - 匿名消息 + notice = "notice" # 群聊消息 - 系统提示 + + +class NoticeType: # 通知事件 + friend_recall = "friend_recall" # 私聊消息撤回 + group_recall = "group_recall" # 群聊消息撤回 + notify = "notify" + group_ban = "group_ban" # 群禁言 + + class Notify: + poke = "poke" # 戳一戳 + + class GroupBan: + ban = "ban" # 禁言 + lift_ban = "lift_ban" # 解除禁言 + + +class RealMessageType: # 实际消息分类 + text = "text" # 纯文本 + face = "face" # qq表情 + image = "image" # 图片 + record = "record" # 语音 + video = "video" # 视频 + at = "at" # @某人 + rps = "rps" # 猜拳魔法表情 + dice = "dice" # 骰子 + shake = "shake" # 私聊窗口抖动(只收) + poke = "poke" # 群聊戳一戳 + share = "share" # 链接分享(json形式) + reply = "reply" # 回复消息 + forward = "forward" # 转发消息 + node = "node" # 转发消息节点 + + +class MessageSentType: + private = "private" + + class Private: + friend = "friend" + group = "group" + + group = "group" + + class Group: + normal = "normal" + + +class CommandType(Enum): + """命令类型""" + + GROUP_BAN = "set_group_ban" # 禁言用户 + GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 + GROUP_KICK = "set_group_kick" # 踢出群聊 + SEND_POKE = "send_poke" # 戳一戳 + + def __str__(self) -> str: + return self.value diff --git a/src/recv_handler.py b/src/recv_handler/message_handler.py similarity index 70% rename from src/recv_handler.py rename to src/recv_handler/message_handler.py index ab5a7eb..0cec56f 100644 --- a/src/recv_handler.py +++ b/src/recv_handler/message_handler.py @@ -1,14 +1,22 @@ -from .logger import logger -from .config import global_config +from src.logger import logger +from src.config import global_config +from src.utils import ( + get_group_info, + get_member_info, + get_image_base64, + get_self_info, + get_message_detail, +) from .qq_emoji_list import qq_face +from .message_sending import message_send_instance +from . import RealMessageType, MessageType + import time -import asyncio import json import websockets as Server from typing import List, Tuple, Optional, Dict, Any import uuid -from . import MetaEventType, RealMessageType, MessageType, NoticeType from maim_message import ( UserInfo, GroupInfo, @@ -17,97 +25,54 @@ from maim_message import ( MessageBase, TemplateInfo, FormatInfo, - Router, ) -from .utils import ( - get_group_info, - get_member_info, - get_image_base64, - get_self_info, - get_stranger_info, - get_message_detail, - read_bot_id, - update_bot_id, -) -from .response_pool import get_response + +from src.response_pool import get_response -class RecvHandler: - maibot_router: Router = None - +class MessageHandler: def __init__(self): self.server_connection: Server.ServerConnection = None - self.interval = global_config.napcat_server.heartbeat_interval - self._interval_checking = False self.bot_id_list: Dict[int, bool] = {} - async def handle_meta_event(self, message: dict) -> None: - event_type = message.get("meta_event_type") - if event_type == MetaEventType.lifecycle: - sub_type = message.get("sub_type") - if sub_type == MetaEventType.Lifecycle.connect: - self_id = message.get("self_id") - self.last_heart_beat = time.time() - logger.info(f"Bot {self_id} 连接成功") - asyncio.create_task(self.check_heartbeat(self_id)) - elif event_type == MetaEventType.heartbeat: - if message["status"].get("online") and message["status"].get("good"): - if not self._interval_checking: - asyncio.create_task(self.check_heartbeat()) - self.last_heart_beat = time.time() - self.interval = message.get("interval") / 1000 - else: - self_id = message.get("self_id") - logger.warning(f"Bot {self_id} Napcat 端异常!") + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection - async def check_heartbeat(self, id: int) -> None: - self._interval_checking = True - while True: - now_time = time.time() - if now_time - self.last_heart_beat > self.interval * 2: - logger.error(f"Bot {id} 连接已断开,被下线,或者Napcat卡死!") - break - else: - logger.debug("心跳正常") - await asyncio.sleep(self.interval) - - async def check_allow_to_chat(self, user_id: int, group_id: Optional[int]) -> bool: + async def check_allow_to_chat( + self, + user_id: int, + group_id: Optional[int] = None, + ignore_bot: Optional[bool] = False, + ignore_global_list: Optional[bool] = False, + ) -> bool: # sourcery skip: hoist-statement-from-if, merge-else-if-into-elif """ 检查是否允许聊天 Parameters: user_id: int: 用户ID group_id: int: 群ID + ignore_bot: bool: 是否忽略机器人检查 + ignore_global_list: bool: 是否忽略全局黑名单检查 Returns: bool: 是否允许聊天 """ - user_id = str(user_id) logger.debug(f"群聊id: {group_id}, 用户id: {user_id}") - if global_config.chat.ban_qq_bot and group_id: + if global_config.chat.ban_qq_bot and group_id and not ignore_bot: logger.debug("开始判断是否为机器人") - if not self.bot_id_list: - self.bot_id_list = read_bot_id() - if user_id in self.bot_id_list: - if self.bot_id_list[user_id]: - logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃") - return False - else: - member_info = await get_member_info(self.server_connection, group_id, user_id) - if member_info: - is_bot = member_info.get("is_robot") - if is_bot is None: - logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新") + member_info = await get_member_info(self.server_connection, group_id, user_id) + if member_info: + is_bot = member_info.get("is_robot") + if is_bot is None: + logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新") + else: + if is_bot: + logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单") + self.bot_id_list[user_id] = True + return False else: - if is_bot: - logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单") - self.bot_id_list[user_id] = True - update_bot_id(self.bot_id_list) - return False - else: - self.bot_id_list[user_id] = False - update_bot_id(self.bot_id_list) - user_id = int(user_id) + self.bot_id_list[user_id] = False logger.debug("开始检查聊天白名单/黑名单") if group_id: if global_config.chat.group_list_type == "whitelist" and group_id not in global_config.chat.group_list: @@ -123,7 +88,7 @@ class RecvHandler: elif global_config.chat.private_list_type == "blacklist" and user_id in global_config.chat.private_list: logger.warning("私聊在聊天黑名单中,消息被丢弃") return False - if user_id in global_config.chat.ban_user_id: + if user_id in global_config.chat.ban_user_id and not ignore_global_list: logger.warning("用户在全局黑名单中,消息被丢弃") return False return True @@ -275,7 +240,7 @@ class RecvHandler: ) logger.info("发送到Maibot处理信息") - await self.message_process(message_base) + await message_send_instance.message_send(message_base) async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None: # sourcery skip: low-code-quality @@ -343,7 +308,7 @@ class RecvHandler: case RealMessageType.share: logger.warning("暂时不支持链接解析") case RealMessageType.forward: - messages = await self.get_forward_message(sub_message) + messages = await self._get_forward_message(sub_message) if not messages: logger.warning("转发消息内容为空或获取失败") return None @@ -440,40 +405,6 @@ class RecvHandler: else: return None - async def get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None: - forward_message_data: Dict = raw_message.get("data") - if not forward_message_data: - logger.warning("转发消息内容为空") - return None - forward_message_id = forward_message_data.get("id") - request_uuid = str(uuid.uuid4()) - payload = json.dumps( - { - "action": "get_forward_msg", - "params": {"message_id": forward_message_id}, - "echo": request_uuid, - } - ) - try: - await self.server_connection.send(payload) - response: dict = await get_response(request_uuid) - except TimeoutError: - logger.error("获取转发消息超时") - return None - except Exception as e: - logger.error(f"获取转发消息失败: {str(e)}") - return None - logger.debug( - f"转发消息原始格式:{json.dumps(response)[:80]}..." - if len(json.dumps(response)) > 80 - else json.dumps(response) - ) - response_data: Dict = response.get("data") - if not response_data: - logger.warning("转发消息内容为空或获取失败") - return None - return response_data.get("messages") - async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None: # sourcery skip: move-assign-in-block, use-named-expression """ @@ -506,156 +437,6 @@ class RecvHandler: seg_message.append(Seg(type="text", data="],说:")) return seg_message - async def handle_notice(self, raw_message: dict) -> None: - notice_type = raw_message.get("notice_type") - # message_time: int = raw_message.get("time") - message_time: float = time.time() # 应可乐要求,现在是float了 - - group_id = raw_message.get("group_id") - user_id = raw_message.get("user_id") - target_id = raw_message.get("target_id") - - if not await self.check_allow_to_chat(user_id, group_id): - logger.warning("notice消息被丢弃") - return None - - handled_message: Seg = None - - match notice_type: - case NoticeType.friend_recall: - logger.info("好友撤回一条消息") - logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") - logger.warning("暂时不支持撤回消息处理") - case NoticeType.group_recall: - logger.info("群内用户撤回一条消息") - logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") - logger.warning("暂时不支持撤回消息处理") - case NoticeType.notify: - sub_type = raw_message.get("sub_type") - match sub_type: - case NoticeType.Notify.poke: - if global_config.chat.enable_poke: - handled_message: Seg = await self.handle_poke_notify(raw_message) - else: - logger.warning("戳一戳消息被禁用,取消戳一戳处理") - case _: - logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") - case _: - logger.warning(f"不支持的notice类型: {notice_type}") - return None - if not handled_message: - logger.warning("notice处理失败或不支持") - return None - - source_name: str = None - source_cardname: str = None - if group_id: - member_info: dict = await get_member_info(self.server_connection, group_id, user_id) - if member_info: - source_name = member_info.get("nickname") - source_cardname = member_info.get("card") - else: - logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") - source_name = "QQ用户" - else: - stranger_info = await get_stranger_info(self.server_connection, user_id) - if stranger_info: - source_name = stranger_info.get("nickname") - else: - logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") - source_name = "QQ用户" - - user_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, - user_id=user_id, - user_nickname=source_name, - user_cardname=source_cardname, - ) - - group_info: GroupInfo = None - if group_id: - fetched_group_info = await get_group_info(self.server_connection, group_id) - group_name: str = None - if fetched_group_info: - group_name = fetched_group_info.get("group_name") - else: - logger.warning("无法获取戳一戳消息所在群的名称") - group_info = GroupInfo( - platform=global_config.maibot_server.platform_name, - group_id=group_id, - group_name=group_name, - ) - - message_info: BaseMessageInfo = BaseMessageInfo( - platform=global_config.maibot_server.platform_name, - message_id="notice", - time=message_time, - user_info=user_info, - group_info=group_info, - template_info=None, - format_info=None, - additional_config = {"target_id": target_id}# 在这里塞了一个target_id,方便mmc那边知道被戳的人是谁 - ) - - message_base: MessageBase = MessageBase( - message_info=message_info, - message_segment=handled_message, - raw_message=json.dumps(raw_message), - ) - - logger.info("发送到Maibot处理通知信息") - await self.message_process(message_base) - - async def handle_poke_notify(self, raw_message: dict) -> Seg | None: - self_info: dict = await get_self_info(self.server_connection) - - if not self_info: - logger.error("自身信息获取失败") - return None - self_id = raw_message.get("self_id") - target_id = raw_message.get("target_id") - group_id = raw_message.get("group_id") - user_id = raw_message.get("user_id") - target_name: str = None - raw_info: list = raw_message.get("raw_info") - # 计算Seg - if self_id == target_id: # 现在这里应当是专注于处理私聊戳一戳的,也就是说当私聊里,被戳的是另一方时,不会给这个消息。 - target_name = self_info.get("nickname") - user_name = "" # 这样的话应该能保证消息大概是“某某某:戳了戳麦麦”,而不是“某某某:某某某戳了戳麦麦” - - elif self_id == user_id: - return None # 这应当让ada不发送麦麦戳别人的消息,因为这个消息已经被mmc的命令记录了,没必要记第二次。 - - else: - if group_id: # 如果是群聊环境,老实说做这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳的,但是感觉可以有这个判定来强限制群聊环境 - user_info: dict = await get_member_info( - self.server_connection, group_id, user_id - ) - fetched_member_info: dict = await get_member_info( - self.server_connection, group_id, target_id - ) - if user_info: - user_name = user_info.get("nickname") - else: - user_name = "QQ用户" - if fetched_member_info: - target_name = fetched_member_info.get("nickname") - else: - target_name = "QQ用户" - else: - return None - try: - first_txt = raw_info[2].get("txt", "戳了戳") - except Exception as e: - logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本") - first_txt = "戳了戳" - - seg_data: Seg = Seg( - type="text", - data=f"{user_name}{first_txt}{target_name}(这是QQ的一个功能,用于提及某人,但没那么明显)", - ) - return seg_data - async def handle_forward_message(self, message_list: list) -> Seg | None: """ 递归处理转发消息,并按照动态方式确定图片处理方式 @@ -814,15 +595,39 @@ class RecvHandler: seg_list.append(full_seg_data) return Seg(type="seglist", data=seg_list), image_count - async def message_process(self, message_base: MessageBase) -> None: - try: - send_status = await self.maibot_router.send_message(message_base) - if not send_status: - raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常") - except Exception as e: - logger.error(f"发送消息失败: {str(e)}") - logger.error("请检查与MaiBot之间的连接") + async def _get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None: + forward_message_data: Dict = raw_message.get("data") + if not forward_message_data: + logger.warning("转发消息内容为空") return None + forward_message_id = forward_message_data.get("id") + request_uuid = str(uuid.uuid4()) + payload = json.dumps( + { + "action": "get_forward_msg", + "params": {"message_id": forward_message_id}, + "echo": request_uuid, + } + ) + try: + await self.server_connection.send(payload) + response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error("获取转发消息超时") + return None + except Exception as e: + logger.error(f"获取转发消息失败: {str(e)}") + return None + logger.debug( + f"转发消息原始格式:{json.dumps(response)[:80]}..." + if len(json.dumps(response)) > 80 + else json.dumps(response) + ) + response_data: Dict = response.get("data") + if not response_data: + logger.warning("转发消息内容为空或获取失败") + return None + return response_data.get("messages") -recv_handler = RecvHandler() +message_handler = MessageHandler() diff --git a/src/recv_handler/message_sending.py b/src/recv_handler/message_sending.py new file mode 100644 index 0000000..2e43bbd --- /dev/null +++ b/src/recv_handler/message_sending.py @@ -0,0 +1,31 @@ +from src.logger import logger +from maim_message import MessageBase, Router + + +class MessageSending: + """ + 负责把消息发送到麦麦 + """ + + maibot_router: Router = None + + def __init__(self): + pass + + async def message_send(self, message_base: MessageBase) -> bool: + """ + 发送消息 + Parameters: + message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息 + """ + try: + send_status = await self.maibot_router.send_message(message_base) + if not send_status: + raise RuntimeError("可能是路由未正确配置或连接异常") + except Exception as e: + logger.error(f"发送消息失败: {str(e)}") + logger.error("请检查与MaiBot之间的连接") + return send_status + + +message_send_instance = MessageSending() diff --git a/src/recv_handler/meta_event_handler.py b/src/recv_handler/meta_event_handler.py new file mode 100644 index 0000000..bb9efe6 --- /dev/null +++ b/src/recv_handler/meta_event_handler.py @@ -0,0 +1,49 @@ +from src.logger import logger +from src.config import global_config +import time +import asyncio + +from . import MetaEventType + + +class MetaEventHandler: + """ + 处理Meta事件 + """ + + def __init__(self): + self.interval = global_config.napcat_server.heartbeat_interval + self._interval_checking = False + + async def handle_meta_event(self, message: dict) -> None: + event_type = message.get("meta_event_type") + if event_type == MetaEventType.lifecycle: + sub_type = message.get("sub_type") + if sub_type == MetaEventType.Lifecycle.connect: + self_id = message.get("self_id") + self.last_heart_beat = time.time() + logger.info(f"Bot {self_id} 连接成功") + asyncio.create_task(self.check_heartbeat(self_id)) + elif event_type == MetaEventType.heartbeat: + if message["status"].get("online") and message["status"].get("good"): + if not self._interval_checking: + asyncio.create_task(self.check_heartbeat()) + self.last_heart_beat = time.time() + self.interval = message.get("interval") / 1000 + else: + self_id = message.get("self_id") + logger.warning(f"Bot {self_id} Napcat 端异常!") + + async def check_heartbeat(self, id: int) -> None: + self._interval_checking = True + while True: + now_time = time.time() + if now_time - self.last_heart_beat > self.interval * 2: + logger.error(f"Bot {id} 可能发生了连接断开,被下线,或者Napcat卡死!") + break + else: + logger.debug("心跳正常") + await asyncio.sleep(self.interval) + + +meta_event_handler = MetaEventHandler() diff --git a/src/recv_handler/notice_handler.py b/src/recv_handler/notice_handler.py new file mode 100644 index 0000000..2d03a49 --- /dev/null +++ b/src/recv_handler/notice_handler.py @@ -0,0 +1,506 @@ +import time +import json +import asyncio +import websockets as Server +from typing import Tuple, Optional + +from src.logger import logger +from src.config import global_config +from src.database import BanUser, db_manager, is_identical +from . import NoticeType +from .message_sending import message_send_instance +from .message_handler import message_handler +from maim_message import UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase + +from src.utils import ( + get_group_info, + get_member_info, + get_self_info, + get_stranger_info, + read_ban_list, +) + +notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100) +unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3) + + +class NoticeHandler: + banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表 + lifted_list: list[BanUser] = [] # 已经自然解除禁言 + + def __init__(self): + self.server_connection: Server.ServerConnection = None + + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection + + while self.server_connection.state != Server.State.OPEN: + await asyncio.sleep(0.5) + self.banned_list, self.lifted_list = await read_ban_list(self.server_connection) + + asyncio.create_task(self.auto_lift_detect()) + asyncio.create_task(self.send_notice()) + asyncio.create_task(self.handle_natural_lift()) + + def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: + """ + 将用户禁言记录添加到self.banned_list中 + 如果是全体禁言,则user_id为0 + """ + if user_id is None: + user_id = 0 # 使用0表示全体禁言 + lift_time = -1 + ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time) + for record in self.banned_list: + if is_identical(record, ban_record): + self.banned_list.remove(record) + self.banned_list.append(ban_record) + db_manager.create_ban_record(ban_record) # 作为更新 + return + self.banned_list.append(ban_record) + db_manager.create_ban_record(ban_record) # 添加到数据库 + + def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: + """ + 从self.lifted_group_list中移除已经解除全体禁言的群 + """ + if user_id is None: + user_id = 0 # 使用0表示全体禁言 + ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1) + self.lifted_list.append(ban_record) + db_manager.delete_ban_record(ban_record) # 删除数据库中的记录 + + async def handle_notice(self, raw_message: dict) -> None: + notice_type = raw_message.get("notice_type") + # message_time: int = raw_message.get("time") + message_time: float = time.time() # 应可乐要求,现在是float了 + + group_id = raw_message.get("group_id") + user_id = raw_message.get("user_id") + + handled_message: Seg = None + user_info: UserInfo = None + system_notice: bool = False + + match notice_type: + case NoticeType.friend_recall: + logger.info("好友撤回一条消息") + logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") + logger.warning("暂时不支持撤回消息处理") + case NoticeType.group_recall: + logger.info("群内用户撤回一条消息") + logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") + logger.warning("暂时不支持撤回消息处理") + case NoticeType.notify: + sub_type = raw_message.get("sub_type") + match sub_type: + case NoticeType.Notify.poke: + if global_config.chat.enable_poke and await message_handler.check_allow_to_chat( + user_id, group_id, False, False + ): + logger.info("处理戳一戳消息") + handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id) + else: + logger.warning("戳一戳消息被禁用,取消戳一戳处理") + case _: + logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") + case NoticeType.group_ban: + sub_type = raw_message.get("sub_type") + match sub_type: + case NoticeType.GroupBan.ban: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理群禁言") + handled_message, user_info = await self.handle_ban_notify(raw_message, group_id) + system_notice = True + case NoticeType.GroupBan.lift_ban: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理解除群禁言") + handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id) + system_notice = True + case _: + logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}") + case _: + logger.warning(f"不支持的notice类型: {notice_type}") + return None + if not handled_message or not user_info: + logger.warning("notice处理失败或不支持") + return None + + group_info: GroupInfo = None + if group_id: + fetched_group_info = await get_group_info(self.server_connection, group_id) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + else: + logger.warning("无法获取notice消息所在群的名称") + group_info = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=group_id, + group_name=group_name, + ) + + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id="notice", + time=message_time, + user_info=user_info, + group_info=group_info, + template_info=None, + format_info=None, + ) + + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=handled_message, + raw_message=json.dumps(raw_message), + ) + + if system_notice: + await self.put_notice(message_base) + else: + logger.info("发送到Maibot处理通知信息") + await message_send_instance.message_send(message_base) + + async def handle_poke_notify(self, raw_message: dict, group_id: int, user_id: int) -> Tuple[Seg | None, UserInfo]: + self_info: dict = await get_self_info(self.server_connection) + if not self_info: + logger.error("自身信息获取失败") + return None + self_id = raw_message.get("self_id") + target_id = raw_message.get("target_id") + target_name: str = None + raw_info: list = raw_message.get("raw_info") + + # 计算user_info + source_name: str = None + source_cardname: str = None + if group_id: + member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if member_info: + source_name = member_info.get("nickname") + source_cardname = member_info.get("card") + else: + logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") + source_name = "QQ用户" + else: + stranger_info = await get_stranger_info(self.server_connection, user_id) + if stranger_info: + source_name = stranger_info.get("nickname") + else: + logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") + source_name = "QQ用户" + + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=source_name, + user_cardname=source_cardname, + ) + + # 计算Seg + if self_id == target_id: + target_name = self_info.get("nickname") + else: + return None + try: + first_txt = raw_info[2].get("txt", "戳了戳") + second_txt = raw_info[4].get("txt", "") + except Exception as e: + logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本") + first_txt = "戳了戳" + second_txt = "" + """ + # 不启用戳其他人的处理 + else: + # 由于Napcat不支持获取昵称,所以需要单独获取 + group_id = raw_message.get("group_id") + fetched_member_info: dict = await get_member_info( + self.server_connection, group_id, target_id + ) + if fetched_member_info: + target_name = fetched_member_info.get("nickname") + """ + seg_data: Seg = Seg( + type="text", + data=f"{first_txt}{target_name}{second_txt}(这是QQ的一个功能,用于提及某人,但没那么明显)", + ) + return seg_data, user_info + + async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: + if not group_id: + logger.error("群ID不能为空,无法处理禁言通知") + return None, None + + # 计算user_info + operator_id = raw_message.get("operator_id") + operator_nickname: str = None + operator_cardname: str = None + + member_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if member_info: + operator_nickname = member_info.get("nickname") + operator_cardname = member_info.get("card") + else: + logger.warning("无法获取禁言执行者的昵称,消息可能会无效") + operator_nickname = "QQ用户" + + operator_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_nickname, + user_cardname=operator_cardname, + ) + + # 计算Seg + user_id = raw_message.get("user_id") + banned_user_info: UserInfo = None + user_nickname: str = "QQ用户" + user_cardname: str = None + sub_type: str = None + + duration = raw_message.get("duration") + if duration is None: + logger.error("禁言时长不能为空,无法处理禁言通知") + return None, None + + if user_id == 0: # 为全体禁言 + sub_type: str = "whole_ban" + self._ban_operation(group_id) + else: # 为单人禁言 + # 获取被禁言人的信息 + sub_type: str = "ban" + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + banned_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + self._ban_operation(group_id, user_id, int(time.time() + duration)) + + seg_data: Seg = Seg( + type="notify", + data={ + "sub_type": sub_type, + "duration": duration, + "banned_user_info": banned_user_info, + }, + ) + + return seg_data, operator_info + + async def handle_lift_ban_notify( + self, raw_message: dict, group_id: int + ) -> Tuple[Seg, UserInfo] | Tuple[None, None]: + if not group_id: + logger.error("群ID不能为空,无法处理解除禁言通知") + return None, None + + # 计算user_info + operator_id = raw_message.get("operator_id") + operator_nickname: str = None + operator_cardname: str = None + + member_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if member_info: + operator_nickname = member_info.get("nickname") + operator_cardname = member_info.get("card") + else: + logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效") + operator_nickname = "QQ用户" + + operator_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_nickname, + user_cardname=operator_cardname, + ) + + # 计算Seg + sub_type: str = None + user_nickname: str = "QQ用户" + user_cardname: str = None + lifted_user_info: UserInfo = None + + user_id = raw_message.get("user_id") + if user_id == 0: # 全体禁言解除 + sub_type = "whole_lift_ban" + self._lift_operation(group_id) + else: # 单人禁言解除 + sub_type = "lift_ban" + # 获取被解除禁言人的信息 + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + else: + logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效") + lifted_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + self._lift_operation(group_id, user_id) + + seg_data: Seg = Seg( + type="notify", + data={ + "sub_type": sub_type, + "lifted_user_info": lifted_user_info, + }, + ) + return seg_data, operator_info + + async def put_notice(self, message_base: MessageBase) -> None: + """ + 将处理后的通知消息放入通知队列 + """ + if notice_queue.full() or unsuccessful_notice_queue.full(): + logger.warning("通知队列已满,可能是多次发送失败,消息丢弃") + else: + await notice_queue.put(message_base) + + async def handle_natural_lift(self) -> None: + while True: + if len(self.lifted_list) != 0: + lift_record = self.lifted_list.pop() + group_id = lift_record.group_id + user_id = lift_record.user_id + + db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录 + + seg_message: Seg = await self.natural_lift(group_id, user_id) + + fetched_group_info = await get_group_info(self.server_connection, group_id) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + else: + logger.warning("无法获取notice消息所在群的名称") + group_info = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=group_id, + group_name=group_name, + ) + + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id="notice", + time=time.time(), + user_info=None, # 自然解除禁言没有操作者 + group_info=group_info, + template_info=None, + format_info=None, + ) + + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=seg_message, + raw_message=json.dumps( + { + "post_type": "notice", + "notice_type": "group_ban", + "sub_type": "lift_ban", + "group_id": group_id, + "user_id": user_id, + "operator_id": None, # 自然解除禁言没有操作者 + } + ), + ) + + await self.put_notice(message_base) + await asyncio.sleep(0.5) # 确保队列处理间隔 + else: + await asyncio.sleep(5) # 每5秒检查一次 + + async def natural_lift(self, group_id: int, user_id: int) -> Seg | None: + if not group_id: + logger.error("群ID不能为空,无法处理解除禁言通知") + return None + + if user_id == 0: # 理论上永远不会触发 + return Seg( + type="notify", + data={ + "sub_type": "whole_lift_ban", + "lifted_user_info": None, + }, + ) + + user_nickname: str = "QQ用户" + user_cardname: str = None + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + + lifted_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + + return Seg( + type="notify", + data={ + "sub_type": "lift_ban", + "lifted_user_info": lifted_user_info, + }, + ) + + async def auto_lift_detect(self) -> None: + while True: + if len(self.banned_list) == 0: + await asyncio.sleep(5) + continue + for ban_record in self.banned_list: + if ban_record.user_id == 0 or ban_record.lift_time == -1: + continue + if ban_record.lift_time <= int(time.time()): + # 触发自然解除禁言 + logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除") + self.lifted_list.append(ban_record) + self.banned_list.remove(ban_record) + await asyncio.sleep(5) + + async def send_notice(self) -> None: + """ + 发送通知消息到Napcat + """ + while True: + if not unsuccessful_notice_queue.empty(): + to_be_send: MessageBase = await unsuccessful_notice_queue.get() + try: + send_status = await message_send_instance.message_send(to_be_send) + if send_status: + unsuccessful_notice_queue.task_done() + else: + await unsuccessful_notice_queue.put(to_be_send) + except Exception as e: + logger.error(f"发送通知消息失败: {str(e)}") + await unsuccessful_notice_queue.put(to_be_send) + await asyncio.sleep(1) + continue + to_be_send: MessageBase = await notice_queue.get() + try: + send_status = await message_send_instance.message_send(to_be_send) + if send_status: + notice_queue.task_done() + else: + await unsuccessful_notice_queue.put(to_be_send) + except Exception as e: + logger.error(f"发送通知消息失败: {str(e)}") + await unsuccessful_notice_queue.put(to_be_send) + await asyncio.sleep(1) + + +notice_handler = NoticeHandler() diff --git a/src/qq_emoji_list.py b/src/recv_handler/qq_emoji_list.py similarity index 100% rename from src/qq_emoji_list.py rename to src/recv_handler/qq_emoji_list.py diff --git a/src/send_handler.py b/src/send_handler.py index 2a6709e..6cb3094 100644 --- a/src/send_handler.py +++ b/src/send_handler.py @@ -21,6 +21,10 @@ class SendHandler: def __init__(self): self.server_connection: Server.ServerConnection = None + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection + async def handle_message(self, raw_message_base_dict: dict) -> None: raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict) message_segment: Seg = raw_message_base.message_segment @@ -254,8 +258,8 @@ class SendHandler: duration: int = int(args["duration"]) user_id: int = int(args["qq_id"]) group_id: int = int(group_info.group_id) - if duration <= 0: - raise ValueError("封禁时间必须大于0") + if duration < 0: + raise ValueError("封禁时间必须大于等于0") if not user_id or not group_id: raise ValueError("封禁命令缺少必要参数") if duration > 2592000: diff --git a/src/utils.py b/src/utils.py index 9e9941c..caa0b56 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,15 +2,16 @@ import websockets as Server import json import base64 import uuid +import urllib3 +import ssl +import io + +from src.database import BanUser, db_manager from .logger import logger from .response_pool import get_response -import urllib3 -import ssl -from pathlib import Path from PIL import Image -import io -import os +from typing import Union, List, Tuple class SSLAdapter(urllib3.PoolManager): @@ -22,7 +23,7 @@ class SSLAdapter(urllib3.PoolManager): super().__init__(*args, **kwargs) -async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict: +async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: """ 获取群相关信息 @@ -44,7 +45,29 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d return socket_response.get("data") -async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict: +async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: + """ + 获取群详细信息 + + 返回值需要处理可能为空的情况 + """ + logger.debug("获取群详细信息中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}) + try: + await websocket.send(payload) + socket_response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error(f"获取群详细信息超时,群号: {group_id}") + return None + except Exception as e: + logger.error(f"获取群详细信息失败: {e}") + return None + logger.debug(socket_response) + return socket_response.get("data") + + +async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None: """ 获取群成员信息 @@ -110,7 +133,7 @@ def convert_image_to_gif(image_base64: str) -> str: return image_base64 -async def get_self_info(websocket: Server.ServerConnection) -> dict: +async def get_self_info(websocket: Server.ServerConnection) -> dict | None: """ 获取自身信息 Parameters: @@ -146,7 +169,7 @@ def get_image_format(raw_data: str) -> str: return Image.open(io.BytesIO(image_bytes)).format.lower() -async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict: +async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None: """ 获取陌生人信息 Parameters: @@ -171,7 +194,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> return response.get("data") -async def get_message_detail(websocket: Server.ServerConnection, message_id: str) -> dict: +async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None: """ 获取消息详情,可能为空 Parameters: @@ -196,41 +219,58 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: str return response.get("data") -def update_bot_id(data: dict) -> None: +async def read_ban_list( + websocket: Server.ServerConnection, +) -> Tuple[List[BanUser], List[BanUser]]: """ - 更新用户是否为机器人的字典到根目录下的data文件夹中的qq_bot.json。 - Parameters: - data: dict: 包含需要更新的信息。 - """ - json_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "qq_bot.json") - try: - with open(json_path, "w", encoding="utf-8") as json_file: - json.dump(data, json_file, ensure_ascii=False, indent=4) - logger.info(f"ID字典已更新到文件: {json_path}") - except Exception as e: - logger.error(f"更新ID字典失败: {e}") - - -def read_bot_id() -> dict: - """ - 从根目录下的data文件夹中的文件读取机器人ID。 + 从根目录下的data文件夹中的文件读取禁言列表。 + 同时自动更新已经失效禁言 Returns: - list: 读取的机器人ID信息。 + Tuple[ + 一个仍在禁言中的用户的BanUser列表, + 一个已经自然解除禁言的用户的BanUser列表, + 一个仍在全体禁言中的群的BanUser列表, + 一个已经自然解除全体禁言的群的BanUser列表, + ] """ - json_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "qq_bot.json") try: - with open(json_path, "r", encoding="utf-8") as json_file: - data = json.load(json_file) - logger.info(f"已读取机器人ID信息: {data}") - return data - except FileNotFoundError: - logger.warning(f"文件未找到: {json_path},正在自动创建文件") - json_path = Path(os.path.dirname(os.path.dirname(__file__))) / "data" / "qq_bot.json" - # 确保父目录存在 - json_path.parent.mkdir(parents=True, exist_ok=True) - # 创建空文件 - json_path.touch(exist_ok=True) - return {} + ban_list = db_manager.get_ban_records() + lifted_list: List[BanUser] = [] + logger.info("已经读取禁言列表") + for ban_record in ban_list: + if ban_record.user_id == 0: + fetched_group_info = await get_group_info(websocket, ban_record.group_id) + if fetched_group_info is None: + logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除") + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + group_all_shut: int = fetched_group_info.get("group_all_shut") + if group_all_shut == 0: + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + else: + fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id) + if fetched_member_info is None: + logger.warning( + f"无法获取群成员信息,用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除" + ) + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + lift_ban_time: int = fetched_member_info.get("shut_up_timestamp") + if lift_ban_time == 0: + lifted_list.append(ban_record) + ban_list.remove(ban_record) + else: + ban_record.lift_time = lift_ban_time + db_manager.update_ban_record(ban_list) + return ban_list, lifted_list except Exception as e: - logger.error(f"读取机器人ID失败: {e}") - return {} + logger.error(f"读取禁言列表失败: {e}") + return [], [] + + +def save_ban_record(list: List[BanUser]): + return db_manager.update_ban_record(list)