From ca0fc4db1123cc2cbc9ad74cfa0f5308ab718a7c Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 28 Jun 2025 01:54:03 +0800 Subject: [PATCH 1/5] =?UTF-8?q?REFACTOR=20=E4=B8=8E=E7=A6=81=E8=A8=80?= =?UTF-8?q?=E6=A3=80=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- README.md | 4 +- command_args.md | 2 + data/NapcatAdapter.db | Bin 0 -> 20480 bytes main.py | 18 +- notify_args.md | 40 ++ src/__init__.py | 65 --- src/database.py | 121 +++++ src/recv_handler/__init__.py | 83 +++ .../message_handler.py} | 333 +++--------- src/recv_handler/message_sending.py | 31 ++ src/recv_handler/meta_event_handler.py | 49 ++ src/recv_handler/notice_handler.py | 493 ++++++++++++++++++ src/{ => recv_handler}/qq_emoji_list.py | 0 src/send_handler.py | 8 +- src/utils.py | 118 +++-- 16 files changed, 996 insertions(+), 372 deletions(-) create mode 100644 data/NapcatAdapter.db create mode 100644 notify_args.md create mode 100644 src/database.py create mode 100644 src/recv_handler/__init__.py rename src/{recv_handler.py => recv_handler/message_handler.py} (72%) create mode 100644 src/recv_handler/message_sending.py create mode 100644 src/recv_handler/meta_event_handler.py create mode 100644 src/recv_handler/notice_handler.py rename src/{ => recv_handler}/qq_emoji_list.py (100%) diff --git a/.gitignore b/.gitignore index b2d679d..60f4dc6 100644 --- a/.gitignore +++ b/.gitignore @@ -272,4 +272,5 @@ $RECYCLE.BIN/ config.toml config.toml.back test -data/qq_bot.json \ No newline at end of file +data/qq_bot.json +data/ban_list.json \ No newline at end of file diff --git a/README.md b/README.md index 4615f49..3e76e39 100644 --- a/README.md +++ b/README.md @@ -78,4 +78,6 @@ sequenceDiagram - [x] 群踢人功能 # 特别鸣谢 - 特别感谢[@Maple127667](https://github.com/Maple127667)对本项目代码思路的支持 \ No newline at end of file + 特别感谢[@Maple127667](https://github.com/Maple127667)对本项目代码思路的支持 + + 以及[@墨梓柒](https://github.com/DrSmoothl)对部分代码想法的支持 \ No newline at end of file diff --git a/command_args.md b/command_args.md index 8dff207..01390a7 100644 --- a/command_args.md +++ b/command_args.md @@ -13,6 +13,8 @@ Seg.data: Dict[str, Any] = { } ``` 其中,群聊ID将会通过Group_Info.group_id自动获取。 + +**当`duration`为 0 时相当于解除禁言。** ## 群聊全体禁言 ```python Seg.data: Dict[str, Any] = { diff --git a/data/NapcatAdapter.db b/data/NapcatAdapter.db new file mode 100644 index 0000000000000000000000000000000000000000..53f80c298d672ee9b080a67a4994e31a45414f0b GIT binary patch literal 20480 zcmeI#%TB^T6oBC=H>C-e9gA)z#sws3LU;jD8$-mRh>0X74TNGExhS*-m%8=Icq5m# zWQtc_p!p|hJJXq-&iUHSxxQ+-o+C%I`K0g3x+n@w6Gu`CA(Fw74vwgW<5V;VuG+W$ zwr)}!KELM*A0m~15QY81D!&RkGz1Vp009ILKmY**5I_KdI1AjRGNrnqy|~k%vvimC zpg;8&&fLDA&-_Q*9jbBqq+>R^rfgcL=B@l^ooriDt(E2I;Yu%=Db)j@^-Gd+x-EX2T~gJI#wmrzg+No-C`-RT;&p=#_&+rqnPrvCe<-G!CkI zyYG9m^>|}lQ@ajp`Q7km%Y~<6c%mVI00IagfB*srAb bool: + """ + 检查两个 BanUser 对象是否相同。 + """ + return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id + + +class DatabaseManager: + """ + 数据库管理类,负责与数据库交互。 + """ + + def __init__(self): + DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db") + self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL + self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎 + self._ensure_database() # 确保数据库和表已创建 + + def _ensure_database(self) -> None: + """ + 确保数据库和表已创建。 + """ + logger.info("确保数据库文件和表已创建...") + SQLModel.metadata.create_all(self.engine) + logger.success("数据库和表已创建或已存在") + + def update_ban_record(self, ban_list: List[BanUser]) -> None: + """ + 更新禁言列表到数据库。 + 支持在不存在时创建新记录,对于多余的项目自动删除。 + """ + with Session(self.engine) as session: + all_records = session.exec(select(BanUser)).all() + for ban_user in ban_list: + statement = select(BanUser).where( + BanUser.user_id == ban_user.user_id, BanUser.group_id == ban_user.group_id + ) + if existing_record := session.exec(statement).first(): + if existing_record.lift_time == ban_user.lift_time: + logger.debug(f"禁言记录未变更: {existing_record}") + continue + # 更新现有记录的 lift_time + existing_record.lift_time = ban_user.lift_time + session.add(existing_record) + logger.debug(f"更新禁言记录: {existing_record}") + else: + # 创建新记录 + session.add(ban_user) + logger.debug(f"创建新禁言记录: {ban_user}") + # 删除不在 ban_list 中的记录 + for record in all_records: + if not any(is_identical(record, ban_user) for ban_user in ban_list): + session.delete(record) + logger.debug(f"删除禁言记录: {record}") + + session.commit() + logger.info("禁言记录已更新") + + def get_ban_records(self) -> List[BanUser]: + """ + 读取所有禁言记录。 + """ + with Session(self.engine) as session: + statement = select(BanUser) + return session.exec(statement).all() + + def create_ban_record(self, ban_record: BanUser) -> None: + """ + 为特定群组中的用户创建禁言记录。 + 一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。 + 其同时还是简化版的更新方式。 + """ + with Session(self.engine) as session: + session.add(ban_record) + session.commit() + logger.debug(f"创建/更新禁言记录: {ban_record}") + + def delete_ban_record(self, ban_record: BanUser) -> bool: + """ + 删除特定用户在特定群组中的禁言记录。 + 一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。 + """ + user_id = ban_record.user_id + group_id = ban_record.group_id + with Session(self.engine) as session: + statement = select(BanUser).where(BanUser.user_id == user_id, BanUser.group_id == group_id) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + session.commit() + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}") + + +db_manager = DatabaseManager() diff --git a/src/recv_handler/__init__.py b/src/recv_handler/__init__.py new file mode 100644 index 0000000..40f0070 --- /dev/null +++ b/src/recv_handler/__init__.py @@ -0,0 +1,83 @@ +from enum import Enum + + +class MetaEventType: + lifecycle = "lifecycle" # 生命周期 + + class Lifecycle: + connect = "connect" # 生命周期 - WebSocket 连接成功 + + heartbeat = "heartbeat" # 心跳 + + +class MessageType: # 接受消息大类 + private = "private" # 私聊消息 + + class Private: + friend = "friend" # 私聊消息 - 好友 + group = "group" # 私聊消息 - 群临时 + group_self = "group_self" # 私聊消息 - 群中自身发送 + other = "other" # 私聊消息 - 其他 + + group = "group" # 群聊消息 + + class Group: + normal = "normal" # 群聊消息 - 普通 + anonymous = "anonymous" # 群聊消息 - 匿名消息 + notice = "notice" # 群聊消息 - 系统提示 + + +class NoticeType: # 通知事件 + friend_recall = "friend_recall" # 私聊消息撤回 + group_recall = "group_recall" # 群聊消息撤回 + notify = "notify" + group_ban = "group_ban" # 群禁言 + + class Notify: + poke = "poke" # 戳一戳 + + class GroupBan: + ban = "ban" # 禁言 + lift_ban = "lift_ban" # 解除禁言 + + +class RealMessageType: # 实际消息分类 + text = "text" # 纯文本 + face = "face" # qq表情 + image = "image" # 图片 + record = "record" # 语音 + video = "video" # 视频 + at = "at" # @某人 + rps = "rps" # 猜拳魔法表情 + dice = "dice" # 骰子 + shake = "shake" # 私聊窗口抖动(只收) + poke = "poke" # 群聊戳一戳 + share = "share" # 链接分享(json形式) + reply = "reply" # 回复消息 + forward = "forward" # 转发消息 + node = "node" # 转发消息节点 + + +class MessageSentType: + private = "private" + + class Private: + friend = "friend" + group = "group" + + group = "group" + + class Group: + normal = "normal" + + +class CommandType(Enum): + """命令类型""" + + GROUP_BAN = "set_group_ban" # 禁言用户 + GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 + GROUP_KICK = "set_group_kick" # 踢出群聊 + SEND_POKE = "send_poke" # 戳一戳 + + def __str__(self) -> str: + return self.value diff --git a/src/recv_handler.py b/src/recv_handler/message_handler.py similarity index 72% rename from src/recv_handler.py rename to src/recv_handler/message_handler.py index 4d45052..608e486 100644 --- a/src/recv_handler.py +++ b/src/recv_handler/message_handler.py @@ -1,14 +1,22 @@ -from .logger import logger -from .config import global_config +from src.logger import logger +from src.config import global_config +from src.utils import ( + get_group_info, + get_member_info, + get_image_base64, + get_self_info, + get_message_detail, +) from .qq_emoji_list import qq_face +from .message_sending import message_send_instance +from . import RealMessageType, MessageType + import time -import asyncio import json import websockets as Server from typing import List, Tuple, Optional, Dict, Any import uuid -from . import MetaEventType, RealMessageType, MessageType, NoticeType from maim_message import ( UserInfo, GroupInfo, @@ -17,97 +25,54 @@ from maim_message import ( MessageBase, TemplateInfo, FormatInfo, - Router, ) -from .utils import ( - get_group_info, - get_member_info, - get_image_base64, - get_self_info, - get_stranger_info, - get_message_detail, - read_bot_id, - update_bot_id, -) -from .response_pool import get_response + +from src.response_pool import get_response -class RecvHandler: - maibot_router: Router = None - +class MessageHandler: def __init__(self): self.server_connection: Server.ServerConnection = None - self.interval = global_config.napcat_server.heartbeat_interval - self._interval_checking = False - self.bot_id_list: Dict[int, bool] = {} + self.bot_id_list: Dict[str, bool] = {} - async def handle_meta_event(self, message: dict) -> None: - event_type = message.get("meta_event_type") - if event_type == MetaEventType.lifecycle: - sub_type = message.get("sub_type") - if sub_type == MetaEventType.Lifecycle.connect: - self_id = message.get("self_id") - self.last_heart_beat = time.time() - logger.info(f"Bot {self_id} 连接成功") - asyncio.create_task(self.check_heartbeat(self_id)) - elif event_type == MetaEventType.heartbeat: - if message["status"].get("online") and message["status"].get("good"): - if not self._interval_checking: - asyncio.create_task(self.check_heartbeat()) - self.last_heart_beat = time.time() - self.interval = message.get("interval") / 1000 - else: - self_id = message.get("self_id") - logger.warning(f"Bot {self_id} Napcat 端异常!") + def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection - async def check_heartbeat(self, id: int) -> None: - self._interval_checking = True - while True: - now_time = time.time() - if now_time - self.last_heart_beat > self.interval * 2: - logger.error(f"Bot {id} 连接已断开,被下线,或者Napcat卡死!") - break - else: - logger.debug("心跳正常") - await asyncio.sleep(self.interval) - - async def check_allow_to_chat(self, user_id: int, group_id: Optional[int]) -> bool: + async def check_allow_to_chat( + self, + user_id: int, + group_id: Optional[int], + ignore_bot: Optional[bool] = False, + ignore_global_list: Optional[bool] = False, + ) -> bool: # sourcery skip: hoist-statement-from-if, merge-else-if-into-elif """ 检查是否允许聊天 Parameters: user_id: int: 用户ID group_id: int: 群ID + ignore_bot: bool: 是否忽略机器人检查 + ignore_global_list: bool: 是否忽略全局黑名单检查 Returns: bool: 是否允许聊天 """ - user_id = str(user_id) logger.debug(f"群聊id: {group_id}, 用户id: {user_id}") - if global_config.chat.ban_qq_bot and group_id: + if global_config.chat.ban_qq_bot and group_id and not ignore_bot: logger.debug("开始判断是否为机器人") - if not self.bot_id_list: - self.bot_id_list = read_bot_id() - if user_id in self.bot_id_list: - if self.bot_id_list[user_id]: - logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃") - return False - else: - member_info = await get_member_info(self.server_connection, group_id, user_id) - if member_info: - is_bot = member_info.get("is_robot") - if is_bot is None: - logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新") + member_info = await get_member_info(self.server_connection, group_id, user_id) + if member_info: + is_bot = member_info.get("is_robot") + if is_bot is None: + logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新") + else: + if is_bot: + logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单") + self.bot_id_list[user_id] = True + return False else: - if is_bot: - logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单") - self.bot_id_list[user_id] = True - update_bot_id(self.bot_id_list) - return False - else: - self.bot_id_list[user_id] = False - update_bot_id(self.bot_id_list) - user_id = int(user_id) + self.bot_id_list[user_id] = False logger.debug("开始检查聊天白名单/黑名单") if group_id: if global_config.chat.group_list_type == "whitelist" and group_id not in global_config.chat.group_list: @@ -123,7 +88,7 @@ class RecvHandler: elif global_config.chat.private_list_type == "blacklist" and user_id in global_config.chat.private_list: logger.warning("私聊在聊天黑名单中,消息被丢弃") return False - if user_id in global_config.chat.ban_user_id: + if user_id in global_config.chat.ban_user_id and not ignore_global_list: logger.warning("用户在全局黑名单中,消息被丢弃") return False return True @@ -275,7 +240,7 @@ class RecvHandler: ) logger.info("发送到Maibot处理信息") - await self.message_process(message_base) + await message_send_instance.message_send(message_base) async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None: # sourcery skip: low-code-quality @@ -343,7 +308,7 @@ class RecvHandler: case RealMessageType.share: logger.warning("暂时不支持链接解析") case RealMessageType.forward: - messages = await self.get_forward_message(sub_message) + messages = await self._get_forward_message(sub_message) if not messages: logger.warning("转发消息内容为空或获取失败") return None @@ -440,40 +405,6 @@ class RecvHandler: else: return None - async def get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None: - forward_message_data: Dict = raw_message.get("data") - if not forward_message_data: - logger.warning("转发消息内容为空") - return None - forward_message_id = forward_message_data.get("id") - request_uuid = str(uuid.uuid4()) - payload = json.dumps( - { - "action": "get_forward_msg", - "params": {"message_id": forward_message_id}, - "echo": request_uuid, - } - ) - try: - await self.server_connection.send(payload) - response: dict = await get_response(request_uuid) - except TimeoutError: - logger.error("获取转发消息超时") - return None - except Exception as e: - logger.error(f"获取转发消息失败: {str(e)}") - return None - logger.debug( - f"转发消息原始格式:{json.dumps(response)[:80]}..." - if len(json.dumps(response)) > 80 - else json.dumps(response) - ) - response_data: Dict = response.get("data") - if not response_data: - logger.warning("转发消息内容为空或获取失败") - return None - return response_data.get("messages") - async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None: # sourcery skip: move-assign-in-block, use-named-expression """ @@ -506,142 +437,6 @@ class RecvHandler: seg_message.append(Seg(type="text", data="],说:")) return seg_message - async def handle_notice(self, raw_message: dict) -> None: - notice_type = raw_message.get("notice_type") - # message_time: int = raw_message.get("time") - message_time: float = time.time() # 应可乐要求,现在是float了 - - group_id = raw_message.get("group_id") - user_id = raw_message.get("user_id") - - if not await self.check_allow_to_chat(user_id, group_id): - logger.warning("notice消息被丢弃") - return None - - handled_message: Seg = None - - match notice_type: - case NoticeType.friend_recall: - logger.info("好友撤回一条消息") - logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") - logger.warning("暂时不支持撤回消息处理") - case NoticeType.group_recall: - logger.info("群内用户撤回一条消息") - logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") - logger.warning("暂时不支持撤回消息处理") - case NoticeType.notify: - sub_type = raw_message.get("sub_type") - match sub_type: - case NoticeType.Notify.poke: - if global_config.chat.enable_poke: - handled_message: Seg = await self.handle_poke_notify(raw_message) - else: - logger.warning("戳一戳消息被禁用,取消戳一戳处理") - case _: - logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") - case _: - logger.warning(f"不支持的notice类型: {notice_type}") - return None - if not handled_message: - logger.warning("notice处理失败或不支持") - return None - - source_name: str = None - source_cardname: str = None - if group_id: - member_info: dict = await get_member_info(self.server_connection, group_id, user_id) - if member_info: - source_name = member_info.get("nickname") - source_cardname = member_info.get("card") - else: - logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") - source_name = "QQ用户" - else: - stranger_info = await get_stranger_info(self.server_connection, user_id) - if stranger_info: - source_name = stranger_info.get("nickname") - else: - logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") - source_name = "QQ用户" - - user_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, - user_id=user_id, - user_nickname=source_name, - user_cardname=source_cardname, - ) - - group_info: GroupInfo = None - if group_id: - fetched_group_info = await get_group_info(self.server_connection, group_id) - group_name: str = None - if fetched_group_info: - group_name = fetched_group_info.get("group_name") - else: - logger.warning("无法获取戳一戳消息所在群的名称") - group_info = GroupInfo( - platform=global_config.maibot_server.platform_name, - group_id=group_id, - group_name=group_name, - ) - - message_info: BaseMessageInfo = BaseMessageInfo( - platform=global_config.maibot_server.platform_name, - message_id="notice", - time=message_time, - user_info=user_info, - group_info=group_info, - template_info=None, - format_info=None, - ) - - message_base: MessageBase = MessageBase( - message_info=message_info, - message_segment=handled_message, - raw_message=json.dumps(raw_message), - ) - - logger.info("发送到Maibot处理通知信息") - await self.message_process(message_base) - - async def handle_poke_notify(self, raw_message: dict) -> Seg | None: - self_info: dict = await get_self_info(self.server_connection) - if not self_info: - logger.error("自身信息获取失败") - return None - self_id = raw_message.get("self_id") - target_id = raw_message.get("target_id") - target_name: str = None - raw_info: list = raw_message.get("raw_info") - # 计算Seg - if self_id == target_id: - target_name = self_info.get("nickname") - else: - return None - try: - first_txt = raw_info[2].get("txt", "戳了戳") - second_txt = raw_info[4].get("txt", "") - except Exception as e: - logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本") - first_txt = "戳了戳" - second_txt = "" - """ - # 不启用戳其他人的处理 - else: - # 由于Napcat不支持获取昵称,所以需要单独获取 - group_id = raw_message.get("group_id") - fetched_member_info: dict = await get_member_info( - self.server_connection, group_id, target_id - ) - if fetched_member_info: - target_name = fetched_member_info.get("nickname") - """ - seg_data: Seg = Seg( - type="text", - data=f"{first_txt}{target_name}{second_txt}(这是QQ的一个功能,用于提及某人,但没那么明显)", - ) - return seg_data - async def handle_forward_message(self, message_list: list) -> Seg | None: """ 递归处理转发消息,并按照动态方式确定图片处理方式 @@ -800,15 +595,39 @@ class RecvHandler: seg_list.append(full_seg_data) return Seg(type="seglist", data=seg_list), image_count - async def message_process(self, message_base: MessageBase) -> None: - try: - send_status = await self.maibot_router.send_message(message_base) - if not send_status: - raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常") - except Exception as e: - logger.error(f"发送消息失败: {str(e)}") - logger.error("请检查与MaiBot之间的连接") + async def _get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None: + forward_message_data: Dict = raw_message.get("data") + if not forward_message_data: + logger.warning("转发消息内容为空") return None + forward_message_id = forward_message_data.get("id") + request_uuid = str(uuid.uuid4()) + payload = json.dumps( + { + "action": "get_forward_msg", + "params": {"message_id": forward_message_id}, + "echo": request_uuid, + } + ) + try: + await self.server_connection.send(payload) + response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error("获取转发消息超时") + return None + except Exception as e: + logger.error(f"获取转发消息失败: {str(e)}") + return None + logger.debug( + f"转发消息原始格式:{json.dumps(response)[:80]}..." + if len(json.dumps(response)) > 80 + else json.dumps(response) + ) + response_data: Dict = response.get("data") + if not response_data: + logger.warning("转发消息内容为空或获取失败") + return None + return response_data.get("messages") -recv_handler = RecvHandler() +message_handler = MessageHandler() diff --git a/src/recv_handler/message_sending.py b/src/recv_handler/message_sending.py new file mode 100644 index 0000000..de35399 --- /dev/null +++ b/src/recv_handler/message_sending.py @@ -0,0 +1,31 @@ +from src.logger import logger +from maim_message import MessageBase, Router + + +class MessageSending: + """ + 负责把消息发送到麦麦 + """ + + maibot_router: Router = None + + def __init__(self): + pass + + async def message_send(self, message_base: MessageBase) -> bool: + """ + 发送消息 + Parameters: + message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息 + """ + try: + send_status = await self.maibot_router.send_message(message_base) + if not send_status: + raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常") + except Exception as e: + logger.error(f"发送消息失败: {str(e)}") + logger.error("请检查与MaiBot之间的连接") + return send_status + + +message_send_instance = MessageSending() diff --git a/src/recv_handler/meta_event_handler.py b/src/recv_handler/meta_event_handler.py new file mode 100644 index 0000000..bb9efe6 --- /dev/null +++ b/src/recv_handler/meta_event_handler.py @@ -0,0 +1,49 @@ +from src.logger import logger +from src.config import global_config +import time +import asyncio + +from . import MetaEventType + + +class MetaEventHandler: + """ + 处理Meta事件 + """ + + def __init__(self): + self.interval = global_config.napcat_server.heartbeat_interval + self._interval_checking = False + + async def handle_meta_event(self, message: dict) -> None: + event_type = message.get("meta_event_type") + if event_type == MetaEventType.lifecycle: + sub_type = message.get("sub_type") + if sub_type == MetaEventType.Lifecycle.connect: + self_id = message.get("self_id") + self.last_heart_beat = time.time() + logger.info(f"Bot {self_id} 连接成功") + asyncio.create_task(self.check_heartbeat(self_id)) + elif event_type == MetaEventType.heartbeat: + if message["status"].get("online") and message["status"].get("good"): + if not self._interval_checking: + asyncio.create_task(self.check_heartbeat()) + self.last_heart_beat = time.time() + self.interval = message.get("interval") / 1000 + else: + self_id = message.get("self_id") + logger.warning(f"Bot {self_id} Napcat 端异常!") + + async def check_heartbeat(self, id: int) -> None: + self._interval_checking = True + while True: + now_time = time.time() + if now_time - self.last_heart_beat > self.interval * 2: + logger.error(f"Bot {id} 可能发生了连接断开,被下线,或者Napcat卡死!") + break + else: + logger.debug("心跳正常") + await asyncio.sleep(self.interval) + + +meta_event_handler = MetaEventHandler() diff --git a/src/recv_handler/notice_handler.py b/src/recv_handler/notice_handler.py new file mode 100644 index 0000000..e4bd468 --- /dev/null +++ b/src/recv_handler/notice_handler.py @@ -0,0 +1,493 @@ +import time +import json +import asyncio +import websockets as Server +from typing import Tuple, Optional + +from src.logger import logger +from src.config import global_config +from src.database import BanUser, db_manager, is_identical +from . import NoticeType +from .message_sending import message_send_instance +from .message_handler import message_handler +from maim_message import UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase + +from src.utils import ( + get_group_info, + get_member_info, + get_self_info, + get_stranger_info, + read_ban_list, +) + +notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100) +unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3) + + +class NoticeHandler: + banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表 + lifted_list: list[BanUser] = [] # 已经自然解除禁言 + + def __init__(self): + self.server_connection: Server.ServerConnection = None + + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection + + self.banned_list, self.lifted_list = await read_ban_list(self.server_connection) + + asyncio.create_task(self.auto_lift_detect()) + asyncio.create_task(self.send_notice()) + asyncio.create_task(self.handle_natural_lift()) + + def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: + """ + 将用户禁言记录添加到self.banned_list中 + 如果是全体禁言,则user_id为0 + """ + if user_id is None: + user_id = 0 # 使用0表示全体禁言 + lift_time = -1 + ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time) + for record in self.banned_list: + if is_identical(record, ban_record): + self.banned_list.remove(record) + self.banned_list.append(ban_record) + db_manager.create_ban_record(ban_record) # 作为更新 + return + self.banned_list.append(ban_record) + db_manager.create_ban_record(ban_record) # 添加到数据库 + + def _lift_operation(self, group_id: int, user_id: Optional[int]) -> None: + """ + 从self.lifted_group_list中移除已经解除全体禁言的群 + """ + if user_id is None: + user_id = 0 # 使用0表示全体禁言 + ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1) + self.lifted_list.append(ban_record) + db_manager.delete_ban_record(ban_record) # 删除数据库中的记录 + + async def handle_notice(self, raw_message: dict) -> None: + notice_type = raw_message.get("notice_type") + # message_time: int = raw_message.get("time") + message_time: float = time.time() # 应可乐要求,现在是float了 + + group_id = raw_message.get("group_id") + user_id = raw_message.get("user_id") + + # if not await self.check_allow_to_chat(user_id, group_id): + # logger.warning("notice消息被丢弃") + # return None + + handled_message: Seg = None + user_info: UserInfo = None + + match notice_type: + case NoticeType.friend_recall: + logger.info("好友撤回一条消息") + logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") + logger.warning("暂时不支持撤回消息处理") + case NoticeType.group_recall: + logger.info("群内用户撤回一条消息") + logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") + logger.warning("暂时不支持撤回消息处理") + case NoticeType.notify: + sub_type = raw_message.get("sub_type") + match sub_type: + case NoticeType.Notify.poke: + if global_config.chat.enable_poke and await message_handler.check_allow_to_chat( + user_id, group_id, False, False + ): + logger.info("处理戳一戳消息") + handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id) + else: + logger.warning("戳一戳消息被禁用,取消戳一戳处理") + case _: + logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") + case NoticeType.group_ban: + sub_type = raw_message.get("sub_type") + match sub_type: + case NoticeType.GroupBan.ban: + if await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理群禁言") + handled_message, user_info = await self.handle_ban_notify(raw_message, group_id) + case NoticeType.GroupBan.lift_ban: + if await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理解除群禁言") + handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id) + case _: + logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}") + case _: + logger.warning(f"不支持的notice类型: {notice_type}") + return None + if not handled_message or not user_info: + logger.warning("notice处理失败或不支持") + return None + + group_info: GroupInfo = None + if group_id: + fetched_group_info = await get_group_info(self.server_connection, group_id) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + else: + logger.warning("无法获取notice消息所在群的名称") + group_info = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=group_id, + group_name=group_name, + ) + + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id="notice", + time=message_time, + user_info=user_info, + group_info=group_info, + template_info=None, + format_info=None, + ) + + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=handled_message, + raw_message=json.dumps(raw_message), + ) + + logger.info("发送到Maibot处理通知信息") + await message_send_instance.message_send(message_base) + + async def handle_poke_notify(self, raw_message: dict, group_id: int, user_id: int) -> Tuple[Seg | None, UserInfo]: + self_info: dict = await get_self_info(self.server_connection) + if not self_info: + logger.error("自身信息获取失败") + return None + self_id = raw_message.get("self_id") + target_id = raw_message.get("target_id") + target_name: str = None + raw_info: list = raw_message.get("raw_info") + + # 计算user_info + source_name: str = None + source_cardname: str = None + if group_id: + member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if member_info: + source_name = member_info.get("nickname") + source_cardname = member_info.get("card") + else: + logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") + source_name = "QQ用户" + else: + stranger_info = await get_stranger_info(self.server_connection, user_id) + if stranger_info: + source_name = stranger_info.get("nickname") + else: + logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") + source_name = "QQ用户" + + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=source_name, + user_cardname=source_cardname, + ) + + # 计算Seg + if self_id == target_id: + target_name = self_info.get("nickname") + else: + return None + try: + first_txt = raw_info[2].get("txt", "戳了戳") + second_txt = raw_info[4].get("txt", "") + except Exception as e: + logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本") + first_txt = "戳了戳" + second_txt = "" + """ + # 不启用戳其他人的处理 + else: + # 由于Napcat不支持获取昵称,所以需要单独获取 + group_id = raw_message.get("group_id") + fetched_member_info: dict = await get_member_info( + self.server_connection, group_id, target_id + ) + if fetched_member_info: + target_name = fetched_member_info.get("nickname") + """ + seg_data: Seg = Seg( + type="text", + data=f"{first_txt}{target_name}{second_txt}(这是QQ的一个功能,用于提及某人,但没那么明显)", + ) + return seg_data, user_info + + async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: + if not group_id: + logger.error("群ID不能为空,无法处理禁言通知") + return None, None + + # 计算user_info + operator_id = raw_message.get("operator_id") + operator_nickname: str = None + operator_cardname: str = None + + member_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if member_info: + operator_nickname = member_info.get("nickname") + operator_cardname = member_info.get("card") + else: + logger.warning("无法获取禁言执行者的昵称,消息可能会无效") + operator_nickname = "QQ用户" + + operator_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_nickname, + user_cardname=operator_cardname, + ) + + # 计算Seg + user_id = raw_message.get("user_id") + banned_user_info: UserInfo = None + user_nickname: str = "QQ用户" + user_cardname: str = None + sub_type: str = None + + duration = raw_message.get("duration") + if duration is None: + logger.error("禁言时长不能为空,无法处理禁言通知") + return None, None + + if user_id == 0: # 为全体禁言 + sub_type: str = "whole_ban" + self._ban_operation(group_id) + else: # 为单人禁言 + # 获取被禁言人的信息 + sub_type: str = "ban" + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + banned_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + self._ban_operation(group_id, user_id, int(time.time() + duration)) + + seg_data: Seg = Seg( + type="notify", + data={ + "sub_type": sub_type, + "duration": duration, + "banned_user_info": banned_user_info, + }, + ) + + return seg_data, operator_info + + async def handle_lift_ban_notify( + self, raw_message: dict, group_id: int + ) -> Tuple[Seg, UserInfo] | Tuple[None, None]: + if not group_id: + logger.error("群ID不能为空,无法处理解除禁言通知") + return None, None + + # 计算user_info + operator_id = raw_message.get("operator_id") + operator_nickname: str = None + operator_cardname: str = None + + member_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if member_info: + operator_nickname = member_info.get("nickname") + operator_cardname = member_info.get("card") + else: + logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效") + operator_nickname = "QQ用户" + + operator_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_nickname, + user_cardname=operator_cardname, + ) + + # 计算Seg + sub_type: str = None + user_nickname: str = "QQ用户" + user_cardname: str = None + lifted_user_info: UserInfo = None + + user_id = raw_message.get("user_id") + if user_id == 0: # 全体禁言解除 + sub_type = "whole_lift_ban" + self._lift_operation(group_id) + else: # 单人禁言解除 + sub_type = "lift_ban" + # 获取被解除禁言人的信息 + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + else: + logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效") + lifted_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + self._lift_operation(group_id, user_id) + + seg_data: Seg = Seg( + type="notify", + data={ + "sub_type": sub_type, + "lifted_user_info": lifted_user_info, + }, + ) + return seg_data, operator_info + + async def handle_natural_lift(self) -> None: + while True: + if len(self.lifted_list) != 0: + lift_record = self.lifted_list.pop() + group_id = lift_record.group_id + user_id = lift_record.user_id + + db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录 + + seg_message: Seg = await self.natural_lift(group_id, user_id) + + fetched_group_info = await get_group_info(self.server_connection, group_id) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + else: + logger.warning("无法获取notice消息所在群的名称") + group_info = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=group_id, + group_name=group_name, + ) + + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id="notice", + time=time.time(), + user_info=None, # 自然解除禁言没有操作者 + group_info=group_info, + template_info=None, + format_info=None, + ) + + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=seg_message, + raw_message=json.dumps( + { + "post_type": "notice", + "notice_type": "group_ban", + "sub_type": "lift_ban", + "group_id": group_id, + "user_id": user_id, + "operator_id": None, # 自然解除禁言没有操作者 + } + ), + ) + if notice_queue.full() or unsuccessful_notice_queue.full(): + logger.warning("通知队列已满,可能是多次发送失败,消息丢弃") + else: + await notice_queue.put(message_base) + + await asyncio.sleep(0.5) # 确保队列处理间隔 + else: + await asyncio.sleep(5) # 每5秒检查一次 + + async def natural_lift(self, group_id: int, user_id: int) -> Seg | None: + if not group_id: + logger.error("群ID不能为空,无法处理解除禁言通知") + return None + + if user_id == 0: # 理论上永远不会触发 + return Seg( + type="notify", + data={ + "sub_type": "whole_lift_ban", + "lifted_user_info": None, + }, + ) + + user_nickname: str = "QQ用户" + user_cardname: str = None + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + + lifted_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + + return Seg( + type="notify", + data={ + "sub_type": "lift_ban", + "lifted_user_info": lifted_user_info, + }, + ) + + async def auto_lift_detect(self) -> None: + while True: + for ban_record in self.banned_list: + if ban_record.user_id == 0 or ban_record.lift_time == -1: + continue + if ban_record.lift_time <= int(time.time()): + # 触发自然解除禁言 + logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除") + self.lifted_list.append(ban_record) + self.banned_list.remove(ban_record) + asyncio.sleep(5) + + async def send_notice(self) -> None: + """ + 发送通知消息到Napcat + """ + while True: + if not unsuccessful_notice_queue.empty(): + to_be_send: MessageBase = await unsuccessful_notice_queue.get() + try: + send_status = await message_send_instance.message_send(to_be_send) + if send_status: + unsuccessful_notice_queue.task_done() + else: + await unsuccessful_notice_queue.put(to_be_send) + except Exception as e: + logger.error(f"发送通知消息失败: {str(e)}") + await unsuccessful_notice_queue.put(to_be_send) + asyncio.sleep(0.2) + continue + to_be_send: MessageBase = await notice_queue.get() + try: + send_status = await message_send_instance.message_send(to_be_send) + if send_status: + notice_queue.task_done() + else: + await unsuccessful_notice_queue.put(to_be_send) + except Exception as e: + logger.error(f"发送通知消息失败: {str(e)}") + await unsuccessful_notice_queue.put(to_be_send) + await asyncio.sleep(1) + + +notice_handler = NoticeHandler() diff --git a/src/qq_emoji_list.py b/src/recv_handler/qq_emoji_list.py similarity index 100% rename from src/qq_emoji_list.py rename to src/recv_handler/qq_emoji_list.py diff --git a/src/send_handler.py b/src/send_handler.py index 2a6709e..c375679 100644 --- a/src/send_handler.py +++ b/src/send_handler.py @@ -21,6 +21,10 @@ class SendHandler: def __init__(self): self.server_connection: Server.ServerConnection = None + def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection + async def handle_message(self, raw_message_base_dict: dict) -> None: raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict) message_segment: Seg = raw_message_base.message_segment @@ -254,8 +258,8 @@ class SendHandler: duration: int = int(args["duration"]) user_id: int = int(args["qq_id"]) group_id: int = int(group_info.group_id) - if duration <= 0: - raise ValueError("封禁时间必须大于0") + if duration < 0: + raise ValueError("封禁时间必须大于等于0") if not user_id or not group_id: raise ValueError("封禁命令缺少必要参数") if duration > 2592000: diff --git a/src/utils.py b/src/utils.py index 9e9941c..c23ee9f 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,15 +2,16 @@ import websockets as Server import json import base64 import uuid +import urllib3 +import ssl +import io + +from src.database import BanUser, db_manager from .logger import logger from .response_pool import get_response -import urllib3 -import ssl -from pathlib import Path from PIL import Image -import io -import os +from typing import Union, List, Tuple class SSLAdapter(urllib3.PoolManager): @@ -44,6 +45,28 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d return socket_response.get("data") +async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict: + """ + 获取群详细信息 + + 返回值需要处理可能为空的情况 + """ + logger.debug("获取群详细信息中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}) + try: + await websocket.send(payload) + socket_response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error(f"获取群详细信息超时,群号: {group_id}") + return None + except Exception as e: + logger.error(f"获取群详细信息失败: {e}") + return None + logger.debug(socket_response) + return socket_response.get("data") + + async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict: """ 获取群成员信息 @@ -171,7 +194,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> return response.get("data") -async def get_message_detail(websocket: Server.ServerConnection, message_id: str) -> dict: +async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict: """ 获取消息详情,可能为空 Parameters: @@ -196,41 +219,58 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: str return response.get("data") -def update_bot_id(data: dict) -> None: +async def read_ban_list( + websocket: Server.ServerConnection, +) -> Tuple[List[BanUser], List[BanUser]]: """ - 更新用户是否为机器人的字典到根目录下的data文件夹中的qq_bot.json。 - Parameters: - data: dict: 包含需要更新的信息。 - """ - json_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "qq_bot.json") - try: - with open(json_path, "w", encoding="utf-8") as json_file: - json.dump(data, json_file, ensure_ascii=False, indent=4) - logger.info(f"ID字典已更新到文件: {json_path}") - except Exception as e: - logger.error(f"更新ID字典失败: {e}") - - -def read_bot_id() -> dict: - """ - 从根目录下的data文件夹中的文件读取机器人ID。 + 从根目录下的data文件夹中的文件读取禁言列表。 + 同时自动更新已经失效禁言 Returns: - list: 读取的机器人ID信息。 + Tuple[ + 一个仍在禁言中的用户的BanUser列表, + 一个已经自然解除禁言的用户的BanUser列表, + 一个仍在全体禁言中的群的BanUser列表, + 一个已经自然解除全体禁言的群的BanUser列表, + ] """ - json_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "qq_bot.json") try: - with open(json_path, "r", encoding="utf-8") as json_file: - data = json.load(json_file) - logger.info(f"已读取机器人ID信息: {data}") - return data - except FileNotFoundError: - logger.warning(f"文件未找到: {json_path},正在自动创建文件") - json_path = Path(os.path.dirname(os.path.dirname(__file__))) / "data" / "qq_bot.json" - # 确保父目录存在 - json_path.parent.mkdir(parents=True, exist_ok=True) - # 创建空文件 - json_path.touch(exist_ok=True) - return {} + ban_list = db_manager.get_ban_records() + lifted_list: List[BanUser] = [] + logger.info("已经读取禁言列表") + for ban_record in ban_list: + if ban_record.user_id == 0: + fetched_group_info = await get_group_info(websocket, ban_record.group_id) + if fetched_group_info is None: + logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除") + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + group_all_shut: int = fetched_group_info.get("group_all_shut") + if group_all_shut == 0: + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + else: + fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id) + if fetched_member_info is None: + logger.warning( + f"无法获取群成员信息,用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除" + ) + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + lift_ban_time: int = fetched_member_info.get("shut_up_timestamp") + if lift_ban_time == 0: + lifted_list.append(ban_record) + ban_list.remove(ban_record) + else: + ban_record.lift_time = lift_ban_time + db_manager.update_ban_record(ban_list) + return ban_list, lifted_list except Exception as e: - logger.error(f"读取机器人ID失败: {e}") - return {} + logger.error(f"读取禁言列表失败: {e}") + return [], [] + + +def save_ban_record(list: List[BanUser]): + return db_manager.update_ban_record(list) From ee873d8cbb50694b91e1c76a83ed2074bbff3c9d Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 28 Jun 2025 01:57:46 +0800 Subject: [PATCH 2/5] =?UTF-8?q?fix=20=E7=B1=BB=E5=9E=8B=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/recv_handler/message_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/recv_handler/message_handler.py b/src/recv_handler/message_handler.py index 608e486..1f65cf6 100644 --- a/src/recv_handler/message_handler.py +++ b/src/recv_handler/message_handler.py @@ -34,7 +34,7 @@ from src.response_pool import get_response class MessageHandler: def __init__(self): self.server_connection: Server.ServerConnection = None - self.bot_id_list: Dict[str, bool] = {} + self.bot_id_list: Dict[int, bool] = {} def set_server_connection(self, server_connection: Server.ServerConnection) -> None: """设置Napcat连接""" From 11969095219a87c72174ee64a97dc1207f90b479 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 28 Jun 2025 10:18:24 +0800 Subject: [PATCH 3/5] remove database and update gitignore --- .gitignore | 3 +-- data/NapcatAdapter.db | Bin 20480 -> 0 bytes 2 files changed, 1 insertion(+), 2 deletions(-) delete mode 100644 data/NapcatAdapter.db diff --git a/.gitignore b/.gitignore index 60f4dc6..6c9e2d2 100644 --- a/.gitignore +++ b/.gitignore @@ -272,5 +272,4 @@ $RECYCLE.BIN/ config.toml config.toml.back test -data/qq_bot.json -data/ban_list.json \ No newline at end of file +data/NapcatAdapter.db \ No newline at end of file diff --git a/data/NapcatAdapter.db b/data/NapcatAdapter.db deleted file mode 100644 index 53f80c298d672ee9b080a67a4994e31a45414f0b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20480 zcmeI#%TB^T6oBC=H>C-e9gA)z#sws3LU;jD8$-mRh>0X74TNGExhS*-m%8=Icq5m# zWQtc_p!p|hJJXq-&iUHSxxQ+-o+C%I`K0g3x+n@w6Gu`CA(Fw74vwgW<5V;VuG+W$ zwr)}!KELM*A0m~15QY81D!&RkGz1Vp009ILKmY**5I_KdI1AjRGNrnqy|~k%vvimC zpg;8&&fLDA&-_Q*9jbBqq+>R^rfgcL=B@l^ooriDt(E2I;Yu%=Db)j@^-Gd+x-EX2T~gJI#wmrzg+No-C`-RT;&p=#_&+rqnPrvCe<-G!CkI zyYG9m^>|}lQ@ajp`Q7km%Y~<6c%mVI00IagfB*srAb Date: Sat, 28 Jun 2025 11:44:35 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E4=BF=AEbug=EF=BC=8C=E6=94=B9=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E5=8F=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 ++- main.py | 6 ++-- pyproject.toml | 2 +- src/database.py | 54 ++++++++++++++++++++++------- src/recv_handler/message_handler.py | 4 +-- src/recv_handler/message_sending.py | 2 +- src/recv_handler/notice_handler.py | 43 +++++++++++++++-------- src/send_handler.py | 2 +- 8 files changed, 80 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 6c9e2d2..f0977e5 100644 --- a/.gitignore +++ b/.gitignore @@ -272,4 +272,6 @@ $RECYCLE.BIN/ config.toml config.toml.back test -data/NapcatAdapter.db \ No newline at end of file +data/NapcatAdapter.db +data/NapcatAdapter.db-shm +data/NapcatAdapter.db-wal \ No newline at end of file diff --git a/main.py b/main.py index 12657af..a928191 100644 --- a/main.py +++ b/main.py @@ -16,9 +16,9 @@ message_queue = asyncio.Queue() async def message_recv(server_connection: Server.ServerConnection): - message_handler.set_server_connection(server_connection) - notice_handler.set_server_connection(server_connection) - send_handler.set_server_connection(server_connection) + await message_handler.set_server_connection(server_connection) + asyncio.create_task(notice_handler.set_server_connection(server_connection)) + await send_handler.set_server_connection(server_connection) async for raw_message in server_connection: logger.debug( f"{raw_message[:100]}..." diff --git a/pyproject.toml b/pyproject.toml index 0fedfb2..2f6423d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "MaiBotNapcatAdapter" -version = "0.2.6" +version = "0.3.0" description = "A MaiBot adapter for Napcat" [tool.ruff] diff --git a/src/database.py b/src/database.py index 2f58b50..45100cb 100644 --- a/src/database.py +++ b/src/database.py @@ -1,5 +1,6 @@ import os from typing import Optional, List +from dataclasses import dataclass from sqlmodel import Field, Session, SQLModel, create_engine, select from src.logger import logger @@ -13,7 +14,18 @@ from src.logger import logger """ -class BanUser(SQLModel, table=True): +@dataclass +class BanUser: + """ + 程序处理使用的实例 + """ + + user_id: int + group_id: int + lift_time: Optional[int] = Field(default=-1) + + +class DB_BanUser(SQLModel, table=True): """ 表示数据库中的用户禁言记录。 使用双重主键 @@ -24,7 +36,7 @@ class BanUser(SQLModel, table=True): lift_time: Optional[int] # 禁言解除的时间(时间戳) -def is_identical(self, obj1: BanUser, obj2: BanUser) -> bool: +def is_identical(obj1: BanUser, obj2: BanUser) -> bool: """ 检查两个 BanUser 对象是否相同。 """ @@ -51,15 +63,16 @@ class DatabaseManager: logger.success("数据库和表已创建或已存在") def update_ban_record(self, ban_list: List[BanUser]) -> None: + # sourcery skip: class-extract-method """ 更新禁言列表到数据库。 支持在不存在时创建新记录,对于多余的项目自动删除。 """ with Session(self.engine) as session: - all_records = session.exec(select(BanUser)).all() + all_records = session.exec(select(DB_BanUser)).all() for ban_user in ban_list: - statement = select(BanUser).where( - BanUser.user_id == ban_user.user_id, BanUser.group_id == ban_user.group_id + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id ) if existing_record := session.exec(statement).first(): if existing_record.lift_time == ban_user.lift_time: @@ -71,13 +84,24 @@ class DatabaseManager: logger.debug(f"更新禁言记录: {existing_record}") else: # 创建新记录 - session.add(ban_user) + db_record = DB_BanUser( + user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time + ) + session.add(db_record) logger.debug(f"创建新禁言记录: {ban_user}") # 删除不在 ban_list 中的记录 - for record in all_records: + for db_record in all_records: + record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time) if not any(is_identical(record, ban_user) for ban_user in ban_list): - session.delete(record) - logger.debug(f"删除禁言记录: {record}") + statement = select(DB_BanUser).where( + DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id + ) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + session.commit() + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: {ban_record}") session.commit() logger.info("禁言记录已更新") @@ -87,8 +111,9 @@ class DatabaseManager: 读取所有禁言记录。 """ with Session(self.engine) as session: - statement = select(BanUser) - return session.exec(statement).all() + statement = select(DB_BanUser) + records = session.exec(statement).all() + return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records] def create_ban_record(self, ban_record: BanUser) -> None: """ @@ -97,7 +122,10 @@ class DatabaseManager: 其同时还是简化版的更新方式。 """ with Session(self.engine) as session: - session.add(ban_record) + db_record = DB_BanUser( + user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + ) + session.add(db_record) session.commit() logger.debug(f"创建/更新禁言记录: {ban_record}") @@ -109,7 +137,7 @@ class DatabaseManager: user_id = ban_record.user_id group_id = ban_record.group_id with Session(self.engine) as session: - statement = select(BanUser).where(BanUser.user_id == user_id, BanUser.group_id == group_id) + statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id) if ban_record := session.exec(statement).first(): session.delete(ban_record) session.commit() diff --git a/src/recv_handler/message_handler.py b/src/recv_handler/message_handler.py index 1f65cf6..0cec56f 100644 --- a/src/recv_handler/message_handler.py +++ b/src/recv_handler/message_handler.py @@ -36,14 +36,14 @@ class MessageHandler: self.server_connection: Server.ServerConnection = None self.bot_id_list: Dict[int, bool] = {} - def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: """设置Napcat连接""" self.server_connection = server_connection async def check_allow_to_chat( self, user_id: int, - group_id: Optional[int], + group_id: Optional[int] = None, ignore_bot: Optional[bool] = False, ignore_global_list: Optional[bool] = False, ) -> bool: diff --git a/src/recv_handler/message_sending.py b/src/recv_handler/message_sending.py index de35399..2e43bbd 100644 --- a/src/recv_handler/message_sending.py +++ b/src/recv_handler/message_sending.py @@ -21,7 +21,7 @@ class MessageSending: try: send_status = await self.maibot_router.send_message(message_base) if not send_status: - raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常") + raise RuntimeError("可能是路由未正确配置或连接异常") except Exception as e: logger.error(f"发送消息失败: {str(e)}") logger.error("请检查与MaiBot之间的连接") diff --git a/src/recv_handler/notice_handler.py b/src/recv_handler/notice_handler.py index e4bd468..2d03a49 100644 --- a/src/recv_handler/notice_handler.py +++ b/src/recv_handler/notice_handler.py @@ -35,6 +35,8 @@ class NoticeHandler: """设置Napcat连接""" self.server_connection = server_connection + while self.server_connection.state != Server.State.OPEN: + await asyncio.sleep(0.5) self.banned_list, self.lifted_list = await read_ban_list(self.server_connection) asyncio.create_task(self.auto_lift_detect()) @@ -59,7 +61,7 @@ class NoticeHandler: self.banned_list.append(ban_record) db_manager.create_ban_record(ban_record) # 添加到数据库 - def _lift_operation(self, group_id: int, user_id: Optional[int]) -> None: + def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: """ 从self.lifted_group_list中移除已经解除全体禁言的群 """ @@ -77,12 +79,9 @@ class NoticeHandler: group_id = raw_message.get("group_id") user_id = raw_message.get("user_id") - # if not await self.check_allow_to_chat(user_id, group_id): - # logger.warning("notice消息被丢弃") - # return None - handled_message: Seg = None user_info: UserInfo = None + system_notice: bool = False match notice_type: case NoticeType.friend_recall: @@ -110,15 +109,17 @@ class NoticeHandler: sub_type = raw_message.get("sub_type") match sub_type: case NoticeType.GroupBan.ban: - if await message_handler.check_allow_to_chat(user_id, group_id, True, False): + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): return None logger.info("处理群禁言") handled_message, user_info = await self.handle_ban_notify(raw_message, group_id) + system_notice = True case NoticeType.GroupBan.lift_ban: - if await message_handler.check_allow_to_chat(user_id, group_id, True, False): + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): return None logger.info("处理解除群禁言") handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id) + system_notice = True case _: logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}") case _: @@ -158,8 +159,11 @@ class NoticeHandler: raw_message=json.dumps(raw_message), ) - logger.info("发送到Maibot处理通知信息") - await message_send_instance.message_send(message_base) + if system_notice: + await self.put_notice(message_base) + else: + logger.info("发送到Maibot处理通知信息") + await message_send_instance.message_send(message_base) async def handle_poke_notify(self, raw_message: dict, group_id: int, user_id: int) -> Tuple[Seg | None, UserInfo]: self_info: dict = await get_self_info(self.server_connection) @@ -355,6 +359,15 @@ class NoticeHandler: ) return seg_data, operator_info + async def put_notice(self, message_base: MessageBase) -> None: + """ + 将处理后的通知消息放入通知队列 + """ + if notice_queue.full() or unsuccessful_notice_queue.full(): + logger.warning("通知队列已满,可能是多次发送失败,消息丢弃") + else: + await notice_queue.put(message_base) + async def handle_natural_lift(self) -> None: while True: if len(self.lifted_list) != 0: @@ -402,11 +415,8 @@ class NoticeHandler: } ), ) - if notice_queue.full() or unsuccessful_notice_queue.full(): - logger.warning("通知队列已满,可能是多次发送失败,消息丢弃") - else: - await notice_queue.put(message_base) + await self.put_notice(message_base) await asyncio.sleep(0.5) # 确保队列处理间隔 else: await asyncio.sleep(5) # 每5秒检查一次 @@ -449,6 +459,9 @@ class NoticeHandler: async def auto_lift_detect(self) -> None: while True: + if len(self.banned_list) == 0: + await asyncio.sleep(5) + continue for ban_record in self.banned_list: if ban_record.user_id == 0 or ban_record.lift_time == -1: continue @@ -457,7 +470,7 @@ class NoticeHandler: logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除") self.lifted_list.append(ban_record) self.banned_list.remove(ban_record) - asyncio.sleep(5) + await asyncio.sleep(5) async def send_notice(self) -> None: """ @@ -475,7 +488,7 @@ class NoticeHandler: except Exception as e: logger.error(f"发送通知消息失败: {str(e)}") await unsuccessful_notice_queue.put(to_be_send) - asyncio.sleep(0.2) + await asyncio.sleep(1) continue to_be_send: MessageBase = await notice_queue.get() try: diff --git a/src/send_handler.py b/src/send_handler.py index c375679..6cb3094 100644 --- a/src/send_handler.py +++ b/src/send_handler.py @@ -21,7 +21,7 @@ class SendHandler: def __init__(self): self.server_connection: Server.ServerConnection = None - def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: """设置Napcat连接""" self.server_connection = server_connection From ed9ecae9dcfbcd6e22efd04e8aec23b353a0c12a Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 28 Jun 2025 12:23:50 +0800 Subject: [PATCH 5/5] =?UTF-8?q?maim=5Fmessage=20logger=E4=BC=A0=E5=85=A5?= =?UTF-8?q?=EF=BC=8C=E7=89=88=E6=9C=AC=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 1 - src/__init__.py | 9 +++++++++ src/mmc_com_layer.py | 2 +- src/utils.py | 12 ++++++------ 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index c0ad7e0..54d8729 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,5 @@ requests maim_message loguru pillow -tomli tomlkit rich \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index 4298de2..b1ac77e 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,4 +1,7 @@ from enum import Enum +import tomlkit +import os +from .logger import logger class CommandType(Enum): @@ -11,3 +14,9 @@ class CommandType(Enum): def __str__(self) -> str: return self.value + + +pyproject_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "pyproject.toml") +toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read()) +version = toml_data["project"]["version"] +logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n") diff --git a/src/mmc_com_layer.py b/src/mmc_com_layer.py index ab50cca..f7fd1ad 100644 --- a/src/mmc_com_layer.py +++ b/src/mmc_com_layer.py @@ -11,7 +11,7 @@ route_config = RouteConfig( ) } ) -router = Router(route_config) +router = Router(route_config, logger) async def mmc_start_com(): diff --git a/src/utils.py b/src/utils.py index c23ee9f..caa0b56 100644 --- a/src/utils.py +++ b/src/utils.py @@ -23,7 +23,7 @@ class SSLAdapter(urllib3.PoolManager): super().__init__(*args, **kwargs) -async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict: +async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: """ 获取群相关信息 @@ -45,7 +45,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d return socket_response.get("data") -async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict: +async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: """ 获取群详细信息 @@ -67,7 +67,7 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in return socket_response.get("data") -async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict: +async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None: """ 获取群成员信息 @@ -133,7 +133,7 @@ def convert_image_to_gif(image_base64: str) -> str: return image_base64 -async def get_self_info(websocket: Server.ServerConnection) -> dict: +async def get_self_info(websocket: Server.ServerConnection) -> dict | None: """ 获取自身信息 Parameters: @@ -169,7 +169,7 @@ def get_image_format(raw_data: str) -> str: return Image.open(io.BytesIO(image_bytes)).format.lower() -async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict: +async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None: """ 获取陌生人信息 Parameters: @@ -194,7 +194,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> return response.get("data") -async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict: +async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None: """ 获取消息详情,可能为空 Parameters: