diff --git a/src/utils.py b/src/utils.py index 78b0d0c..c4481e8 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,5 +1,6 @@ import websockets as Server import json +import asyncio import base64 import uuid import urllib3 @@ -95,6 +96,32 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use return socket_response.get("data") +async def get_group_member_list(websocket: Server.ServerConnection, group_id: int) -> list | None: + """ + 获取群成员列表 + + 返回值需要处理可能为空的情况 + """ + logger.debug(f"获取群成员列表中,群号: {group_id}") + request_uuid = str(uuid.uuid4()) + payload = json.dumps({ + "action": "get_group_member_list", + "params": {"group_id": group_id, "no_cache": False}, + "echo": request_uuid, + }) + try: + await websocket.send(payload) + socket_response: dict = await get_response(request_uuid, 30) + except TimeoutError: + logger.error(f"获取群成员列表超时,群号: {group_id}") + return None + except Exception as e: + logger.error(f"获取群成员列表失败: {e}") + return None + logger.debug(socket_response) + return socket_response.get("data") + + async def get_image_base64(url: str) -> str: # sourcery skip: raise-specific-error """获取图片/表情包的Base64""" @@ -271,34 +298,37 @@ async def read_ban_list( ban_list = db_manager.get_ban_records() lifted_list: List[BanUser] = [] logger.info("已经读取禁言列表") - for ban_record in ban_list: - if ban_record.user_id == 0: - fetched_group_info = await get_group_info(websocket, ban_record.group_id) - if fetched_group_info is None: - logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除") - lifted_list.append(ban_record) - ban_list.remove(ban_record) - continue - group_all_shut: int = fetched_group_info.get("group_all_shut") - if group_all_shut == 0: - lifted_list.append(ban_record) - ban_list.remove(ban_record) - continue - else: - fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id) - if fetched_member_info is None: - logger.warning( - f"无法获取群成员信息,用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除" - ) - lifted_list.append(ban_record) - ban_list.remove(ban_record) - continue - lift_ban_time: int = fetched_member_info.get("shut_up_timestamp") - if lift_ban_time == 0: - lifted_list.append(ban_record) - ban_list.remove(ban_record) - else: - ban_record.lift_time = lift_ban_time + tasks = [] + for record in ban_list: + if record.user_id == 0: # 群全体禁言 + async def handle_group(r=record): + try: + group_info = await get_group_info(websocket, r.group_id) + except Exception as e: + logger.warning(f"获取群信息失败(群号: {r.group_id}):{e},保留禁言状态") + return None + if not group_info or group_info.get("group_all_shut") == 0: + return r + return None + tasks.append(handle_group()) + else: # 普通用户 + async def handle_user(r=record): + try: + member_info = await get_member_info(websocket, r.group_id, r.user_id) + except Exception as e: + logger.warning(f"获取成员信息失败(群号: {r.group_id} 用户ID: {r.user_id}):{e}") + return None + if not member_info: + return None + lift_time = member_info.get("shut_up_timestamp", 0) + if lift_time == 0: + return r + r.lift_time = lift_time + return None + tasks.append(handle_user()) + results = await asyncio.gather(*tasks) + lifted_list = [r for r in results if r is not None] + ban_list = [r for r in ban_list if r not in lifted_list] db_manager.update_ban_record(ban_list) return ban_list, lifted_list except Exception as e: