From 3cb262be9ae1131009e02cb90cc60e9ccefed339 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 8 Apr 2025 00:42:29 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BD=AC=E5=8F=91=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=8E=A5=E6=94=B6=EF=BC=8C=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=8F=91=E9=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 1 + main.py | 36 +++++-- src/logger.py | 2 +- src/message_queue.py | 10 +- src/recv_handler.py | 230 ++++++++++++++++++++++++++++++++++++++----- src/send_handler.py | 10 +- src/utils.py | 37 ++++--- test.py | 18 ++++ 8 files changed, 287 insertions(+), 57 deletions(-) create mode 100644 test.py diff --git a/README.md b/README.md index 9cae1b5..29bb834 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ enable_temp = false - [x] 文本解析 - [x] 图片解析 - [x] 文本与消息混合解析 + - [x] 转发解析(含图片动态解析) - [ ] 链接解析 - [ ] 戳一戳解析 - [ ] 语音解析(?) diff --git a/main.py b/main.py index 26fff8f..d9d13fa 100644 --- a/main.py +++ b/main.py @@ -1,39 +1,54 @@ import asyncio import sys import json -import websockets.asyncio.server as Server +import websockets as Server from src.logger import logger from src.recv_handler import recv_handler from src.send_handler import send_handler from src.config import global_config from src.mmc_com_layer import mmc_start_com, mmc_stop_com, router -from src.message_queue import recv_queue +from src.message_queue import recv_queue, message_queue async def message_recv(server_connection: Server.ServerConnection): recv_handler.server_connection = server_connection send_handler.server_connection = server_connection - # asyncio.create_task(send_handler.test_send()) async for raw_message in server_connection: - logger.debug(raw_message) + logger.debug(f"{raw_message[:80]}..." if len(raw_message) > 80 else raw_message) decoded_raw_message: dict = json.loads(raw_message) post_type = decoded_raw_message.get("post_type") if post_type == "meta_event": - await recv_handler.handle_meta_event(decoded_raw_message) + await message_queue.put(decoded_raw_message) elif post_type == "message": - await recv_handler.handle_raw_message(decoded_raw_message) + await message_queue.put(decoded_raw_message) elif post_type == "notice": pass elif post_type is None: - recv_queue.put(decoded_raw_message) + await recv_queue.put(decoded_raw_message) + + +async def message_process(): + while True: + message = await message_queue.get() + post_type = message.get("post_type") + if post_type == "message": + await recv_handler.handle_raw_message(message) + elif post_type == "meta_event": + await recv_handler.handle_meta_event(message) + elif post_type == "notice": + pass + else: + logger.warning(f"未知的post_type: {post_type}") + message_queue.task_done() + await asyncio.sleep(0.05) async def main(): recv_handler.maibot_router = router - _ = await asyncio.gather(mmc_server(), mmc_start_com()) + _ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process()) -async def mmc_server(): +async def napcat_server(): logger.info("正在启动adapter...") async with Server.serve( message_recv, global_config.server_host, global_config.server_port @@ -50,7 +65,8 @@ async def graceful_shutdown(): await mmc_stop_com() tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] for task in tasks: - task.cancel() + if not task.done(): + task.cancel() await asyncio.gather(*tasks, return_exceptions=True) except Exception as e: diff --git a/src/logger.py b/src/logger.py index 854f5a2..7ada699 100644 --- a/src/logger.py +++ b/src/logger.py @@ -13,4 +13,4 @@ def handle_output(message: str): logger.info(message) -builtins.print = handle_output +# builtins.print = handle_output diff --git a/src/message_queue.py b/src/message_queue.py index cca41e1..9d2bb2c 100644 --- a/src/message_queue.py +++ b/src/message_queue.py @@ -1,9 +1,9 @@ -import queue import asyncio -recv_queue = queue.Queue() +recv_queue = asyncio.Queue() +message_queue = asyncio.Queue() async def get_response(): - while recv_queue.empty(): - await asyncio.sleep(0.5) - return recv_queue.get() + response = await recv_queue.get() + recv_queue.task_done() + return response \ No newline at end of file diff --git a/src/recv_handler.py b/src/recv_handler.py index 0ecc057..5467604 100644 --- a/src/recv_handler.py +++ b/src/recv_handler.py @@ -4,7 +4,7 @@ import time import asyncio import json import websockets.asyncio.server as Server -from typing import List +from typing import List, Tuple from . import MetaEventType, RealMessageType, MessageType from maim_message import ( @@ -19,6 +19,7 @@ from maim_message import ( ) from .utils import get_group_info, get_member_info, get_image_base64, get_self_info +from .message_queue import get_response class RecvHandler: @@ -59,10 +60,8 @@ class RecvHandler: """ 从Napcat接受的原始消息处理 - 参数: + Parameters: raw_message: dict: 原始消息 - 返回值: - None """ message_type: str = raw_message.get("message_type") message_id: int = raw_message.get("message_id") @@ -70,7 +69,6 @@ class RecvHandler: template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用 format_info: FormatInfo = None # 格式化信息,暂时为空,等待启用 - if message_type == MessageType.private: sub_type = raw_message.get("sub_type") if sub_type == MessageType.Private.friend: @@ -116,7 +114,7 @@ class RecvHandler: # -------------------这里需要群信息吗?------------------- # 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有 - fetched_group_info: dict = get_group_info( + fetched_group_info: dict = await get_group_info( self.server_connection, raw_message.get("group_id") ) group_name = "" @@ -194,18 +192,16 @@ class RecvHandler: message_segment=submit_seg, raw_message=raw_message.get("raw_message"), ) - # 不启用发送消息 - await self.message_process(message_base) - logger.debug("我处理!") + logger.info("发送到Maibot处理信息") + await self.message_process(message_base) async def handle_real_message(self, raw_message: dict) -> List[Seg]: """ 处理实际消息 - - 参数: + Parameters: real_message: dict: 实际消息 - 返回值: + Returns: seg_message: list[Seg]: 处理后的消息段列表 """ real_message: list = raw_message.get("message") @@ -243,7 +239,7 @@ class RecvHandler: logger.warning("暂时不支持猜拳魔法表情解析") pass case RealMessageType.dice: - logger.warning("暂时不支持筛子表情解析") + logger.warning("暂时不支持骰子表情解析") pass case RealMessageType.shake: # 预计等价于戳一戳 @@ -267,9 +263,17 @@ class RecvHandler: } ) await self.server_connection.send(payload) - response = await self.server_connection.recv() - logger.critical(response) - logger.critical(json.loads(response)) + # response = await self.server_connection.recv() + response: dict = await get_response() + logger.debug( + f"转发消息原始格式:{json.dumps(response)[:80]}..." + if len(json.dumps(response)) > 80 + else json.dumps(response) + ) + messages = response.get("data").get("messages") + ret_seg = await self.handle_forward_message(messages) + if ret_seg: + seg_message.append(ret_seg) case RealMessageType.node: logger.warning("不支持转发消息节点解析") pass @@ -278,10 +282,9 @@ class RecvHandler: async def handle_text_message(self, raw_message: dict) -> Seg: """ 处理纯文本信息 - - 参数: + Parameters: raw_message: dict: 原始消息 - 返回值: + Returns: seg_data: Seg: 处理后的消息段 """ message_data: dict = raw_message.get("data") @@ -300,10 +303,9 @@ class RecvHandler: async def handle_image_message(self, raw_message: dict) -> Seg: """ 处理图片消息与表情包消息 - - 参数: + Parameters: raw_message: dict: 原始消息 - 返回值: + Returns: seg_data: Seg: 处理后的消息段 """ message_data: dict = raw_message.get("data") @@ -325,28 +327,204 @@ class RecvHandler: ) -> Seg: """ 处理at消息 + Parameters: + raw_message: dict: 原始消息 + self_id: int: 机器人QQ号 + group_id: int: 群号 + Returns: + seg_data: Seg: 处理后的消息段 """ message_data: dict = raw_message.get("data") if message_data: qq_id = message_data.get("qq") if str(self_id) == str(qq_id): - self_info: dict = get_self_info() + self_info: dict = await get_self_info(self.server_connection) if self_info: - return Seg(type="text", data=f"@{self_info.get('nickname')} ") + return Seg( + type=RealMessageType.text, data=f"@{self_info.get('nickname')}" + ) else: return None else: - member_info: dict = get_member_info( + member_info: dict = await get_member_info( self.server_connection, group_id=group_id, user_id=self_id ) if member_info: - return Seg(type="text", data=f"@{member_info.get('nickname')} ") + return Seg( + type=RealMessageType.text, + data=f"@{member_info.get('nickname')}", + ) else: return None async def handle_poke_message(self) -> None: pass + async def handle_forward_message(self, message_list: list) -> Seg: + """ + 递归处理转发消息,并按照动态方式确定图片处理方式 + Parameters: + message_list: list: 转发消息列表 + """ + handled_message, image_count = await self._handle_forward_message( + message_list, 0 + ) + handled_message: Seg + image_count: int + if not handled_message: + return None + if image_count < 5 and image_count > 0: + # 处理图片数量小于5的情况,此时解析图片为base64 + parsed_handled_message = await self._recursive_parse_image_seg( + handled_message, True + ) + return parsed_handled_message + elif image_count > 0: + # 处理图片数量大于等于5的情况,此时解析图片为占位符 + parsed_handled_message = await self._recursive_parse_image_seg( + handled_message, False + ) + return parsed_handled_message + else: + # 处理没有图片的情况,此时直接返回 + return handled_message + + async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg: + if to_image: + if seg_data.type == "seglist": + new_seg_list = [] + for i_seg in seg_data.data: + parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image) + new_seg_list.append(parsed_seg) + return Seg(type="seglist", data=new_seg_list) + elif seg_data.type == "image": + image_url = seg_data.data + encoded_image = await get_image_base64(image_url) + return Seg(type="image", data=encoded_image) + elif seg_data.type == "emoji": + image_url = seg_data.data + encoded_image = await get_image_base64(image_url) + return Seg(type="emoji", data=encoded_image) + else: + return seg_data + else: + if seg_data.type == "seglist": + new_seg_list = [] + for i_seg in seg_data.data: + parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image) + new_seg_list.append(parsed_seg) + return Seg(type="seglist", data=new_seg_list) + elif seg_data.type == "image": + image_url = seg_data.data + return Seg(type="text", data="【图片】") + elif seg_data.type == "emoji": + image_url = seg_data.data + return Seg(type="text", data="【动画表情】") + else: + return seg_data + + async def _handle_forward_message( + self, message_list: list, layer: int + ) -> Tuple[Seg, int]: + """ + 递归处理实际转发消息 + Parameters: + message_list: list: 转发消息列表,首层对应messages字段,后面对应content字段 + layer: int: 当前层级 + Returns: + seg_data: Seg: 处理后的消息段 + image_count: int: 图片数量 + """ + seg_list = [] + image_count = 0 + if message_list is None or len(message_list) == 0: + return None, 0 + for sub_message in message_list: + sub_message: dict + sender_info: dict = sub_message.get("sender") + user_nickname: str = sender_info.get("nickname", "QQ用户") + user_nickname_str = f"【{user_nickname}】:" + break_seg = Seg(type="text", data="\n") + message_of_sub_message: dict = sub_message.get("message")[0] + if message_of_sub_message.get("type") == RealMessageType.forward: + if layer >= 3: + full_seg_data = ( + Seg( + type="text", + data=("--" * layer) + f"【{user_nickname}】:【转发消息】\n", + ), + 0, + ) + else: + contents = message_of_sub_message.get("data").get("content") + seg_data, count = await self._handle_forward_message( + contents, layer + 1 + ) + image_count += count + head_tip = Seg( + type="text", + data=("--" * layer) + + f"【{user_nickname}】: 合并转发消息内容:\n", + ) + full_seg_data = Seg(type="seglist", data=[head_tip, seg_data]) + seg_list.append(full_seg_data) + elif message_of_sub_message.get("type") == RealMessageType.text: + text_message = message_of_sub_message.get("data").get("text") + seg_data = Seg(type="text", data=text_message) + if layer > 0: + seg_list.append( + Seg( + type="seglist", + data=[ + Seg( + type="text", data=("--" * layer) + user_nickname_str + ), + seg_data, + break_seg, + ], + ) + ) + else: + seg_list.append( + Seg( + type="seglist", + data=[ + Seg(type="text", data=user_nickname_str), + seg_data, + break_seg, + ], + ) + ) + elif message_of_sub_message.get("type") == RealMessageType.image: + image_count += 1 + image_data = message_of_sub_message.get("data") + sub_type = image_data.get("sub_type") + image_url = image_data.get("url") + if sub_type == 0: + seg_data = Seg(type="image", data=image_url) + else: + seg_data = Seg(type="emoji", data=image_url) + if layer > 0: + full_seg_data = Seg( + type="seglist", + data=[ + Seg(type="text", data=("--" * layer) + user_nickname_str), + seg_data, + break_seg, + ], + ) + else: + full_seg_data = Seg( + type="seglist", + data=[ + Seg(type="text", data=user_nickname_str), + seg_data, + break_seg, + ], + ) + seg_list.append(full_seg_data) + return Seg(type="seglist", data=seg_list), image_count + async def message_process(self, message_base: MessageBase) -> None: await self.maibot_router.send_message(message_base) diff --git a/src/send_handler.py b/src/send_handler.py index 208dabb..53560e7 100644 --- a/src/send_handler.py +++ b/src/send_handler.py @@ -14,13 +14,13 @@ from maim_message import ( MessageBase, ) +from .utils import get_image_format, convert_image_to_gif class SendHandler: def __init__(self): self.server_connection: Server.ServerConnection = None async def handle_seg(self, raw_message_base_str: str) -> None: - logger.critical(raw_message_base_str) raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_str) message_info: BaseMessageInfo = raw_message_base.message_info message_segment: Seg = raw_message_base.message_segment @@ -127,9 +127,13 @@ class SendHandler: def handle_emoji_message(self, encoded_emoji: str) -> dict: """处理表情消息""" + encoded_image = encoded_emoji + image_format = get_image_format(encoded_emoji) + if image_format != 'gif': + encoded_image = convert_image_to_gif(encoded_emoji) return { "type": "image", - "data": {"file": f"base64://{encoded_emoji}", "subtype": 1}, + "data": {"file": f"base64://{encoded_image}", "subtype": 1}, } async def test_send(self): @@ -145,7 +149,7 @@ class SendHandler: else: logger.warning(f"消息发送失败,napcat返回:{str(response)}") - async def send_message_to_napcat(self, action: str, params: dict) -> None: + async def send_message_to_napcat(self, action: str, params: dict) -> dict: payload = json.dumps({"action": action, "params": params}) await self.server_connection.send(payload) response = await get_response() diff --git a/src/utils.py b/src/utils.py index 86f4314..e25c7f1 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,6 +2,7 @@ import websockets.asyncio.server as Server import json import base64 from .logger import logger +from .message_queue import get_response import requests import ssl @@ -10,6 +11,7 @@ from requests.adapters import HTTPAdapter from PIL import Image import io + class SSLAdapter(HTTPAdapter): def init_poolmanager(self, *args, **kwargs): """ @@ -35,9 +37,9 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d """ payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}}) await websocket.send(payload) - socket_response = await websocket.recv() + socket_response: dict = await get_response() logger.debug(socket_response) - return json.loads(socket_response).get("data") + return socket_response.get("data") async def get_member_info( @@ -55,9 +57,9 @@ async def get_member_info( } ) await websocket.send(payload) - socket_response = await websocket.recv() + socket_response: dict = await get_response() logger.debug(socket_response) - return json.loads(socket_response).get("data") + return socket_response.get("data") async def get_image_base64(url: str) -> str: @@ -65,10 +67,7 @@ async def get_image_base64(url: str) -> str: try: sess = requests.session() sess.mount("https://", SSLAdapter()) # 将上面定义的SSLAdapter 应用起来 - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" - } - response = sess.get(url, headers=headers, timeout=10, verify=True) + response = sess.get(url, timeout=10, verify=True) response.raise_for_status() image_bytes = response.content return base64.b64encode(image_bytes).decode("utf-8") @@ -77,16 +76,30 @@ async def get_image_base64(url: str) -> str: raise +def convert_image_to_gif(image_base64: str) -> str: + try: + image_bytes = base64.b64decode(image_base64) + image = Image.open(io.BytesIO(image_bytes)) + output_buffer = io.BytesIO() + image.save(output_buffer, format="GIF") + output_buffer.seek(0) + return base64.b64encode(output_buffer.read()).decode("utf-8") + except Exception as e: + logger.error(f"图片转换为GIF失败: {str(e)}") + return image_base64 + + async def get_self_info(websocket: Server.ServerConnection) -> str: """ 获取自身信息 """ payload = json.dumps({"action": "get_login_info", "params": {}}) await websocket.send(payload) - response = await websocket.recv() + response: dict = await get_response() logger.debug(response) - return json.loads(response).get("data") + return response.get("data") -async def get_image_format(raw_data: str) -> str: + +def get_image_format(raw_data: str) -> str: image_bytes = base64.b64decode(raw_data) - return Image.open(io.BytesIO(image_bytes)).format.lower() \ No newline at end of file + return Image.open(io.BytesIO(image_bytes)).format.lower() diff --git a/test.py b/test.py new file mode 100644 index 0000000..3d8945b --- /dev/null +++ b/test.py @@ -0,0 +1,18 @@ +import asyncio +import queue + +message = queue.Queue() + +async def test(): + await asyncio.sleep(5) + message.put("123") + +async def test2(): + while message.empty(): + await asyncio.sleep(0.5) + print("等回复") + print(message.get()) + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + loop.run_until_complete(asyncio.gather(test(), test2())) \ No newline at end of file