diff --git a/main.py b/main.py index 07b3002..64d8c32 100644 --- a/main.py +++ b/main.py @@ -52,7 +52,7 @@ async def main(): async def napcat_server(): logger.info("正在启动adapter...") - async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port) as server: + async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port, max_size=2**26) as server: logger.info( f"Adapter已启动,监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}" ) diff --git a/pyproject.toml b/pyproject.toml index 61d7c1d..42e56eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "MaiBotNapcatAdapter" -version = "0.4.4" +version = "0.4.7" description = "A MaiBot adapter for Napcat" [tool.ruff] diff --git a/src/database.py b/src/database.py index f9a94d1..af193da 100644 --- a/src/database.py +++ b/src/database.py @@ -123,14 +123,26 @@ class DatabaseManager: 其同时还是简化版的更新方式。 """ with Session(self.engine) as session: - db_record = DB_BanUser( - user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + # 检查记录是否已存在 + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id ) - session.add(db_record) + existing_record = session.exec(statement).first() + if existing_record: + # 如果记录已存在,更新 lift_time + existing_record.lift_time = ban_record.lift_time + session.add(existing_record) + logger.debug(f"更新禁言记录: {ban_record}") + else: + # 如果记录不存在,创建新记录 + db_record = DB_BanUser( + user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + ) + session.add(db_record) + logger.debug(f"创建新禁言记录: {ban_record}") session.commit() - logger.debug(f"创建/更新禁言记录: {ban_record}") - def delete_ban_record(self, ban_record: BanUser) -> bool: + def delete_ban_record(self, ban_record: BanUser): """ 删除特定用户在特定群组中的禁言记录。 一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。 diff --git a/src/recv_handler/__init__.py b/src/recv_handler/__init__.py index 0deaede..422041b 100644 --- a/src/recv_handler/__init__.py +++ b/src/recv_handler/__init__.py @@ -84,4 +84,4 @@ class CommandType(Enum): return self.value -ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command"] +ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command", "voiceurl", "music", "videourl", "file"] diff --git a/src/recv_handler/message_handler.py b/src/recv_handler/message_handler.py index aa327e7..82aeaf9 100644 --- a/src/recv_handler/message_handler.py +++ b/src/recv_handler/message_handler.py @@ -60,20 +60,6 @@ class MessageHandler: bool: 是否允许聊天 """ logger.debug(f"群聊id: {group_id}, 用户id: {user_id}") - if global_config.chat.ban_qq_bot and group_id and not ignore_bot: - logger.debug("开始判断是否为机器人") - member_info = await get_member_info(self.server_connection, group_id, user_id) - if member_info: - is_bot = member_info.get("is_robot") - if is_bot is None: - logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新") - else: - if is_bot: - logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单") - self.bot_id_list[user_id] = True - return False - else: - self.bot_id_list[user_id] = False logger.debug("开始检查聊天白名单/黑名单") if group_id: if global_config.chat.group_list_type == "whitelist" and group_id not in global_config.chat.group_list: @@ -92,6 +78,22 @@ class MessageHandler: if user_id in global_config.chat.ban_user_id and not ignore_global_list: logger.warning("用户在全局黑名单中,消息被丢弃") return False + + if global_config.chat.ban_qq_bot and group_id and not ignore_bot: + logger.debug("开始判断是否为机器人") + member_info = await get_member_info(self.server_connection, group_id, user_id) + if member_info: + is_bot = member_info.get("is_robot") + if is_bot is None: + logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新") + else: + if is_bot: + logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单") + self.bot_id_list[user_id] = True + return False + else: + self.bot_id_list[user_id] = False + return True async def handle_raw_message(self, raw_message: dict) -> None: @@ -422,8 +424,14 @@ class MessageHandler: """ message_data: dict = raw_message.get("data") file: str = message_data.get("file") + if not file: + logger.warning("语音消息缺少文件信息") + return None try: record_detail = await get_record_detail(self.server_connection, file) + if not record_detail: + logger.warning("获取语音消息详情失败") + return None audio_base64: str = record_detail.get("base64") except Exception as e: logger.error(f"语音消息处理失败: {str(e)}") diff --git a/src/response_pool.py b/src/response_pool.py index c41ed7f..41feb9e 100644 --- a/src/response_pool.py +++ b/src/response_pool.py @@ -8,19 +8,19 @@ response_dict: Dict = {} response_time_dict: Dict = {} -async def get_response(request_id: str) -> dict: - retry_count = 0 - max_retries = 50 # 10秒超时 - while request_id not in response_dict: - retry_count += 1 - if retry_count >= max_retries: - raise TimeoutError(f"请求超时,未收到响应,request_id: {request_id}") - await asyncio.sleep(0.2) - response = response_dict.pop(request_id) +async def get_response(request_id: str, timeout: int = 10) -> dict: + response = await asyncio.wait_for(_get_response(request_id), timeout) _ = response_time_dict.pop(request_id) logger.trace(f"响应信息id: {request_id} 已从响应字典中取出") return response +async def _get_response(request_id: str) -> dict: + """ + 内部使用的获取响应函数,主要用于在需要时获取响应 + """ + while request_id not in response_dict: + await asyncio.sleep(0.2) + return response_dict.pop(request_id) async def put_response(response: dict): echo_id = response.get("echo") diff --git a/src/send_handler.py b/src/send_handler.py index 70cfa79..cf64a44 100644 --- a/src/send_handler.py +++ b/src/send_handler.py @@ -178,6 +178,9 @@ class SendHandler: elif seg.type == "videourl": video_url = seg.data new_payload = self.build_payload(payload, self.handle_videourl_message(video_url), False) + elif seg.type == "file": + file_path = seg.data + new_payload = self.build_payload(payload, self.handle_file_message(file_path), False) return new_payload def build_payload(self, payload: list, addon: dict, is_reply: bool = False) -> list: @@ -261,6 +264,13 @@ class SendHandler: "data": {"file": video_url}, } + def handle_file_message(self, file_path: str) -> dict: + """处理文件消息""" + return { + "type": "file", + "data": {"file": f"file://{file_path}"}, + } + def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理封禁命令 diff --git a/src/utils.py b/src/utils.py index 6e07da4..78b0d0c 100644 --- a/src/utils.py +++ b/src/utils.py @@ -208,7 +208,7 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: Uni payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}) try: await websocket.send(payload) - response: dict = await get_response(request_uuid) + response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 except TimeoutError: logger.error(f"获取消息详情超时,消息ID: {message_id}") return None @@ -242,7 +242,7 @@ async def get_record_detail( ) try: await websocket.send(payload) - response: dict = await get_response(request_uuid) + response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 except TimeoutError: logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}") return None