转发功能,重构接收,增加发送

pull/2/head
UnCLAS-Prommer 2025-04-08 00:42:29 +08:00
parent 6ec55f781b
commit 3cb262be9a
8 changed files with 287 additions and 57 deletions

View File

@ -37,6 +37,7 @@ enable_temp = false
- [x] 文本解析
- [x] 图片解析
- [x] 文本与消息混合解析
- [x] 转发解析(含图片动态解析)
- [ ] 链接解析
- [ ] 戳一戳解析
- [ ] 语音解析(?)

36
main.py
View File

@ -1,39 +1,54 @@
import asyncio
import sys
import json
import websockets.asyncio.server as Server
import websockets as Server
from src.logger import logger
from src.recv_handler import recv_handler
from src.send_handler import send_handler
from src.config import global_config
from src.mmc_com_layer import mmc_start_com, mmc_stop_com, router
from src.message_queue import recv_queue
from src.message_queue import recv_queue, message_queue
async def message_recv(server_connection: Server.ServerConnection):
recv_handler.server_connection = server_connection
send_handler.server_connection = server_connection
# asyncio.create_task(send_handler.test_send())
async for raw_message in server_connection:
logger.debug(raw_message)
logger.debug(f"{raw_message[:80]}..." if len(raw_message) > 80 else raw_message)
decoded_raw_message: dict = json.loads(raw_message)
post_type = decoded_raw_message.get("post_type")
if post_type == "meta_event":
await recv_handler.handle_meta_event(decoded_raw_message)
await message_queue.put(decoded_raw_message)
elif post_type == "message":
await recv_handler.handle_raw_message(decoded_raw_message)
await message_queue.put(decoded_raw_message)
elif post_type == "notice":
pass
elif post_type is None:
recv_queue.put(decoded_raw_message)
await recv_queue.put(decoded_raw_message)
async def message_process():
while True:
message = await message_queue.get()
post_type = message.get("post_type")
if post_type == "message":
await recv_handler.handle_raw_message(message)
elif post_type == "meta_event":
await recv_handler.handle_meta_event(message)
elif post_type == "notice":
pass
else:
logger.warning(f"未知的post_type: {post_type}")
message_queue.task_done()
await asyncio.sleep(0.05)
async def main():
recv_handler.maibot_router = router
_ = await asyncio.gather(mmc_server(), mmc_start_com())
_ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process())
async def mmc_server():
async def napcat_server():
logger.info("正在启动adapter...")
async with Server.serve(
message_recv, global_config.server_host, global_config.server_port
@ -50,7 +65,8 @@ async def graceful_shutdown():
await mmc_stop_com()
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:

View File

@ -13,4 +13,4 @@ def handle_output(message: str):
logger.info(message)
builtins.print = handle_output
# builtins.print = handle_output

View File

@ -1,9 +1,9 @@
import queue
import asyncio
recv_queue = queue.Queue()
recv_queue = asyncio.Queue()
message_queue = asyncio.Queue()
async def get_response():
while recv_queue.empty():
await asyncio.sleep(0.5)
return recv_queue.get()
response = await recv_queue.get()
recv_queue.task_done()
return response

View File

@ -4,7 +4,7 @@ import time
import asyncio
import json
import websockets.asyncio.server as Server
from typing import List
from typing import List, Tuple
from . import MetaEventType, RealMessageType, MessageType
from maim_message import (
@ -19,6 +19,7 @@ from maim_message import (
)
from .utils import get_group_info, get_member_info, get_image_base64, get_self_info
from .message_queue import get_response
class RecvHandler:
@ -59,10 +60,8 @@ class RecvHandler:
"""
从Napcat接受的原始消息处理
参数:
Parameters:
raw_message: dict: 原始消息
返回值:
None
"""
message_type: str = raw_message.get("message_type")
message_id: int = raw_message.get("message_id")
@ -70,7 +69,6 @@ class RecvHandler:
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
format_info: FormatInfo = None # 格式化信息,暂时为空,等待启用
if message_type == MessageType.private:
sub_type = raw_message.get("sub_type")
if sub_type == MessageType.Private.friend:
@ -116,7 +114,7 @@ class RecvHandler:
# -------------------这里需要群信息吗?-------------------
# 获取群聊相关信息在此单独处理group_name因为默认发送的消息中没有
fetched_group_info: dict = get_group_info(
fetched_group_info: dict = await get_group_info(
self.server_connection, raw_message.get("group_id")
)
group_name = ""
@ -194,18 +192,16 @@ class RecvHandler:
message_segment=submit_seg,
raw_message=raw_message.get("raw_message"),
)
# 不启用发送消息
await self.message_process(message_base)
logger.debug("我处理!")
logger.info("发送到Maibot处理信息")
await self.message_process(message_base)
async def handle_real_message(self, raw_message: dict) -> List[Seg]:
"""
处理实际消息
参数:
Parameters:
real_message: dict: 实际消息
返回值:
Returns:
seg_message: list[Seg]: 处理后的消息段列表
"""
real_message: list = raw_message.get("message")
@ -243,7 +239,7 @@ class RecvHandler:
logger.warning("暂时不支持猜拳魔法表情解析")
pass
case RealMessageType.dice:
logger.warning("暂时不支持子表情解析")
logger.warning("暂时不支持子表情解析")
pass
case RealMessageType.shake:
# 预计等价于戳一戳
@ -267,9 +263,17 @@ class RecvHandler:
}
)
await self.server_connection.send(payload)
response = await self.server_connection.recv()
logger.critical(response)
logger.critical(json.loads(response))
# response = await self.server_connection.recv()
response: dict = await get_response()
logger.debug(
f"转发消息原始格式:{json.dumps(response)[:80]}..."
if len(json.dumps(response)) > 80
else json.dumps(response)
)
messages = response.get("data").get("messages")
ret_seg = await self.handle_forward_message(messages)
if ret_seg:
seg_message.append(ret_seg)
case RealMessageType.node:
logger.warning("不支持转发消息节点解析")
pass
@ -278,10 +282,9 @@ class RecvHandler:
async def handle_text_message(self, raw_message: dict) -> Seg:
"""
处理纯文本信息
参数:
Parameters:
raw_message: dict: 原始消息
返回值:
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
@ -300,10 +303,9 @@ class RecvHandler:
async def handle_image_message(self, raw_message: dict) -> Seg:
"""
处理图片消息与表情包消息
参数:
Parameters:
raw_message: dict: 原始消息
返回值:
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
@ -325,28 +327,204 @@ class RecvHandler:
) -> Seg:
"""
处理at消息
Parameters:
raw_message: dict: 原始消息
self_id: int: 机器人QQ号
group_id: int: 群号
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
if message_data:
qq_id = message_data.get("qq")
if str(self_id) == str(qq_id):
self_info: dict = get_self_info()
self_info: dict = await get_self_info(self.server_connection)
if self_info:
return Seg(type="text", data=f"@{self_info.get('nickname')} ")
return Seg(
type=RealMessageType.text, data=f"@{self_info.get('nickname')}"
)
else:
return None
else:
member_info: dict = get_member_info(
member_info: dict = await get_member_info(
self.server_connection, group_id=group_id, user_id=self_id
)
if member_info:
return Seg(type="text", data=f"@{member_info.get('nickname')} ")
return Seg(
type=RealMessageType.text,
data=f"@{member_info.get('nickname')}",
)
else:
return None
async def handle_poke_message(self) -> None:
pass
async def handle_forward_message(self, message_list: list) -> Seg:
"""
递归处理转发消息并按照动态方式确定图片处理方式
Parameters:
message_list: list: 转发消息列表
"""
handled_message, image_count = await self._handle_forward_message(
message_list, 0
)
handled_message: Seg
image_count: int
if not handled_message:
return None
if image_count < 5 and image_count > 0:
# 处理图片数量小于5的情况此时解析图片为base64
parsed_handled_message = await self._recursive_parse_image_seg(
handled_message, True
)
return parsed_handled_message
elif image_count > 0:
# 处理图片数量大于等于5的情况此时解析图片为占位符
parsed_handled_message = await self._recursive_parse_image_seg(
handled_message, False
)
return parsed_handled_message
else:
# 处理没有图片的情况,此时直接返回
return handled_message
async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg:
if to_image:
if seg_data.type == "seglist":
new_seg_list = []
for i_seg in seg_data.data:
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
new_seg_list.append(parsed_seg)
return Seg(type="seglist", data=new_seg_list)
elif seg_data.type == "image":
image_url = seg_data.data
encoded_image = await get_image_base64(image_url)
return Seg(type="image", data=encoded_image)
elif seg_data.type == "emoji":
image_url = seg_data.data
encoded_image = await get_image_base64(image_url)
return Seg(type="emoji", data=encoded_image)
else:
return seg_data
else:
if seg_data.type == "seglist":
new_seg_list = []
for i_seg in seg_data.data:
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
new_seg_list.append(parsed_seg)
return Seg(type="seglist", data=new_seg_list)
elif seg_data.type == "image":
image_url = seg_data.data
return Seg(type="text", data="【图片】")
elif seg_data.type == "emoji":
image_url = seg_data.data
return Seg(type="text", data="【动画表情】")
else:
return seg_data
async def _handle_forward_message(
self, message_list: list, layer: int
) -> Tuple[Seg, int]:
"""
递归处理实际转发消息
Parameters:
message_list: list: 转发消息列表首层对应messages字段后面对应content字段
layer: int: 当前层级
Returns:
seg_data: Seg: 处理后的消息段
image_count: int: 图片数量
"""
seg_list = []
image_count = 0
if message_list is None or len(message_list) == 0:
return None, 0
for sub_message in message_list:
sub_message: dict
sender_info: dict = sub_message.get("sender")
user_nickname: str = sender_info.get("nickname", "QQ用户")
user_nickname_str = f"{user_nickname}】:"
break_seg = Seg(type="text", data="\n")
message_of_sub_message: dict = sub_message.get("message")[0]
if message_of_sub_message.get("type") == RealMessageType.forward:
if layer >= 3:
full_seg_data = (
Seg(
type="text",
data=("--" * layer) + f"{user_nickname}】:【转发消息】\n",
),
0,
)
else:
contents = message_of_sub_message.get("data").get("content")
seg_data, count = await self._handle_forward_message(
contents, layer + 1
)
image_count += count
head_tip = Seg(
type="text",
data=("--" * layer)
+ f"{user_nickname}】: 合并转发消息内容:\n",
)
full_seg_data = Seg(type="seglist", data=[head_tip, seg_data])
seg_list.append(full_seg_data)
elif message_of_sub_message.get("type") == RealMessageType.text:
text_message = message_of_sub_message.get("data").get("text")
seg_data = Seg(type="text", data=text_message)
if layer > 0:
seg_list.append(
Seg(
type="seglist",
data=[
Seg(
type="text", data=("--" * layer) + user_nickname_str
),
seg_data,
break_seg,
],
)
)
else:
seg_list.append(
Seg(
type="seglist",
data=[
Seg(type="text", data=user_nickname_str),
seg_data,
break_seg,
],
)
)
elif message_of_sub_message.get("type") == RealMessageType.image:
image_count += 1
image_data = message_of_sub_message.get("data")
sub_type = image_data.get("sub_type")
image_url = image_data.get("url")
if sub_type == 0:
seg_data = Seg(type="image", data=image_url)
else:
seg_data = Seg(type="emoji", data=image_url)
if layer > 0:
full_seg_data = Seg(
type="seglist",
data=[
Seg(type="text", data=("--" * layer) + user_nickname_str),
seg_data,
break_seg,
],
)
else:
full_seg_data = Seg(
type="seglist",
data=[
Seg(type="text", data=user_nickname_str),
seg_data,
break_seg,
],
)
seg_list.append(full_seg_data)
return Seg(type="seglist", data=seg_list), image_count
async def message_process(self, message_base: MessageBase) -> None:
await self.maibot_router.send_message(message_base)

View File

@ -14,13 +14,13 @@ from maim_message import (
MessageBase,
)
from .utils import get_image_format, convert_image_to_gif
class SendHandler:
def __init__(self):
self.server_connection: Server.ServerConnection = None
async def handle_seg(self, raw_message_base_str: str) -> None:
logger.critical(raw_message_base_str)
raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_str)
message_info: BaseMessageInfo = raw_message_base.message_info
message_segment: Seg = raw_message_base.message_segment
@ -127,9 +127,13 @@ class SendHandler:
def handle_emoji_message(self, encoded_emoji: str) -> dict:
"""处理表情消息"""
encoded_image = encoded_emoji
image_format = get_image_format(encoded_emoji)
if image_format != 'gif':
encoded_image = convert_image_to_gif(encoded_emoji)
return {
"type": "image",
"data": {"file": f"base64://{encoded_emoji}", "subtype": 1},
"data": {"file": f"base64://{encoded_image}", "subtype": 1},
}
async def test_send(self):
@ -145,7 +149,7 @@ class SendHandler:
else:
logger.warning(f"消息发送失败napcat返回{str(response)}")
async def send_message_to_napcat(self, action: str, params: dict) -> None:
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
payload = json.dumps({"action": action, "params": params})
await self.server_connection.send(payload)
response = await get_response()

View File

@ -2,6 +2,7 @@ import websockets.asyncio.server as Server
import json
import base64
from .logger import logger
from .message_queue import get_response
import requests
import ssl
@ -10,6 +11,7 @@ from requests.adapters import HTTPAdapter
from PIL import Image
import io
class SSLAdapter(HTTPAdapter):
def init_poolmanager(self, *args, **kwargs):
"""
@ -35,9 +37,9 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
"""
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}})
await websocket.send(payload)
socket_response = await websocket.recv()
socket_response: dict = await get_response()
logger.debug(socket_response)
return json.loads(socket_response).get("data")
return socket_response.get("data")
async def get_member_info(
@ -55,9 +57,9 @@ async def get_member_info(
}
)
await websocket.send(payload)
socket_response = await websocket.recv()
socket_response: dict = await get_response()
logger.debug(socket_response)
return json.loads(socket_response).get("data")
return socket_response.get("data")
async def get_image_base64(url: str) -> str:
@ -65,10 +67,7 @@ async def get_image_base64(url: str) -> str:
try:
sess = requests.session()
sess.mount("https://", SSLAdapter()) # 将上面定义的SSLAdapter 应用起来
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
}
response = sess.get(url, headers=headers, timeout=10, verify=True)
response = sess.get(url, timeout=10, verify=True)
response.raise_for_status()
image_bytes = response.content
return base64.b64encode(image_bytes).decode("utf-8")
@ -77,16 +76,30 @@ async def get_image_base64(url: str) -> str:
raise
def convert_image_to_gif(image_base64: str) -> str:
try:
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
output_buffer = io.BytesIO()
image.save(output_buffer, format="GIF")
output_buffer.seek(0)
return base64.b64encode(output_buffer.read()).decode("utf-8")
except Exception as e:
logger.error(f"图片转换为GIF失败: {str(e)}")
return image_base64
async def get_self_info(websocket: Server.ServerConnection) -> str:
"""
获取自身信息
"""
payload = json.dumps({"action": "get_login_info", "params": {}})
await websocket.send(payload)
response = await websocket.recv()
response: dict = await get_response()
logger.debug(response)
return json.loads(response).get("data")
return response.get("data")
async def get_image_format(raw_data: str) -> str:
def get_image_format(raw_data: str) -> str:
image_bytes = base64.b64decode(raw_data)
return Image.open(io.BytesIO(image_bytes)).format.lower()
return Image.open(io.BytesIO(image_bytes)).format.lower()

18
test.py 100644
View File

@ -0,0 +1,18 @@
import asyncio
import queue
message = queue.Queue()
async def test():
await asyncio.sleep(5)
message.put("123")
async def test2():
while message.empty():
await asyncio.sleep(0.5)
print("等回复")
print(message.get())
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(asyncio.gather(test(), test2()))