转发功能,重构接收,增加发送
parent
6ec55f781b
commit
3cb262be9a
|
|
@ -37,6 +37,7 @@ enable_temp = false
|
|||
- [x] 文本解析
|
||||
- [x] 图片解析
|
||||
- [x] 文本与消息混合解析
|
||||
- [x] 转发解析(含图片动态解析)
|
||||
- [ ] 链接解析
|
||||
- [ ] 戳一戳解析
|
||||
- [ ] 语音解析(?)
|
||||
|
|
|
|||
36
main.py
36
main.py
|
|
@ -1,39 +1,54 @@
|
|||
import asyncio
|
||||
import sys
|
||||
import json
|
||||
import websockets.asyncio.server as Server
|
||||
import websockets as Server
|
||||
from src.logger import logger
|
||||
from src.recv_handler import recv_handler
|
||||
from src.send_handler import send_handler
|
||||
from src.config import global_config
|
||||
from src.mmc_com_layer import mmc_start_com, mmc_stop_com, router
|
||||
from src.message_queue import recv_queue
|
||||
from src.message_queue import recv_queue, message_queue
|
||||
|
||||
|
||||
async def message_recv(server_connection: Server.ServerConnection):
|
||||
recv_handler.server_connection = server_connection
|
||||
send_handler.server_connection = server_connection
|
||||
# asyncio.create_task(send_handler.test_send())
|
||||
async for raw_message in server_connection:
|
||||
logger.debug(raw_message)
|
||||
logger.debug(f"{raw_message[:80]}..." if len(raw_message) > 80 else raw_message)
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
post_type = decoded_raw_message.get("post_type")
|
||||
if post_type == "meta_event":
|
||||
await recv_handler.handle_meta_event(decoded_raw_message)
|
||||
await message_queue.put(decoded_raw_message)
|
||||
elif post_type == "message":
|
||||
await recv_handler.handle_raw_message(decoded_raw_message)
|
||||
await message_queue.put(decoded_raw_message)
|
||||
elif post_type == "notice":
|
||||
pass
|
||||
elif post_type is None:
|
||||
recv_queue.put(decoded_raw_message)
|
||||
await recv_queue.put(decoded_raw_message)
|
||||
|
||||
|
||||
async def message_process():
|
||||
while True:
|
||||
message = await message_queue.get()
|
||||
post_type = message.get("post_type")
|
||||
if post_type == "message":
|
||||
await recv_handler.handle_raw_message(message)
|
||||
elif post_type == "meta_event":
|
||||
await recv_handler.handle_meta_event(message)
|
||||
elif post_type == "notice":
|
||||
pass
|
||||
else:
|
||||
logger.warning(f"未知的post_type: {post_type}")
|
||||
message_queue.task_done()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
async def main():
|
||||
recv_handler.maibot_router = router
|
||||
_ = await asyncio.gather(mmc_server(), mmc_start_com())
|
||||
_ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process())
|
||||
|
||||
|
||||
async def mmc_server():
|
||||
async def napcat_server():
|
||||
logger.info("正在启动adapter...")
|
||||
async with Server.serve(
|
||||
message_recv, global_config.server_host, global_config.server_port
|
||||
|
|
@ -50,7 +65,8 @@ async def graceful_shutdown():
|
|||
await mmc_stop_com()
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -13,4 +13,4 @@ def handle_output(message: str):
|
|||
logger.info(message)
|
||||
|
||||
|
||||
builtins.print = handle_output
|
||||
# builtins.print = handle_output
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import queue
|
||||
import asyncio
|
||||
|
||||
recv_queue = queue.Queue()
|
||||
recv_queue = asyncio.Queue()
|
||||
message_queue = asyncio.Queue()
|
||||
|
||||
async def get_response():
|
||||
while recv_queue.empty():
|
||||
await asyncio.sleep(0.5)
|
||||
return recv_queue.get()
|
||||
response = await recv_queue.get()
|
||||
recv_queue.task_done()
|
||||
return response
|
||||
|
|
@ -4,7 +4,7 @@ import time
|
|||
import asyncio
|
||||
import json
|
||||
import websockets.asyncio.server as Server
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
from . import MetaEventType, RealMessageType, MessageType
|
||||
from maim_message import (
|
||||
|
|
@ -19,6 +19,7 @@ from maim_message import (
|
|||
)
|
||||
|
||||
from .utils import get_group_info, get_member_info, get_image_base64, get_self_info
|
||||
from .message_queue import get_response
|
||||
|
||||
|
||||
class RecvHandler:
|
||||
|
|
@ -59,10 +60,8 @@ class RecvHandler:
|
|||
"""
|
||||
从Napcat接受的原始消息处理
|
||||
|
||||
参数:
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
返回值:
|
||||
None
|
||||
"""
|
||||
message_type: str = raw_message.get("message_type")
|
||||
message_id: int = raw_message.get("message_id")
|
||||
|
|
@ -70,7 +69,6 @@ class RecvHandler:
|
|||
|
||||
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
|
||||
format_info: FormatInfo = None # 格式化信息,暂时为空,等待启用
|
||||
|
||||
if message_type == MessageType.private:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
if sub_type == MessageType.Private.friend:
|
||||
|
|
@ -116,7 +114,7 @@ class RecvHandler:
|
|||
# -------------------这里需要群信息吗?-------------------
|
||||
|
||||
# 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有
|
||||
fetched_group_info: dict = get_group_info(
|
||||
fetched_group_info: dict = await get_group_info(
|
||||
self.server_connection, raw_message.get("group_id")
|
||||
)
|
||||
group_name = ""
|
||||
|
|
@ -194,18 +192,16 @@ class RecvHandler:
|
|||
message_segment=submit_seg,
|
||||
raw_message=raw_message.get("raw_message"),
|
||||
)
|
||||
# 不启用发送消息
|
||||
await self.message_process(message_base)
|
||||
|
||||
logger.debug("我处理!")
|
||||
logger.info("发送到Maibot处理信息")
|
||||
await self.message_process(message_base)
|
||||
|
||||
async def handle_real_message(self, raw_message: dict) -> List[Seg]:
|
||||
"""
|
||||
处理实际消息
|
||||
|
||||
参数:
|
||||
Parameters:
|
||||
real_message: dict: 实际消息
|
||||
返回值:
|
||||
Returns:
|
||||
seg_message: list[Seg]: 处理后的消息段列表
|
||||
"""
|
||||
real_message: list = raw_message.get("message")
|
||||
|
|
@ -243,7 +239,7 @@ class RecvHandler:
|
|||
logger.warning("暂时不支持猜拳魔法表情解析")
|
||||
pass
|
||||
case RealMessageType.dice:
|
||||
logger.warning("暂时不支持筛子表情解析")
|
||||
logger.warning("暂时不支持骰子表情解析")
|
||||
pass
|
||||
case RealMessageType.shake:
|
||||
# 预计等价于戳一戳
|
||||
|
|
@ -267,9 +263,17 @@ class RecvHandler:
|
|||
}
|
||||
)
|
||||
await self.server_connection.send(payload)
|
||||
response = await self.server_connection.recv()
|
||||
logger.critical(response)
|
||||
logger.critical(json.loads(response))
|
||||
# response = await self.server_connection.recv()
|
||||
response: dict = await get_response()
|
||||
logger.debug(
|
||||
f"转发消息原始格式:{json.dumps(response)[:80]}..."
|
||||
if len(json.dumps(response)) > 80
|
||||
else json.dumps(response)
|
||||
)
|
||||
messages = response.get("data").get("messages")
|
||||
ret_seg = await self.handle_forward_message(messages)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
case RealMessageType.node:
|
||||
logger.warning("不支持转发消息节点解析")
|
||||
pass
|
||||
|
|
@ -278,10 +282,9 @@ class RecvHandler:
|
|||
async def handle_text_message(self, raw_message: dict) -> Seg:
|
||||
"""
|
||||
处理纯文本信息
|
||||
|
||||
参数:
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
返回值:
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
|
|
@ -300,10 +303,9 @@ class RecvHandler:
|
|||
async def handle_image_message(self, raw_message: dict) -> Seg:
|
||||
"""
|
||||
处理图片消息与表情包消息
|
||||
|
||||
参数:
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
返回值:
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
|
|
@ -325,28 +327,204 @@ class RecvHandler:
|
|||
) -> Seg:
|
||||
"""
|
||||
处理at消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
self_id: int: 机器人QQ号
|
||||
group_id: int: 群号
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
if message_data:
|
||||
qq_id = message_data.get("qq")
|
||||
if str(self_id) == str(qq_id):
|
||||
self_info: dict = get_self_info()
|
||||
self_info: dict = await get_self_info(self.server_connection)
|
||||
if self_info:
|
||||
return Seg(type="text", data=f"@{self_info.get('nickname')} ")
|
||||
return Seg(
|
||||
type=RealMessageType.text, data=f"@{self_info.get('nickname')}"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
member_info: dict = get_member_info(
|
||||
member_info: dict = await get_member_info(
|
||||
self.server_connection, group_id=group_id, user_id=self_id
|
||||
)
|
||||
if member_info:
|
||||
return Seg(type="text", data=f"@{member_info.get('nickname')} ")
|
||||
return Seg(
|
||||
type=RealMessageType.text,
|
||||
data=f"@{member_info.get('nickname')}",
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def handle_poke_message(self) -> None:
|
||||
pass
|
||||
|
||||
async def handle_forward_message(self, message_list: list) -> Seg:
|
||||
"""
|
||||
递归处理转发消息,并按照动态方式确定图片处理方式
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表
|
||||
"""
|
||||
handled_message, image_count = await self._handle_forward_message(
|
||||
message_list, 0
|
||||
)
|
||||
handled_message: Seg
|
||||
image_count: int
|
||||
if not handled_message:
|
||||
return None
|
||||
if image_count < 5 and image_count > 0:
|
||||
# 处理图片数量小于5的情况,此时解析图片为base64
|
||||
parsed_handled_message = await self._recursive_parse_image_seg(
|
||||
handled_message, True
|
||||
)
|
||||
return parsed_handled_message
|
||||
elif image_count > 0:
|
||||
# 处理图片数量大于等于5的情况,此时解析图片为占位符
|
||||
parsed_handled_message = await self._recursive_parse_image_seg(
|
||||
handled_message, False
|
||||
)
|
||||
return parsed_handled_message
|
||||
else:
|
||||
# 处理没有图片的情况,此时直接返回
|
||||
return handled_message
|
||||
|
||||
async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg:
|
||||
if to_image:
|
||||
if seg_data.type == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.data:
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return Seg(type="seglist", data=new_seg_list)
|
||||
elif seg_data.type == "image":
|
||||
image_url = seg_data.data
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
return Seg(type="image", data=encoded_image)
|
||||
elif seg_data.type == "emoji":
|
||||
image_url = seg_data.data
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
return Seg(type="emoji", data=encoded_image)
|
||||
else:
|
||||
return seg_data
|
||||
else:
|
||||
if seg_data.type == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.data:
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return Seg(type="seglist", data=new_seg_list)
|
||||
elif seg_data.type == "image":
|
||||
image_url = seg_data.data
|
||||
return Seg(type="text", data="【图片】")
|
||||
elif seg_data.type == "emoji":
|
||||
image_url = seg_data.data
|
||||
return Seg(type="text", data="【动画表情】")
|
||||
else:
|
||||
return seg_data
|
||||
|
||||
async def _handle_forward_message(
|
||||
self, message_list: list, layer: int
|
||||
) -> Tuple[Seg, int]:
|
||||
"""
|
||||
递归处理实际转发消息
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表,首层对应messages字段,后面对应content字段
|
||||
layer: int: 当前层级
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
image_count: int: 图片数量
|
||||
"""
|
||||
seg_list = []
|
||||
image_count = 0
|
||||
if message_list is None or len(message_list) == 0:
|
||||
return None, 0
|
||||
for sub_message in message_list:
|
||||
sub_message: dict
|
||||
sender_info: dict = sub_message.get("sender")
|
||||
user_nickname: str = sender_info.get("nickname", "QQ用户")
|
||||
user_nickname_str = f"【{user_nickname}】:"
|
||||
break_seg = Seg(type="text", data="\n")
|
||||
message_of_sub_message: dict = sub_message.get("message")[0]
|
||||
if message_of_sub_message.get("type") == RealMessageType.forward:
|
||||
if layer >= 3:
|
||||
full_seg_data = (
|
||||
Seg(
|
||||
type="text",
|
||||
data=("--" * layer) + f"【{user_nickname}】:【转发消息】\n",
|
||||
),
|
||||
0,
|
||||
)
|
||||
else:
|
||||
contents = message_of_sub_message.get("data").get("content")
|
||||
seg_data, count = await self._handle_forward_message(
|
||||
contents, layer + 1
|
||||
)
|
||||
image_count += count
|
||||
head_tip = Seg(
|
||||
type="text",
|
||||
data=("--" * layer)
|
||||
+ f"【{user_nickname}】: 合并转发消息内容:\n",
|
||||
)
|
||||
full_seg_data = Seg(type="seglist", data=[head_tip, seg_data])
|
||||
seg_list.append(full_seg_data)
|
||||
elif message_of_sub_message.get("type") == RealMessageType.text:
|
||||
text_message = message_of_sub_message.get("data").get("text")
|
||||
seg_data = Seg(type="text", data=text_message)
|
||||
if layer > 0:
|
||||
seg_list.append(
|
||||
Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(
|
||||
type="text", data=("--" * layer) + user_nickname_str
|
||||
),
|
||||
seg_data,
|
||||
break_seg,
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
seg_list.append(
|
||||
Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="text", data=user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
],
|
||||
)
|
||||
)
|
||||
elif message_of_sub_message.get("type") == RealMessageType.image:
|
||||
image_count += 1
|
||||
image_data = message_of_sub_message.get("data")
|
||||
sub_type = image_data.get("sub_type")
|
||||
image_url = image_data.get("url")
|
||||
if sub_type == 0:
|
||||
seg_data = Seg(type="image", data=image_url)
|
||||
else:
|
||||
seg_data = Seg(type="emoji", data=image_url)
|
||||
if layer > 0:
|
||||
full_seg_data = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="text", data=("--" * layer) + user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
],
|
||||
)
|
||||
else:
|
||||
full_seg_data = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="text", data=user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
],
|
||||
)
|
||||
seg_list.append(full_seg_data)
|
||||
return Seg(type="seglist", data=seg_list), image_count
|
||||
|
||||
async def message_process(self, message_base: MessageBase) -> None:
|
||||
await self.maibot_router.send_message(message_base)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,13 +14,13 @@ from maim_message import (
|
|||
MessageBase,
|
||||
)
|
||||
|
||||
from .utils import get_image_format, convert_image_to_gif
|
||||
|
||||
class SendHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
|
||||
async def handle_seg(self, raw_message_base_str: str) -> None:
|
||||
logger.critical(raw_message_base_str)
|
||||
raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_str)
|
||||
message_info: BaseMessageInfo = raw_message_base.message_info
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
|
|
@ -127,9 +127,13 @@ class SendHandler:
|
|||
|
||||
def handle_emoji_message(self, encoded_emoji: str) -> dict:
|
||||
"""处理表情消息"""
|
||||
encoded_image = encoded_emoji
|
||||
image_format = get_image_format(encoded_emoji)
|
||||
if image_format != 'gif':
|
||||
encoded_image = convert_image_to_gif(encoded_emoji)
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {"file": f"base64://{encoded_emoji}", "subtype": 1},
|
||||
"data": {"file": f"base64://{encoded_image}", "subtype": 1},
|
||||
}
|
||||
|
||||
async def test_send(self):
|
||||
|
|
@ -145,7 +149,7 @@ class SendHandler:
|
|||
else:
|
||||
logger.warning(f"消息发送失败,napcat返回:{str(response)}")
|
||||
|
||||
async def send_message_to_napcat(self, action: str, params: dict) -> None:
|
||||
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
|
||||
payload = json.dumps({"action": action, "params": params})
|
||||
await self.server_connection.send(payload)
|
||||
response = await get_response()
|
||||
|
|
|
|||
37
src/utils.py
37
src/utils.py
|
|
@ -2,6 +2,7 @@ import websockets.asyncio.server as Server
|
|||
import json
|
||||
import base64
|
||||
from .logger import logger
|
||||
from .message_queue import get_response
|
||||
|
||||
import requests
|
||||
import ssl
|
||||
|
|
@ -10,6 +11,7 @@ from requests.adapters import HTTPAdapter
|
|||
from PIL import Image
|
||||
import io
|
||||
|
||||
|
||||
class SSLAdapter(HTTPAdapter):
|
||||
def init_poolmanager(self, *args, **kwargs):
|
||||
"""
|
||||
|
|
@ -35,9 +37,9 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
|
|||
"""
|
||||
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}})
|
||||
await websocket.send(payload)
|
||||
socket_response = await websocket.recv()
|
||||
socket_response: dict = await get_response()
|
||||
logger.debug(socket_response)
|
||||
return json.loads(socket_response).get("data")
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_member_info(
|
||||
|
|
@ -55,9 +57,9 @@ async def get_member_info(
|
|||
}
|
||||
)
|
||||
await websocket.send(payload)
|
||||
socket_response = await websocket.recv()
|
||||
socket_response: dict = await get_response()
|
||||
logger.debug(socket_response)
|
||||
return json.loads(socket_response).get("data")
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_image_base64(url: str) -> str:
|
||||
|
|
@ -65,10 +67,7 @@ async def get_image_base64(url: str) -> str:
|
|||
try:
|
||||
sess = requests.session()
|
||||
sess.mount("https://", SSLAdapter()) # 将上面定义的SSLAdapter 应用起来
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
|
||||
}
|
||||
response = sess.get(url, headers=headers, timeout=10, verify=True)
|
||||
response = sess.get(url, timeout=10, verify=True)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
return base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
|
@ -77,16 +76,30 @@ async def get_image_base64(url: str) -> str:
|
|||
raise
|
||||
|
||||
|
||||
def convert_image_to_gif(image_base64: str) -> str:
|
||||
try:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
output_buffer = io.BytesIO()
|
||||
image.save(output_buffer, format="GIF")
|
||||
output_buffer.seek(0)
|
||||
return base64.b64encode(output_buffer.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为GIF失败: {str(e)}")
|
||||
return image_base64
|
||||
|
||||
|
||||
async def get_self_info(websocket: Server.ServerConnection) -> str:
|
||||
"""
|
||||
获取自身信息
|
||||
"""
|
||||
payload = json.dumps({"action": "get_login_info", "params": {}})
|
||||
await websocket.send(payload)
|
||||
response = await websocket.recv()
|
||||
response: dict = await get_response()
|
||||
logger.debug(response)
|
||||
return json.loads(response).get("data")
|
||||
return response.get("data")
|
||||
|
||||
async def get_image_format(raw_data: str) -> str:
|
||||
|
||||
def get_image_format(raw_data: str) -> str:
|
||||
image_bytes = base64.b64decode(raw_data)
|
||||
return Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
return Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
import asyncio
|
||||
import queue
|
||||
|
||||
message = queue.Queue()
|
||||
|
||||
async def test():
|
||||
await asyncio.sleep(5)
|
||||
message.put("123")
|
||||
|
||||
async def test2():
|
||||
while message.empty():
|
||||
await asyncio.sleep(0.5)
|
||||
print("等回复")
|
||||
print(message.get())
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(asyncio.gather(test(), test2()))
|
||||
Loading…
Reference in New Issue