From 5c57ba9c85b1456751071e8a92c961b2c466f2dc Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 28 Jun 2025 11:44:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AEbug=EF=BC=8C=E6=94=B9=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E5=8F=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 ++- main.py | 6 ++-- pyproject.toml | 2 +- src/database.py | 54 ++++++++++++++++++++++------- src/recv_handler/message_handler.py | 4 +-- src/recv_handler/message_sending.py | 2 +- src/recv_handler/notice_handler.py | 43 +++++++++++++++-------- src/send_handler.py | 2 +- 8 files changed, 80 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 6c9e2d2..f0977e5 100644 --- a/.gitignore +++ b/.gitignore @@ -272,4 +272,6 @@ $RECYCLE.BIN/ config.toml config.toml.back test -data/NapcatAdapter.db \ No newline at end of file +data/NapcatAdapter.db +data/NapcatAdapter.db-shm +data/NapcatAdapter.db-wal \ No newline at end of file diff --git a/main.py b/main.py index 12657af..a928191 100644 --- a/main.py +++ b/main.py @@ -16,9 +16,9 @@ message_queue = asyncio.Queue() async def message_recv(server_connection: Server.ServerConnection): - message_handler.set_server_connection(server_connection) - notice_handler.set_server_connection(server_connection) - send_handler.set_server_connection(server_connection) + await message_handler.set_server_connection(server_connection) + asyncio.create_task(notice_handler.set_server_connection(server_connection)) + await send_handler.set_server_connection(server_connection) async for raw_message in server_connection: logger.debug( f"{raw_message[:100]}..." diff --git a/pyproject.toml b/pyproject.toml index 0fedfb2..2f6423d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "MaiBotNapcatAdapter" -version = "0.2.6" +version = "0.3.0" description = "A MaiBot adapter for Napcat" [tool.ruff] diff --git a/src/database.py b/src/database.py index 2f58b50..45100cb 100644 --- a/src/database.py +++ b/src/database.py @@ -1,5 +1,6 @@ import os from typing import Optional, List +from dataclasses import dataclass from sqlmodel import Field, Session, SQLModel, create_engine, select from src.logger import logger @@ -13,7 +14,18 @@ from src.logger import logger """ -class BanUser(SQLModel, table=True): +@dataclass +class BanUser: + """ + 程序处理使用的实例 + """ + + user_id: int + group_id: int + lift_time: Optional[int] = Field(default=-1) + + +class DB_BanUser(SQLModel, table=True): """ 表示数据库中的用户禁言记录。 使用双重主键 @@ -24,7 +36,7 @@ class BanUser(SQLModel, table=True): lift_time: Optional[int] # 禁言解除的时间(时间戳) -def is_identical(self, obj1: BanUser, obj2: BanUser) -> bool: +def is_identical(obj1: BanUser, obj2: BanUser) -> bool: """ 检查两个 BanUser 对象是否相同。 """ @@ -51,15 +63,16 @@ class DatabaseManager: logger.success("数据库和表已创建或已存在") def update_ban_record(self, ban_list: List[BanUser]) -> None: + # sourcery skip: class-extract-method """ 更新禁言列表到数据库。 支持在不存在时创建新记录,对于多余的项目自动删除。 """ with Session(self.engine) as session: - all_records = session.exec(select(BanUser)).all() + all_records = session.exec(select(DB_BanUser)).all() for ban_user in ban_list: - statement = select(BanUser).where( - BanUser.user_id == ban_user.user_id, BanUser.group_id == ban_user.group_id + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id ) if existing_record := session.exec(statement).first(): if existing_record.lift_time == ban_user.lift_time: @@ -71,13 +84,24 @@ class DatabaseManager: logger.debug(f"更新禁言记录: {existing_record}") else: # 创建新记录 - session.add(ban_user) + db_record = DB_BanUser( + user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time + ) + session.add(db_record) logger.debug(f"创建新禁言记录: {ban_user}") # 删除不在 ban_list 中的记录 - for record in all_records: + for db_record in all_records: + record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time) if not any(is_identical(record, ban_user) for ban_user in ban_list): - session.delete(record) - logger.debug(f"删除禁言记录: {record}") + statement = select(DB_BanUser).where( + DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id + ) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + session.commit() + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: {ban_record}") session.commit() logger.info("禁言记录已更新") @@ -87,8 +111,9 @@ class DatabaseManager: 读取所有禁言记录。 """ with Session(self.engine) as session: - statement = select(BanUser) - return session.exec(statement).all() + statement = select(DB_BanUser) + records = session.exec(statement).all() + return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records] def create_ban_record(self, ban_record: BanUser) -> None: """ @@ -97,7 +122,10 @@ class DatabaseManager: 其同时还是简化版的更新方式。 """ with Session(self.engine) as session: - session.add(ban_record) + db_record = DB_BanUser( + user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + ) + session.add(db_record) session.commit() logger.debug(f"创建/更新禁言记录: {ban_record}") @@ -109,7 +137,7 @@ class DatabaseManager: user_id = ban_record.user_id group_id = ban_record.group_id with Session(self.engine) as session: - statement = select(BanUser).where(BanUser.user_id == user_id, BanUser.group_id == group_id) + statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id) if ban_record := session.exec(statement).first(): session.delete(ban_record) session.commit() diff --git a/src/recv_handler/message_handler.py b/src/recv_handler/message_handler.py index 1f65cf6..0cec56f 100644 --- a/src/recv_handler/message_handler.py +++ b/src/recv_handler/message_handler.py @@ -36,14 +36,14 @@ class MessageHandler: self.server_connection: Server.ServerConnection = None self.bot_id_list: Dict[int, bool] = {} - def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: """设置Napcat连接""" self.server_connection = server_connection async def check_allow_to_chat( self, user_id: int, - group_id: Optional[int], + group_id: Optional[int] = None, ignore_bot: Optional[bool] = False, ignore_global_list: Optional[bool] = False, ) -> bool: diff --git a/src/recv_handler/message_sending.py b/src/recv_handler/message_sending.py index de35399..2e43bbd 100644 --- a/src/recv_handler/message_sending.py +++ b/src/recv_handler/message_sending.py @@ -21,7 +21,7 @@ class MessageSending: try: send_status = await self.maibot_router.send_message(message_base) if not send_status: - raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常") + raise RuntimeError("可能是路由未正确配置或连接异常") except Exception as e: logger.error(f"发送消息失败: {str(e)}") logger.error("请检查与MaiBot之间的连接") diff --git a/src/recv_handler/notice_handler.py b/src/recv_handler/notice_handler.py index e4bd468..2d03a49 100644 --- a/src/recv_handler/notice_handler.py +++ b/src/recv_handler/notice_handler.py @@ -35,6 +35,8 @@ class NoticeHandler: """设置Napcat连接""" self.server_connection = server_connection + while self.server_connection.state != Server.State.OPEN: + await asyncio.sleep(0.5) self.banned_list, self.lifted_list = await read_ban_list(self.server_connection) asyncio.create_task(self.auto_lift_detect()) @@ -59,7 +61,7 @@ class NoticeHandler: self.banned_list.append(ban_record) db_manager.create_ban_record(ban_record) # 添加到数据库 - def _lift_operation(self, group_id: int, user_id: Optional[int]) -> None: + def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: """ 从self.lifted_group_list中移除已经解除全体禁言的群 """ @@ -77,12 +79,9 @@ class NoticeHandler: group_id = raw_message.get("group_id") user_id = raw_message.get("user_id") - # if not await self.check_allow_to_chat(user_id, group_id): - # logger.warning("notice消息被丢弃") - # return None - handled_message: Seg = None user_info: UserInfo = None + system_notice: bool = False match notice_type: case NoticeType.friend_recall: @@ -110,15 +109,17 @@ class NoticeHandler: sub_type = raw_message.get("sub_type") match sub_type: case NoticeType.GroupBan.ban: - if await message_handler.check_allow_to_chat(user_id, group_id, True, False): + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): return None logger.info("处理群禁言") handled_message, user_info = await self.handle_ban_notify(raw_message, group_id) + system_notice = True case NoticeType.GroupBan.lift_ban: - if await message_handler.check_allow_to_chat(user_id, group_id, True, False): + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): return None logger.info("处理解除群禁言") handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id) + system_notice = True case _: logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}") case _: @@ -158,8 +159,11 @@ class NoticeHandler: raw_message=json.dumps(raw_message), ) - logger.info("发送到Maibot处理通知信息") - await message_send_instance.message_send(message_base) + if system_notice: + await self.put_notice(message_base) + else: + logger.info("发送到Maibot处理通知信息") + await message_send_instance.message_send(message_base) async def handle_poke_notify(self, raw_message: dict, group_id: int, user_id: int) -> Tuple[Seg | None, UserInfo]: self_info: dict = await get_self_info(self.server_connection) @@ -355,6 +359,15 @@ class NoticeHandler: ) return seg_data, operator_info + async def put_notice(self, message_base: MessageBase) -> None: + """ + 将处理后的通知消息放入通知队列 + """ + if notice_queue.full() or unsuccessful_notice_queue.full(): + logger.warning("通知队列已满,可能是多次发送失败,消息丢弃") + else: + await notice_queue.put(message_base) + async def handle_natural_lift(self) -> None: while True: if len(self.lifted_list) != 0: @@ -402,11 +415,8 @@ class NoticeHandler: } ), ) - if notice_queue.full() or unsuccessful_notice_queue.full(): - logger.warning("通知队列已满,可能是多次发送失败,消息丢弃") - else: - await notice_queue.put(message_base) + await self.put_notice(message_base) await asyncio.sleep(0.5) # 确保队列处理间隔 else: await asyncio.sleep(5) # 每5秒检查一次 @@ -449,6 +459,9 @@ class NoticeHandler: async def auto_lift_detect(self) -> None: while True: + if len(self.banned_list) == 0: + await asyncio.sleep(5) + continue for ban_record in self.banned_list: if ban_record.user_id == 0 or ban_record.lift_time == -1: continue @@ -457,7 +470,7 @@ class NoticeHandler: logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除") self.lifted_list.append(ban_record) self.banned_list.remove(ban_record) - asyncio.sleep(5) + await asyncio.sleep(5) async def send_notice(self) -> None: """ @@ -475,7 +488,7 @@ class NoticeHandler: except Exception as e: logger.error(f"发送通知消息失败: {str(e)}") await unsuccessful_notice_queue.put(to_be_send) - asyncio.sleep(0.2) + await asyncio.sleep(1) continue to_be_send: MessageBase = await notice_queue.get() try: diff --git a/src/send_handler.py b/src/send_handler.py index c375679..6cb3094 100644 --- a/src/send_handler.py +++ b/src/send_handler.py @@ -21,7 +21,7 @@ class SendHandler: def __init__(self): self.server_connection: Server.ServerConnection = None - def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: """设置Napcat连接""" self.server_connection = server_connection