mirror of https://github.com/Mai-with-u/MaiBot.git
从maim_message的序列化和反序列化;更多消息组件
parent
c0c003a098
commit
daad0ba2f0
|
|
@ -1,5 +1,11 @@
|
|||
from dataclasses import dataclass, field
|
||||
from maim_message import MessageBase
|
||||
from maim_message import (
|
||||
MessageBase,
|
||||
UserInfo as MaimUserInfo,
|
||||
GroupInfo as MaimGroupInfo,
|
||||
BaseMessageInfo as MaimBaseMessageInfo,
|
||||
Seg,
|
||||
)
|
||||
from typing import Optional
|
||||
|
||||
import json
|
||||
|
|
@ -112,10 +118,59 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
|||
|
||||
@classmethod
|
||||
def from_maim_message(cls, message: MessageBase) -> "MaiMessage":
|
||||
raise NotImplementedError
|
||||
"""从 maim_message.MessageBase 创建 MaiMessage 实例,解析消息内容并提取相关信息"""
|
||||
msg_info = message.message_info
|
||||
assert msg_info, "MessageBase 的 message_info 不能为空"
|
||||
msg_id = msg_info.message_id
|
||||
timestamp = msg_info.time
|
||||
assert isinstance(msg_id, str)
|
||||
assert msg_id
|
||||
assert timestamp
|
||||
obj = cls(message_id=msg_id, timestamp=datetime.fromtimestamp(timestamp))
|
||||
obj.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(message)
|
||||
usr_info = msg_info.user_info
|
||||
assert usr_info
|
||||
assert isinstance(usr_info.user_id, str)
|
||||
assert isinstance(usr_info.user_nickname, str)
|
||||
user_info = UserInfo(
|
||||
user_id=usr_info.user_id,
|
||||
user_nickname=usr_info.user_nickname,
|
||||
user_cardname=usr_info.user_cardname,
|
||||
)
|
||||
if grp_info := msg_info.group_info:
|
||||
assert isinstance(grp_info.group_id, str)
|
||||
assert isinstance(grp_info.group_name, str)
|
||||
group_info = GroupInfo(group_id=grp_info.group_id, group_name=grp_info.group_name)
|
||||
else:
|
||||
group_info = None
|
||||
add_cfg = msg_info.additional_config or {}
|
||||
obj.message_info = MessageInfo(user_info=user_info, group_info=group_info, additional_config=add_cfg)
|
||||
return obj
|
||||
|
||||
def to_maim_message(self) -> MessageBase:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_message_segments(self):
|
||||
raise NotImplementedError
|
||||
async def to_maim_message(self) -> MessageBase:
|
||||
"""
|
||||
从 MaiMessage 实例转换为 maim_message.MessageBase,构建消息内容并设置相关信息
|
||||
"""
|
||||
maim_user_info = MaimUserInfo(
|
||||
user_id=self.message_info.user_info.user_id,
|
||||
user_nickname=self.message_info.user_info.user_nickname,
|
||||
user_cardname=self.message_info.user_info.user_cardname,
|
||||
platform=self.platform,
|
||||
)
|
||||
maim_group_info = None
|
||||
if self.message_info.group_info:
|
||||
maim_group_info = MaimGroupInfo(
|
||||
group_id=self.message_info.group_info.group_id,
|
||||
group_name=self.message_info.group_info.group_name,
|
||||
platform=self.platform,
|
||||
)
|
||||
maim_msg_info = MaimBaseMessageInfo(
|
||||
platform=self.platform,
|
||||
message_id=self.message_id,
|
||||
time=self.timestamp.timestamp(),
|
||||
group_info=maim_group_info,
|
||||
user_info=maim_user_info,
|
||||
additional_config=self.message_info.additional_config,
|
||||
)
|
||||
msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message)
|
||||
return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments))
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class ByteComponent:
|
|||
|
||||
|
||||
class TextComponent(BaseMessageComponentModel):
|
||||
"""文本组件,包含一个文本消息的内容"""
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
assert isinstance(text, str), "TextComponent 的 text 必须是字符串类型"
|
||||
|
|
@ -42,9 +43,10 @@ class TextComponent(BaseMessageComponentModel):
|
|||
|
||||
|
||||
class ImageComponent(BaseMessageComponentModel, ByteComponent):
|
||||
"""图片组件,包含一个图片消息的二进制数据和一个唯一标识该图片消息的 hash 值"""
|
||||
async def load_image_binary(self):
|
||||
if not self.binary_data:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
async def to_seg(self) -> Seg:
|
||||
if not self.binary_data:
|
||||
|
|
@ -53,6 +55,7 @@ class ImageComponent(BaseMessageComponentModel, ByteComponent):
|
|||
|
||||
|
||||
class EmojiComponent(BaseMessageComponentModel, ByteComponent):
|
||||
"""表情组件,包含一个表情消息的二进制数据和一个唯一标识该表情消息的 hash 值"""
|
||||
async def load_emoji_binary(self) -> None:
|
||||
"""
|
||||
加载表情的二进制数据,如果 binary_data 为空,则通过 emoji_hash 从表情管理器加载
|
||||
|
|
@ -81,6 +84,7 @@ class EmojiComponent(BaseMessageComponentModel, ByteComponent):
|
|||
|
||||
|
||||
class VoiceComponent(BaseMessageComponentModel, ByteComponent):
|
||||
"""语音组件,包含一个语音消息的二进制数据和一个唯一标识该语音消息的 hash 值"""
|
||||
async def load_voice_binary(self) -> None:
|
||||
if not self.binary_data:
|
||||
from src.common.utils.utils_file import FileUtils
|
||||
|
|
@ -97,9 +101,33 @@ class VoiceComponent(BaseMessageComponentModel, ByteComponent):
|
|||
return Seg(type="voice", data=base64.b64encode(self.binary_data).decode())
|
||||
|
||||
|
||||
class AtComponent(BaseMessageComponentModel):
|
||||
"""@组件,包含一个被@的用户的ID,用于表示该组件是一个@某人的消息片段"""
|
||||
def __init__(self, target_user_id: str) -> None:
|
||||
self.target_user_id = target_user_id
|
||||
"""目标用户ID"""
|
||||
assert isinstance(target_user_id, str), "AtComponent 的 target_user_id 必须是字符串类型"
|
||||
|
||||
async def to_seg(self) -> Seg:
|
||||
return Seg(type="at", data=self.target_user_id)
|
||||
|
||||
|
||||
class ReplyComponent(BaseMessageComponentModel):
|
||||
"""回复组件,包含一个回复消息的 ID,用于表示该组件是对哪条消息的回复"""
|
||||
def __init__(self, target_message_id: str) -> None:
|
||||
assert isinstance(target_message_id, str), "ReplyComponent 的 target_message_id 必须是字符串类型"
|
||||
self.target_message_id = target_message_id
|
||||
"""目标消息ID"""
|
||||
|
||||
async def to_seg(self) -> Seg:
|
||||
return Seg(type="reply", data=self.target_message_id)
|
||||
|
||||
|
||||
class ForwardNodeComponent(BaseMessageComponentModel):
|
||||
"""转发节点消息组件,包含一个转发节点的消息,所有组件按照消息顺序排列"""
|
||||
def __init__(self, forward_components: List["ForwardComponent"]):
|
||||
self.forward_components = forward_components
|
||||
"""节点的消息组件列表,按照消息顺序排列"""
|
||||
assert isinstance(forward_components, list), "ForwardNodeComponent 的 forward_components 必须是列表类型"
|
||||
assert all(isinstance(comp, ForwardComponent) for comp in forward_components), (
|
||||
"ForwardNodeComponent 的 forward_components 列表中必须全部是 ForwardComponent 类型"
|
||||
|
|
@ -128,12 +156,15 @@ StandardMessageComponents = Union[
|
|||
ImageComponent,
|
||||
EmojiComponent,
|
||||
VoiceComponent,
|
||||
AtComponent,
|
||||
ReplyComponent,
|
||||
ForwardNodeComponent,
|
||||
DictComponent,
|
||||
]
|
||||
|
||||
|
||||
class ForwardComponent(BaseMessageComponentModel):
|
||||
"""转发组件,包含一个转发消息中的一个节点的信息,包括发送者信息和该节点的消息内容"""
|
||||
def __init__(
|
||||
self,
|
||||
user_nickname: str,
|
||||
|
|
@ -142,9 +173,13 @@ class ForwardComponent(BaseMessageComponentModel):
|
|||
user_cardname: Optional[str] = None,
|
||||
):
|
||||
self.user_nickname: str = user_nickname
|
||||
"""转发节点的发送者昵称"""
|
||||
self.content: List[StandardMessageComponents] = content
|
||||
"""消息内容"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
"""转发节点的发送者ID,可能为 None"""
|
||||
self.user_cardname: Optional[str] = user_cardname
|
||||
"""转发节点的发送者群名片,可能为 None"""
|
||||
assert self.content, "ForwardComponent 的 content 不能为空"
|
||||
|
||||
async def to_seg(self) -> "Seg":
|
||||
|
|
@ -154,13 +189,33 @@ class ForwardComponent(BaseMessageComponentModel):
|
|||
|
||||
|
||||
class MessageSequence:
|
||||
"""消息组件序列,包含一个消息中的所有组件,按照顺序排列"""
|
||||
|
||||
def __init__(self, components: List[StandardMessageComponents]):
|
||||
"""
|
||||
创建一个消息组件序列
|
||||
|
||||
**消息组件序列不会对组件进行去重或校验。**
|
||||
|
||||
因此同一消息中可以包含多个相同的组件(例如多个文本组件、多个图片组件等)。
|
||||
因此也可以包含多个`ReplyComponent`组件(例如回复多条消息)。
|
||||
如果需要对组件进行去重或校验,还请在使用时自行处理。
|
||||
"""
|
||||
self.components: List[StandardMessageComponents] = components
|
||||
|
||||
def to_dict(self) -> List[Dict[str, Any]]:
|
||||
"""将消息序列转换为字典列表格式,便于存储或传输"""
|
||||
return [self._item_2_dict(comp) for comp in self.components]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence":
|
||||
"""从字典列表格式创建消息序列实例"""
|
||||
components: List[StandardMessageComponents] = []
|
||||
components.extend(cls._dict_2_item(item) for item in data)
|
||||
return cls(components=components)
|
||||
|
||||
def _item_2_dict(self, item: StandardMessageComponents) -> Dict[str, Any]:
|
||||
"""内部方法:将单个消息组件转换为字典格式"""
|
||||
if isinstance(item, TextComponent):
|
||||
return {"type": "text", "data": item.text}
|
||||
elif isinstance(item, ImageComponent):
|
||||
|
|
@ -175,6 +230,10 @@ class MessageSequence:
|
|||
if not item.content:
|
||||
raise RuntimeError("VoiceComponent content 未初始化")
|
||||
return {"type": "voice", "data": item.content, "hash": item.binary_hash}
|
||||
elif isinstance(item, AtComponent):
|
||||
return {"type": "at", "data": item.target_user_id}
|
||||
elif isinstance(item, ReplyComponent):
|
||||
return {"type": "reply", "data": item.target_message_id}
|
||||
elif isinstance(item, ForwardNodeComponent):
|
||||
return {
|
||||
"type": "forward",
|
||||
|
|
@ -192,14 +251,9 @@ class MessageSequence:
|
|||
logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent")
|
||||
return {"type": "dict", "data": item.data}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence":
|
||||
components: List[StandardMessageComponents] = []
|
||||
components.extend(cls._dict_2_item(item) for item in data)
|
||||
return cls(components=components)
|
||||
|
||||
@classmethod
|
||||
def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents:
|
||||
"""内部方法:将单个消息组件的字典格式转换回组件对象"""
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
return TextComponent(text=item["data"])
|
||||
|
|
@ -209,6 +263,10 @@ class MessageSequence:
|
|||
return EmojiComponent(binary_hash=item["hash"], content=item["data"])
|
||||
elif item_type == "voice":
|
||||
return VoiceComponent(binary_hash=item["hash"], content=item["data"])
|
||||
elif item_type == "at":
|
||||
return AtComponent(target_user_id=item["data"])
|
||||
elif item_type == "reply":
|
||||
return ReplyComponent(target_message_id=item["data"])
|
||||
elif item_type == "forward":
|
||||
forward_components = []
|
||||
for fc in item["data"]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,21 @@
|
|||
from maim_message import MessageBase, Seg
|
||||
from typing import List
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import msgpack
|
||||
|
||||
from src.common.data_models.message_component_model import MessageSequence
|
||||
from src.common.data_models.message_component_model import (
|
||||
MessageSequence,
|
||||
StandardMessageComponents,
|
||||
TextComponent,
|
||||
ImageComponent,
|
||||
EmojiComponent,
|
||||
VoiceComponent,
|
||||
AtComponent,
|
||||
ReplyComponent,
|
||||
DictComponent,
|
||||
)
|
||||
|
||||
|
||||
class MessageUtils:
|
||||
|
|
@ -13,3 +28,60 @@ class MessageUtils:
|
|||
def from_MaiSeq_to_db_record_msg(msg: MessageSequence) -> bytes:
|
||||
dict_representation = msg.to_dict()
|
||||
return msgpack.packb(dict_representation) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def from_maim_message_segments_to_MaiSeq(message: "MessageBase") -> MessageSequence:
|
||||
"""从maim_message.MessageBase.message_segment转换为MessageSequence"""
|
||||
raw_msg_seq = message.message_segment
|
||||
components: List[StandardMessageComponents] = []
|
||||
if not raw_msg_seq:
|
||||
return MessageSequence(components)
|
||||
if raw_msg_seq.type == "seglist":
|
||||
assert isinstance(raw_msg_seq.data, list), "seglist类型的message_segment数据应该是一个列表"
|
||||
components.extend(MessageUtils._parse_maim_message_segment_to_component(item) for item in raw_msg_seq.data)
|
||||
elif raw_msg_seq.type in {"text", "image", "emoji", "voice", "at", "reply"}:
|
||||
components.append(MessageUtils._parse_maim_message_segment_to_component(raw_msg_seq))
|
||||
else:
|
||||
raise NotImplementedError(f"暂时不支持的消息片段类型: {raw_msg_seq.type}")
|
||||
return MessageSequence(components)
|
||||
|
||||
@staticmethod
|
||||
async def from_MaiSeq_to_maim_message_segments(msg_seq: MessageSequence) -> List[Seg]:
|
||||
"""从MessageSequence转换为maim_message.MessageBase.message_segment格式的列表"""
|
||||
segments = []
|
||||
for component in msg_seq.components:
|
||||
if isinstance(component, DictComponent):
|
||||
seg = Seg(type="dict", data=component.data) # type: ignore
|
||||
else:
|
||||
seg = await component.to_seg()
|
||||
segments.append(seg)
|
||||
return segments
|
||||
|
||||
@staticmethod
|
||||
def _parse_maim_message_segment_to_component(seg: Seg) -> "StandardMessageComponents":
|
||||
if seg.type == "text":
|
||||
assert isinstance(seg.data, str), "text类型的seg数据应该是字符串"
|
||||
return TextComponent(text=seg.data)
|
||||
elif seg.type == "image":
|
||||
assert isinstance(seg.data, str), "image类型的seg数据应该是base64字符串"
|
||||
image_bytes = base64.b64decode(seg.data)
|
||||
binary_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
return ImageComponent(binary_hash=binary_hash, binary_data=image_bytes)
|
||||
elif seg.type == "emoji":
|
||||
assert isinstance(seg.data, str), "emoji类型的seg数据应该是base64字符串"
|
||||
emoji_bytes = base64.b64decode(seg.data)
|
||||
binary_hash = hashlib.md5(emoji_bytes).hexdigest()
|
||||
return EmojiComponent(binary_hash=binary_hash, binary_data=emoji_bytes)
|
||||
elif seg.type == "voice":
|
||||
assert isinstance(seg.data, str), "voice类型的seg数据应该是base64字符串"
|
||||
voice_bytes = base64.b64decode(seg.data)
|
||||
binary_hash = hashlib.md5(voice_bytes).hexdigest()
|
||||
return VoiceComponent(binary_hash=binary_hash, binary_data=voice_bytes)
|
||||
elif seg.type == "at":
|
||||
assert isinstance(seg.data, str), "at类型的seg数据应该是字符串"
|
||||
return AtComponent(target_user_id=seg.data)
|
||||
elif seg.type == "reply":
|
||||
assert isinstance(seg.data, str), "reply类型的seg数据应该是字符串"
|
||||
return ReplyComponent(target_message_id=seg.data)
|
||||
else:
|
||||
raise NotImplementedError(f"暂时不支持的消息片段类型: {seg.type}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue