diff --git a/src/__init__.py b/src/__init__.py index a35ff0e..e12aede 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,77 +1,78 @@ -from enum import Enum - - -class MetaEventType: - lifecycle = "lifecycle" # 生命周期 - - class Lifecycle: - connect = "connect" # 生命周期 - WebSocket 连接成功 - - heartbeat = "heartbeat" # 心跳 - - -class MessageType: # 接受消息大类 - private = "private" # 私聊消息 - - class Private: - friend = "friend" # 私聊消息 - 好友 - group = "group" # 私聊消息 - 群临时 - group_self = "group_self" # 私聊消息 - 群中自身发送 - other = "other" # 私聊消息 - 其他 - - group = "group" # 群聊消息 - - class Group: - normal = "normal" # 群聊消息 - 普通 - anonymous = "anonymous" # 群聊消息 - 匿名消息 - notice = "notice" # 群聊消息 - 系统提示 - - -class NoticeType: # 通知事件 - friend_recall = "friend_recall" # 私聊消息撤回 - group_recall = "group_recall" # 群聊消息撤回 - notify = "notify" - - class Notify: - poke = "poke" # 戳一戳 - - -class RealMessageType: # 实际消息分类 - text = "text" # 纯文本 - face = "face" # qq表情 - image = "image" # 图片 - record = "record" # 语音 - video = "video" # 视频 - at = "at" # @某人 - rps = "rps" # 猜拳魔法表情 - dice = "dice" # 骰子 - shake = "shake" # 私聊窗口抖动(只收) - poke = "poke" # 群聊戳一戳 - share = "share" # 链接分享(json形式) - reply = "reply" # 回复消息 - forward = "forward" # 转发消息 - node = "node" # 转发消息节点 - - -class MessageSentType: - private = "private" - - class Private: - friend = "friend" - group = "group" - - group = "group" - - class Group: - normal = "normal" - - -class CommandType(Enum): - """命令类型""" - - GROUP_BAN = "set_group_ban" # 禁言用户 - GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 - GROUP_KICK = "set_group_kick" # 踢出群聊 - - def __str__(self) -> str: - return self.value +from enum import Enum + + +class MetaEventType: + lifecycle = "lifecycle" # 生命周期 + + class Lifecycle: + connect = "connect" # 生命周期 - WebSocket 连接成功 + + heartbeat = "heartbeat" # 心跳 + + +class MessageType: # 接受消息大类 + private = "private" # 私聊消息 + + class Private: + friend = "friend" # 私聊消息 - 好友 + group = "group" # 私聊消息 - 群临时 + group_self = "group_self" # 私聊消息 - 群中自身发送 + other = "other" # 私聊消息 - 其他 + + group = "group" # 群聊消息 + + class Group: + normal = "normal" # 群聊消息 - 普通 + anonymous = "anonymous" # 群聊消息 - 匿名消息 + notice = "notice" # 群聊消息 - 系统提示 + + +class NoticeType: # 通知事件 + friend_recall = "friend_recall" # 私聊消息撤回 + group_recall = "group_recall" # 群聊消息撤回 + notify = "notify" + + class Notify: + poke = "poke" # 戳一戳 + + +class RealMessageType: # 实际消息分类 + text = "text" # 纯文本 + face = "face" # qq表情 + image = "image" # 图片 + record = "record" # 语音 + video = "video" # 视频 + at = "at" # @某人 + rps = "rps" # 猜拳魔法表情 + dice = "dice" # 骰子 + shake = "shake" # 私聊窗口抖动(只收) + poke = "poke" # 群聊戳一戳 + share = "share" # 链接分享(json形式) + reply = "reply" # 回复消息 + forward = "forward" # 转发消息 + node = "node" # 转发消息节点 + + +class MessageSentType: + private = "private" + + class Private: + friend = "friend" + group = "group" + + group = "group" + + class Group: + normal = "normal" + + +class CommandType(Enum): + """命令类型""" + + GROUP_BAN = "set_group_ban" # 禁言用户 + GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 + GROUP_KICK = "set_group_kick" # 踢出群聊 + SEND_POKE = "send_poke" # 戳一戳 + + def __str__(self) -> str: + return self.value diff --git a/src/send_handler.py b/src/send_handler.py index 74646b6..8aae094 100644 --- a/src/send_handler.py +++ b/src/send_handler.py @@ -1,313 +1,342 @@ -import json -import websockets as Server -import uuid -from maim_message import ( - UserInfo, - GroupInfo, - Seg, - BaseMessageInfo, - MessageBase, -) -from typing import Dict, Any, Tuple - -from . import CommandType -from .config import global_config -from .response_pool import get_response -from .logger import logger -from .utils import get_image_format, convert_image_to_gif - - -class SendHandler: - def __init__(self): - self.server_connection: Server.ServerConnection = None - - async def handle_message(self, raw_message_base_dict: dict) -> None: - raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict) - message_segment: Seg = raw_message_base.message_segment - logger.info("接收到来自MaiBot的消息,处理中") - if message_segment.type == "command": - return await self.send_command(raw_message_base) - else: - return await self.send_normal_message(raw_message_base) - - async def send_normal_message(self, raw_message_base: MessageBase) -> None: - """ - 处理普通消息发送 - """ - logger.info("处理普通信息中") - message_info: BaseMessageInfo = raw_message_base.message_info - message_segment: Seg = raw_message_base.message_segment - group_info: GroupInfo = message_info.group_info - user_info: UserInfo = message_info.user_info - target_id: int = None - action: str = None - id_name: str = None - processed_message: list = [] - try: - processed_message = await self.handle_seg_recursive(message_segment) - except Exception as e: - logger.error(f"处理消息时发生错误: {e}") - return - - if not processed_message: - logger.critical("现在暂时不支持解析此回复!") - return None - - if group_info and user_info: - logger.debug("发送群聊消息") - target_id = group_info.group_id - action = "send_group_msg" - id_name = "group_id" - elif user_info: - logger.debug("发送私聊消息") - target_id = user_info.user_id - action = "send_private_msg" - id_name = "user_id" - else: - logger.error("无法识别的消息类型") - return - logger.info("尝试发送到napcat") - response = await self.send_message_to_napcat( - action, - { - id_name: target_id, - "message": processed_message, - }, - ) - if response.get("status") == "ok": - logger.info("消息发送成功") - else: - logger.warning(f"消息发送失败,napcat返回:{str(response)}") - - async def send_command(self, raw_message_base: MessageBase) -> None: - """ - 处理命令类 - """ - logger.info("处理命令中") - message_info: BaseMessageInfo = raw_message_base.message_info - message_segment: Seg = raw_message_base.message_segment - group_info: GroupInfo = message_info.group_info - seg_data: Dict[str, Any] = message_segment.data - command_name: str = seg_data.get("name") - try: - match command_name: - case CommandType.GROUP_BAN.name: - command, args_dict = self.handle_ban_command(seg_data.get("args"), group_info) - case CommandType.GROUP_WHOLE_BAN.name: - command, args_dict = self.handle_whole_ban_command(seg_data.get("args"), group_info) - case CommandType.GROUP_KICK.name: - command, args_dict = self.handle_kick_command(seg_data.get("args"), group_info) - case _: - logger.error(f"未知命令: {command_name}") - return - except Exception as e: - logger.error(f"处理命令时发生错误: {e}") - return None - - if not command or not args_dict: - logger.error("命令或参数缺失") - return None - - response = await self.send_message_to_napcat(command, args_dict) - if response.get("status") == "ok": - logger.info(f"命令 {command_name} 执行成功") - else: - logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}") - - def get_level(self, seg_data: Seg) -> int: - if seg_data.type == "seglist": - return 1 + max(self.get_level(seg) for seg in seg_data.data) - else: - return 1 - - async def handle_seg_recursive(self, seg_data: Seg) -> list: - payload: list = [] - if seg_data.type == "seglist": - # level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用 - if not seg_data.data: - return [] - for seg in seg_data.data: - payload = self.process_message_by_type(seg, payload) - else: - payload = self.process_message_by_type(seg_data, payload) - return payload - - def process_message_by_type(self, seg: Seg, payload: list) -> list: - # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression - new_payload = payload - if seg.type == "reply": - target_id = seg.data - if target_id == "notice": - return payload - new_payload = self.build_payload(payload, self.handle_reply_message(target_id), True) - elif seg.type == "text": - text = seg.data - if not text: - return payload - new_payload = self.build_payload(payload, self.handle_text_message(text), False) - elif seg.type == "face": - logger.warning("MaiBot 发送了qq原生表情,暂时不支持") - elif seg.type == "image": - image = seg.data - new_payload = self.build_payload(payload, self.handle_image_message(image), False) - elif seg.type == "emoji": - emoji = seg.data - new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False) - elif seg.type == "voice": - voice = seg.data - new_payload = self.build_payload(payload, self.handle_voice_message(voice), False) - return new_payload - - def build_payload(self, payload: list, addon: dict, is_reply: bool = False) -> list: - # sourcery skip: for-append-to-extend, merge-list-append, simplify-generator - """构建发送的消息体""" - if is_reply: - temp_list = [] - temp_list.append(addon) - for i in payload: - if i.get("type") == "reply": - logger.debug("检测到多个回复,使用最新的回复") - continue - temp_list.append(i) - return temp_list - else: - payload.append(addon) - return payload - - def handle_reply_message(self, id: str) -> dict: - """处理回复消息""" - return {"type": "reply", "data": {"id": id}} - - def handle_text_message(self, message: str) -> dict: - """处理文本消息""" - return {"type": "text", "data": {"text": message}} - - def handle_image_message(self, encoded_image: str) -> dict: - """处理图片消息""" - return { - "type": "image", - "data": { - "file": f"base64://{encoded_image}", - "subtype": 0, - }, - } # base64 编码的图片 - - def handle_emoji_message(self, encoded_emoji: str) -> dict: - """处理表情消息""" - encoded_image = encoded_emoji - image_format = get_image_format(encoded_emoji) - if image_format != "gif": - encoded_image = convert_image_to_gif(encoded_emoji) - return { - "type": "image", - "data": { - "file": f"base64://{encoded_image}", - "subtype": 1, - "summary": "[动画表情]", - }, - } - - def handle_voice_message(self, encoded_voice: str) -> dict: - """处理语音消息""" - if not global_config.voice.use_tts: - logger.warning("未启用语音消息处理") - return {} - if not encoded_voice: - return {} - return { - "type": "record", - "data": {"file": f"base64://{encoded_voice}"}, - } - - def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: - """处理封禁命令 - - Args: - args (Dict[str, Any]): 参数字典 - group_info (GroupInfo): 群聊信息(对应目标群聊) - - Returns: - Tuple[CommandType, Dict[str, Any]] - """ - duration: int = int(args["duration"]) - user_id: int = int(args["qq_id"]) - group_id: int = int(group_info.group_id) - if duration <= 0: - raise ValueError("封禁时间必须大于0") - if not user_id or not group_id: - raise ValueError("封禁命令缺少必要参数") - if duration > 2592000: - raise ValueError("封禁时间不能超过30天") - return ( - CommandType.GROUP_BAN.value, - { - "group_id": group_id, - "user_id": user_id, - "duration": duration, - }, - ) - - def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: - """处理全体禁言命令 - - Args: - args (Dict[str, Any]): 参数字典 - group_info (GroupInfo): 群聊信息(对应目标群聊) - - Returns: - Tuple[CommandType, Dict[str, Any]] - """ - enable = args["enable"] - assert isinstance(enable, bool), "enable参数必须是布尔值" - group_id: int = int(group_info.group_id) - if group_id <= 0: - raise ValueError("群组ID无效") - return ( - CommandType.GROUP_WHOLE_BAN.value, - { - "group_id": group_id, - "enable": enable, - }, - ) - - def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: - """处理群成员踢出命令 - - Args: - args (Dict[str, Any]): 参数字典 - group_info (GroupInfo): 群聊信息(对应目标群聊) - - Returns: - Tuple[CommandType, Dict[str, Any]] - """ - user_id: int = int(args["qq_id"]) - group_id: int = int(group_info.group_id) - if group_id <= 0: - raise ValueError("群组ID无效") - if user_id <= 0: - raise ValueError("用户ID无效") - return ( - CommandType.GROUP_KICK.value, - { - "group_id": group_id, - "user_id": user_id, - "reject_add_request": False, # 不拒绝加群请求 - }, - ) - - async def send_message_to_napcat(self, action: str, params: dict) -> dict: - request_uuid = str(uuid.uuid4()) - payload = json.dumps({"action": action, "params": params, "echo": request_uuid}) - await self.server_connection.send(payload) - try: - response = await get_response(request_uuid) - except TimeoutError: - logger.error("发送消息超时,未收到响应") - return {"status": "error", "message": "timeout"} - except Exception as e: - logger.error(f"发送消息失败: {e}") - return {"status": "error", "message": str(e)} - return response - - -send_handler = SendHandler() +import json +import websockets as Server +import uuid +from maim_message import ( + UserInfo, + GroupInfo, + Seg, + BaseMessageInfo, + MessageBase, +) +from typing import Dict, Any, Tuple + +from . import CommandType +from .config import global_config +from .response_pool import get_response +from .logger import logger +from .utils import get_image_format, convert_image_to_gif + + +class SendHandler: + def __init__(self): + self.server_connection: Server.ServerConnection = None + + async def handle_message(self, raw_message_base_dict: dict) -> None: + raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict) + message_segment: Seg = raw_message_base.message_segment + logger.info("接收到来自MaiBot的消息,处理中") + if message_segment.type == "command": + return await self.send_command(raw_message_base) + else: + return await self.send_normal_message(raw_message_base) + + async def send_normal_message(self, raw_message_base: MessageBase) -> None: + """ + 处理普通消息发送 + """ + logger.info("处理普通信息中") + message_info: BaseMessageInfo = raw_message_base.message_info + message_segment: Seg = raw_message_base.message_segment + group_info: GroupInfo = message_info.group_info + user_info: UserInfo = message_info.user_info + target_id: int = None + action: str = None + id_name: str = None + processed_message: list = [] + try: + processed_message = await self.handle_seg_recursive(message_segment) + except Exception as e: + logger.error(f"处理消息时发生错误: {e}") + return + + if not processed_message: + logger.critical("现在暂时不支持解析此回复!") + return None + + if group_info and user_info: + logger.debug("发送群聊消息") + target_id = group_info.group_id + action = "send_group_msg" + id_name = "group_id" + elif user_info: + logger.debug("发送私聊消息") + target_id = user_info.user_id + action = "send_private_msg" + id_name = "user_id" + else: + logger.error("无法识别的消息类型") + return + logger.info("尝试发送到napcat") + response = await self.send_message_to_napcat( + action, + { + id_name: target_id, + "message": processed_message, + }, + ) + if response.get("status") == "ok": + logger.info("消息发送成功") + else: + logger.warning(f"消息发送失败,napcat返回:{str(response)}") + + async def send_command(self, raw_message_base: MessageBase) -> None: + """ + 处理命令类 + """ + logger.info("处理命令中") + message_info: BaseMessageInfo = raw_message_base.message_info + message_segment: Seg = raw_message_base.message_segment + group_info: GroupInfo = message_info.group_info + seg_data: Dict[str, Any] = message_segment.data + command_name: str = seg_data.get("name") + try: + match command_name: + case CommandType.GROUP_BAN.name: + command, args_dict = self.handle_ban_command(seg_data.get("args"), group_info) + case CommandType.GROUP_WHOLE_BAN.name: + command, args_dict = self.handle_whole_ban_command(seg_data.get("args"), group_info) + case CommandType.GROUP_KICK.name: + command, args_dict = self.handle_kick_command(seg_data.get("args"), group_info) + case CommandType.SEND_POKE.name: + command, args_dict = self.handle_poke_command(seg_data.get("args"), group_info) + case _: + logger.error(f"未知命令: {command_name}") + return + except Exception as e: + logger.error(f"处理命令时发生错误: {e}") + return None + + if not command or not args_dict: + logger.error("命令或参数缺失") + return None + + response = await self.send_message_to_napcat(command, args_dict) + if response.get("status") == "ok": + logger.info(f"命令 {command_name} 执行成功") + else: + logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}") + + def get_level(self, seg_data: Seg) -> int: + if seg_data.type == "seglist": + return 1 + max(self.get_level(seg) for seg in seg_data.data) + else: + return 1 + + async def handle_seg_recursive(self, seg_data: Seg) -> list: + payload: list = [] + if seg_data.type == "seglist": + # level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用 + if not seg_data.data: + return [] + for seg in seg_data.data: + payload = self.process_message_by_type(seg, payload) + else: + payload = self.process_message_by_type(seg_data, payload) + return payload + + def process_message_by_type(self, seg: Seg, payload: list) -> list: + # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression + new_payload = payload + if seg.type == "reply": + target_id = seg.data + if target_id == "notice": + return payload + new_payload = self.build_payload(payload, self.handle_reply_message(target_id), True) + elif seg.type == "text": + text = seg.data + if not text: + return payload + new_payload = self.build_payload(payload, self.handle_text_message(text), False) + elif seg.type == "face": + logger.warning("MaiBot 发送了qq原生表情,暂时不支持") + elif seg.type == "image": + image = seg.data + new_payload = self.build_payload(payload, self.handle_image_message(image), False) + elif seg.type == "emoji": + emoji = seg.data + new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False) + elif seg.type == "voice": + voice = seg.data + new_payload = self.build_payload(payload, self.handle_voice_message(voice), False) + return new_payload + + def build_payload(self, payload: list, addon: dict, is_reply: bool = False) -> list: + # sourcery skip: for-append-to-extend, merge-list-append, simplify-generator + """构建发送的消息体""" + if is_reply: + temp_list = [] + temp_list.append(addon) + for i in payload: + if i.get("type") == "reply": + logger.debug("检测到多个回复,使用最新的回复") + continue + temp_list.append(i) + return temp_list + else: + payload.append(addon) + return payload + + def handle_reply_message(self, id: str) -> dict: + """处理回复消息""" + return {"type": "reply", "data": {"id": id}} + + def handle_text_message(self, message: str) -> dict: + """处理文本消息""" + return {"type": "text", "data": {"text": message}} + + def handle_image_message(self, encoded_image: str) -> dict: + """处理图片消息""" + return { + "type": "image", + "data": { + "file": f"base64://{encoded_image}", + "subtype": 0, + }, + } # base64 编码的图片 + + def handle_emoji_message(self, encoded_emoji: str) -> dict: + """处理表情消息""" + encoded_image = encoded_emoji + image_format = get_image_format(encoded_emoji) + if image_format != "gif": + encoded_image = convert_image_to_gif(encoded_emoji) + return { + "type": "image", + "data": { + "file": f"base64://{encoded_image}", + "subtype": 1, + "summary": "[动画表情]", + }, + } + + def handle_voice_message(self, encoded_voice: str) -> dict: + """处理语音消息""" + if not global_config.voice.use_tts: + logger.warning("未启用语音消息处理") + return {} + if not encoded_voice: + return {} + return { + "type": "record", + "data": {"file": f"base64://{encoded_voice}"}, + } + + def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + """处理封禁命令 + + Args: + args (Dict[str, Any]): 参数字典 + group_info (GroupInfo): 群聊信息(对应目标群聊) + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + duration: int = int(args["duration"]) + user_id: int = int(args["qq_id"]) + group_id: int = int(group_info.group_id) + if duration <= 0: + raise ValueError("封禁时间必须大于0") + if not user_id or not group_id: + raise ValueError("封禁命令缺少必要参数") + if duration > 2592000: + raise ValueError("封禁时间不能超过30天") + return ( + CommandType.GROUP_BAN.value, + { + "group_id": group_id, + "user_id": user_id, + "duration": duration, + }, + ) + + def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + """处理全体禁言命令 + + Args: + args (Dict[str, Any]): 参数字典 + group_info (GroupInfo): 群聊信息(对应目标群聊) + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + enable = args["enable"] + assert isinstance(enable, bool), "enable参数必须是布尔值" + group_id: int = int(group_info.group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + return ( + CommandType.GROUP_WHOLE_BAN.value, + { + "group_id": group_id, + "enable": enable, + }, + ) + + def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + """处理群成员踢出命令 + + Args: + args (Dict[str, Any]): 参数字典 + group_info (GroupInfo): 群聊信息(对应目标群聊) + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + user_id: int = int(args["qq_id"]) + group_id: int = int(group_info.group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + if user_id <= 0: + raise ValueError("用户ID无效") + return ( + CommandType.GROUP_KICK.value, + { + "group_id": group_id, + "user_id": user_id, + "reject_add_request": False, # 不拒绝加群请求 + }, + ) + + def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + """处理戳一戳命令 + + Args: + args (Dict[str, Any]): 参数字典 + group_info (GroupInfo): 群聊信息(对应目标群聊) + + Returns: + Tuple[CommandType, Dict[str, Any]] + """ + user_id: int = int(args["qq_id"]) + if group_info == None: + group_id = None + else: + group_id: int = int(group_info.group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + if user_id <= 0: + raise ValueError("用户ID无效") + return ( + CommandType.SEND_POKE.value, + { + "group_id": group_id, + "user_id": user_id, + }, + ) + + async def send_message_to_napcat(self, action: str, params: dict) -> dict: + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": action, "params": params, "echo": request_uuid}) + await self.server_connection.send(payload) + try: + response = await get_response(request_uuid) + except TimeoutError: + logger.error("发送消息超时,未收到响应") + return {"status": "error", "message": "timeout"} + except Exception as e: + logger.error(f"发送消息失败: {e}") + return {"status": "error", "message": str(e)} + return response + + +send_handler = SendHandler()