mirror of https://github.com/Mai-with-u/MaiBot.git
从maim_message的序列化和反序列化;更多消息组件
parent
c0c003a098
commit
daad0ba2f0
|
|
@ -1,5 +1,11 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from maim_message import MessageBase
|
from maim_message import (
|
||||||
|
MessageBase,
|
||||||
|
UserInfo as MaimUserInfo,
|
||||||
|
GroupInfo as MaimGroupInfo,
|
||||||
|
BaseMessageInfo as MaimBaseMessageInfo,
|
||||||
|
Seg,
|
||||||
|
)
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
@ -112,10 +118,59 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_maim_message(cls, message: MessageBase) -> "MaiMessage":
|
def from_maim_message(cls, message: MessageBase) -> "MaiMessage":
|
||||||
raise NotImplementedError
|
"""从 maim_message.MessageBase 创建 MaiMessage 实例,解析消息内容并提取相关信息"""
|
||||||
|
msg_info = message.message_info
|
||||||
|
assert msg_info, "MessageBase 的 message_info 不能为空"
|
||||||
|
msg_id = msg_info.message_id
|
||||||
|
timestamp = msg_info.time
|
||||||
|
assert isinstance(msg_id, str)
|
||||||
|
assert msg_id
|
||||||
|
assert timestamp
|
||||||
|
obj = cls(message_id=msg_id, timestamp=datetime.fromtimestamp(timestamp))
|
||||||
|
obj.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(message)
|
||||||
|
usr_info = msg_info.user_info
|
||||||
|
assert usr_info
|
||||||
|
assert isinstance(usr_info.user_id, str)
|
||||||
|
assert isinstance(usr_info.user_nickname, str)
|
||||||
|
user_info = UserInfo(
|
||||||
|
user_id=usr_info.user_id,
|
||||||
|
user_nickname=usr_info.user_nickname,
|
||||||
|
user_cardname=usr_info.user_cardname,
|
||||||
|
)
|
||||||
|
if grp_info := msg_info.group_info:
|
||||||
|
assert isinstance(grp_info.group_id, str)
|
||||||
|
assert isinstance(grp_info.group_name, str)
|
||||||
|
group_info = GroupInfo(group_id=grp_info.group_id, group_name=grp_info.group_name)
|
||||||
|
else:
|
||||||
|
group_info = None
|
||||||
|
add_cfg = msg_info.additional_config or {}
|
||||||
|
obj.message_info = MessageInfo(user_info=user_info, group_info=group_info, additional_config=add_cfg)
|
||||||
|
return obj
|
||||||
|
|
||||||
def to_maim_message(self) -> MessageBase:
|
async def to_maim_message(self) -> MessageBase:
|
||||||
raise NotImplementedError
|
"""
|
||||||
|
从 MaiMessage 实例转换为 maim_message.MessageBase,构建消息内容并设置相关信息
|
||||||
def parse_message_segments(self):
|
"""
|
||||||
raise NotImplementedError
|
maim_user_info = MaimUserInfo(
|
||||||
|
user_id=self.message_info.user_info.user_id,
|
||||||
|
user_nickname=self.message_info.user_info.user_nickname,
|
||||||
|
user_cardname=self.message_info.user_info.user_cardname,
|
||||||
|
platform=self.platform,
|
||||||
|
)
|
||||||
|
maim_group_info = None
|
||||||
|
if self.message_info.group_info:
|
||||||
|
maim_group_info = MaimGroupInfo(
|
||||||
|
group_id=self.message_info.group_info.group_id,
|
||||||
|
group_name=self.message_info.group_info.group_name,
|
||||||
|
platform=self.platform,
|
||||||
|
)
|
||||||
|
maim_msg_info = MaimBaseMessageInfo(
|
||||||
|
platform=self.platform,
|
||||||
|
message_id=self.message_id,
|
||||||
|
time=self.timestamp.timestamp(),
|
||||||
|
group_info=maim_group_info,
|
||||||
|
user_info=maim_user_info,
|
||||||
|
additional_config=self.message_info.additional_config,
|
||||||
|
)
|
||||||
|
msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message)
|
||||||
|
return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments))
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ class ByteComponent:
|
||||||
|
|
||||||
|
|
||||||
class TextComponent(BaseMessageComponentModel):
|
class TextComponent(BaseMessageComponentModel):
|
||||||
|
"""文本组件,包含一个文本消息的内容"""
|
||||||
def __init__(self, text: str):
|
def __init__(self, text: str):
|
||||||
self.text = text
|
self.text = text
|
||||||
assert isinstance(text, str), "TextComponent 的 text 必须是字符串类型"
|
assert isinstance(text, str), "TextComponent 的 text 必须是字符串类型"
|
||||||
|
|
@ -42,9 +43,10 @@ class TextComponent(BaseMessageComponentModel):
|
||||||
|
|
||||||
|
|
||||||
class ImageComponent(BaseMessageComponentModel, ByteComponent):
|
class ImageComponent(BaseMessageComponentModel, ByteComponent):
|
||||||
|
"""图片组件,包含一个图片消息的二进制数据和一个唯一标识该图片消息的 hash 值"""
|
||||||
async def load_image_binary(self):
|
async def load_image_binary(self):
|
||||||
if not self.binary_data:
|
if not self.binary_data:
|
||||||
...
|
raise NotImplementedError
|
||||||
|
|
||||||
async def to_seg(self) -> Seg:
|
async def to_seg(self) -> Seg:
|
||||||
if not self.binary_data:
|
if not self.binary_data:
|
||||||
|
|
@ -53,6 +55,7 @@ class ImageComponent(BaseMessageComponentModel, ByteComponent):
|
||||||
|
|
||||||
|
|
||||||
class EmojiComponent(BaseMessageComponentModel, ByteComponent):
|
class EmojiComponent(BaseMessageComponentModel, ByteComponent):
|
||||||
|
"""表情组件,包含一个表情消息的二进制数据和一个唯一标识该表情消息的 hash 值"""
|
||||||
async def load_emoji_binary(self) -> None:
|
async def load_emoji_binary(self) -> None:
|
||||||
"""
|
"""
|
||||||
加载表情的二进制数据,如果 binary_data 为空,则通过 emoji_hash 从表情管理器加载
|
加载表情的二进制数据,如果 binary_data 为空,则通过 emoji_hash 从表情管理器加载
|
||||||
|
|
@ -81,6 +84,7 @@ class EmojiComponent(BaseMessageComponentModel, ByteComponent):
|
||||||
|
|
||||||
|
|
||||||
class VoiceComponent(BaseMessageComponentModel, ByteComponent):
|
class VoiceComponent(BaseMessageComponentModel, ByteComponent):
|
||||||
|
"""语音组件,包含一个语音消息的二进制数据和一个唯一标识该语音消息的 hash 值"""
|
||||||
async def load_voice_binary(self) -> None:
|
async def load_voice_binary(self) -> None:
|
||||||
if not self.binary_data:
|
if not self.binary_data:
|
||||||
from src.common.utils.utils_file import FileUtils
|
from src.common.utils.utils_file import FileUtils
|
||||||
|
|
@ -97,9 +101,33 @@ class VoiceComponent(BaseMessageComponentModel, ByteComponent):
|
||||||
return Seg(type="voice", data=base64.b64encode(self.binary_data).decode())
|
return Seg(type="voice", data=base64.b64encode(self.binary_data).decode())
|
||||||
|
|
||||||
|
|
||||||
|
class AtComponent(BaseMessageComponentModel):
|
||||||
|
"""@组件,包含一个被@的用户的ID,用于表示该组件是一个@某人的消息片段"""
|
||||||
|
def __init__(self, target_user_id: str) -> None:
|
||||||
|
self.target_user_id = target_user_id
|
||||||
|
"""目标用户ID"""
|
||||||
|
assert isinstance(target_user_id, str), "AtComponent 的 target_user_id 必须是字符串类型"
|
||||||
|
|
||||||
|
async def to_seg(self) -> Seg:
|
||||||
|
return Seg(type="at", data=self.target_user_id)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplyComponent(BaseMessageComponentModel):
|
||||||
|
"""回复组件,包含一个回复消息的 ID,用于表示该组件是对哪条消息的回复"""
|
||||||
|
def __init__(self, target_message_id: str) -> None:
|
||||||
|
assert isinstance(target_message_id, str), "ReplyComponent 的 target_message_id 必须是字符串类型"
|
||||||
|
self.target_message_id = target_message_id
|
||||||
|
"""目标消息ID"""
|
||||||
|
|
||||||
|
async def to_seg(self) -> Seg:
|
||||||
|
return Seg(type="reply", data=self.target_message_id)
|
||||||
|
|
||||||
|
|
||||||
class ForwardNodeComponent(BaseMessageComponentModel):
|
class ForwardNodeComponent(BaseMessageComponentModel):
|
||||||
|
"""转发节点消息组件,包含一个转发节点的消息,所有组件按照消息顺序排列"""
|
||||||
def __init__(self, forward_components: List["ForwardComponent"]):
|
def __init__(self, forward_components: List["ForwardComponent"]):
|
||||||
self.forward_components = forward_components
|
self.forward_components = forward_components
|
||||||
|
"""节点的消息组件列表,按照消息顺序排列"""
|
||||||
assert isinstance(forward_components, list), "ForwardNodeComponent 的 forward_components 必须是列表类型"
|
assert isinstance(forward_components, list), "ForwardNodeComponent 的 forward_components 必须是列表类型"
|
||||||
assert all(isinstance(comp, ForwardComponent) for comp in forward_components), (
|
assert all(isinstance(comp, ForwardComponent) for comp in forward_components), (
|
||||||
"ForwardNodeComponent 的 forward_components 列表中必须全部是 ForwardComponent 类型"
|
"ForwardNodeComponent 的 forward_components 列表中必须全部是 ForwardComponent 类型"
|
||||||
|
|
@ -128,12 +156,15 @@ StandardMessageComponents = Union[
|
||||||
ImageComponent,
|
ImageComponent,
|
||||||
EmojiComponent,
|
EmojiComponent,
|
||||||
VoiceComponent,
|
VoiceComponent,
|
||||||
|
AtComponent,
|
||||||
|
ReplyComponent,
|
||||||
ForwardNodeComponent,
|
ForwardNodeComponent,
|
||||||
DictComponent,
|
DictComponent,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ForwardComponent(BaseMessageComponentModel):
|
class ForwardComponent(BaseMessageComponentModel):
|
||||||
|
"""转发组件,包含一个转发消息中的一个节点的信息,包括发送者信息和该节点的消息内容"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
user_nickname: str,
|
user_nickname: str,
|
||||||
|
|
@ -142,9 +173,13 @@ class ForwardComponent(BaseMessageComponentModel):
|
||||||
user_cardname: Optional[str] = None,
|
user_cardname: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.user_nickname: str = user_nickname
|
self.user_nickname: str = user_nickname
|
||||||
|
"""转发节点的发送者昵称"""
|
||||||
self.content: List[StandardMessageComponents] = content
|
self.content: List[StandardMessageComponents] = content
|
||||||
|
"""消息内容"""
|
||||||
self.user_id: Optional[str] = user_id
|
self.user_id: Optional[str] = user_id
|
||||||
|
"""转发节点的发送者ID,可能为 None"""
|
||||||
self.user_cardname: Optional[str] = user_cardname
|
self.user_cardname: Optional[str] = user_cardname
|
||||||
|
"""转发节点的发送者群名片,可能为 None"""
|
||||||
assert self.content, "ForwardComponent 的 content 不能为空"
|
assert self.content, "ForwardComponent 的 content 不能为空"
|
||||||
|
|
||||||
async def to_seg(self) -> "Seg":
|
async def to_seg(self) -> "Seg":
|
||||||
|
|
@ -154,13 +189,33 @@ class ForwardComponent(BaseMessageComponentModel):
|
||||||
|
|
||||||
|
|
||||||
class MessageSequence:
|
class MessageSequence:
|
||||||
|
"""消息组件序列,包含一个消息中的所有组件,按照顺序排列"""
|
||||||
|
|
||||||
def __init__(self, components: List[StandardMessageComponents]):
|
def __init__(self, components: List[StandardMessageComponents]):
|
||||||
|
"""
|
||||||
|
创建一个消息组件序列
|
||||||
|
|
||||||
|
**消息组件序列不会对组件进行去重或校验。**
|
||||||
|
|
||||||
|
因此同一消息中可以包含多个相同的组件(例如多个文本组件、多个图片组件等)。
|
||||||
|
因此也可以包含多个`ReplyComponent`组件(例如回复多条消息)。
|
||||||
|
如果需要对组件进行去重或校验,还请在使用时自行处理。
|
||||||
|
"""
|
||||||
self.components: List[StandardMessageComponents] = components
|
self.components: List[StandardMessageComponents] = components
|
||||||
|
|
||||||
def to_dict(self) -> List[Dict[str, Any]]:
|
def to_dict(self) -> List[Dict[str, Any]]:
|
||||||
|
"""将消息序列转换为字典列表格式,便于存储或传输"""
|
||||||
return [self._item_2_dict(comp) for comp in self.components]
|
return [self._item_2_dict(comp) for comp in self.components]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence":
|
||||||
|
"""从字典列表格式创建消息序列实例"""
|
||||||
|
components: List[StandardMessageComponents] = []
|
||||||
|
components.extend(cls._dict_2_item(item) for item in data)
|
||||||
|
return cls(components=components)
|
||||||
|
|
||||||
def _item_2_dict(self, item: StandardMessageComponents) -> Dict[str, Any]:
|
def _item_2_dict(self, item: StandardMessageComponents) -> Dict[str, Any]:
|
||||||
|
"""内部方法:将单个消息组件转换为字典格式"""
|
||||||
if isinstance(item, TextComponent):
|
if isinstance(item, TextComponent):
|
||||||
return {"type": "text", "data": item.text}
|
return {"type": "text", "data": item.text}
|
||||||
elif isinstance(item, ImageComponent):
|
elif isinstance(item, ImageComponent):
|
||||||
|
|
@ -175,6 +230,10 @@ class MessageSequence:
|
||||||
if not item.content:
|
if not item.content:
|
||||||
raise RuntimeError("VoiceComponent content 未初始化")
|
raise RuntimeError("VoiceComponent content 未初始化")
|
||||||
return {"type": "voice", "data": item.content, "hash": item.binary_hash}
|
return {"type": "voice", "data": item.content, "hash": item.binary_hash}
|
||||||
|
elif isinstance(item, AtComponent):
|
||||||
|
return {"type": "at", "data": item.target_user_id}
|
||||||
|
elif isinstance(item, ReplyComponent):
|
||||||
|
return {"type": "reply", "data": item.target_message_id}
|
||||||
elif isinstance(item, ForwardNodeComponent):
|
elif isinstance(item, ForwardNodeComponent):
|
||||||
return {
|
return {
|
||||||
"type": "forward",
|
"type": "forward",
|
||||||
|
|
@ -192,14 +251,9 @@ class MessageSequence:
|
||||||
logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent")
|
logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent")
|
||||||
return {"type": "dict", "data": item.data}
|
return {"type": "dict", "data": item.data}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence":
|
|
||||||
components: List[StandardMessageComponents] = []
|
|
||||||
components.extend(cls._dict_2_item(item) for item in data)
|
|
||||||
return cls(components=components)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents:
|
def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents:
|
||||||
|
"""内部方法:将单个消息组件的字典格式转换回组件对象"""
|
||||||
item_type = item.get("type")
|
item_type = item.get("type")
|
||||||
if item_type == "text":
|
if item_type == "text":
|
||||||
return TextComponent(text=item["data"])
|
return TextComponent(text=item["data"])
|
||||||
|
|
@ -209,6 +263,10 @@ class MessageSequence:
|
||||||
return EmojiComponent(binary_hash=item["hash"], content=item["data"])
|
return EmojiComponent(binary_hash=item["hash"], content=item["data"])
|
||||||
elif item_type == "voice":
|
elif item_type == "voice":
|
||||||
return VoiceComponent(binary_hash=item["hash"], content=item["data"])
|
return VoiceComponent(binary_hash=item["hash"], content=item["data"])
|
||||||
|
elif item_type == "at":
|
||||||
|
return AtComponent(target_user_id=item["data"])
|
||||||
|
elif item_type == "reply":
|
||||||
|
return ReplyComponent(target_message_id=item["data"])
|
||||||
elif item_type == "forward":
|
elif item_type == "forward":
|
||||||
forward_components = []
|
forward_components = []
|
||||||
for fc in item["data"]:
|
for fc in item["data"]:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,21 @@
|
||||||
|
from maim_message import MessageBase, Seg
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
import msgpack
|
import msgpack
|
||||||
|
|
||||||
from src.common.data_models.message_component_model import MessageSequence
|
from src.common.data_models.message_component_model import (
|
||||||
|
MessageSequence,
|
||||||
|
StandardMessageComponents,
|
||||||
|
TextComponent,
|
||||||
|
ImageComponent,
|
||||||
|
EmojiComponent,
|
||||||
|
VoiceComponent,
|
||||||
|
AtComponent,
|
||||||
|
ReplyComponent,
|
||||||
|
DictComponent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MessageUtils:
|
class MessageUtils:
|
||||||
|
|
@ -13,3 +28,60 @@ class MessageUtils:
|
||||||
def from_MaiSeq_to_db_record_msg(msg: MessageSequence) -> bytes:
|
def from_MaiSeq_to_db_record_msg(msg: MessageSequence) -> bytes:
|
||||||
dict_representation = msg.to_dict()
|
dict_representation = msg.to_dict()
|
||||||
return msgpack.packb(dict_representation) # type: ignore
|
return msgpack.packb(dict_representation) # type: ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_maim_message_segments_to_MaiSeq(message: "MessageBase") -> MessageSequence:
|
||||||
|
"""从maim_message.MessageBase.message_segment转换为MessageSequence"""
|
||||||
|
raw_msg_seq = message.message_segment
|
||||||
|
components: List[StandardMessageComponents] = []
|
||||||
|
if not raw_msg_seq:
|
||||||
|
return MessageSequence(components)
|
||||||
|
if raw_msg_seq.type == "seglist":
|
||||||
|
assert isinstance(raw_msg_seq.data, list), "seglist类型的message_segment数据应该是一个列表"
|
||||||
|
components.extend(MessageUtils._parse_maim_message_segment_to_component(item) for item in raw_msg_seq.data)
|
||||||
|
elif raw_msg_seq.type in {"text", "image", "emoji", "voice", "at", "reply"}:
|
||||||
|
components.append(MessageUtils._parse_maim_message_segment_to_component(raw_msg_seq))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"暂时不支持的消息片段类型: {raw_msg_seq.type}")
|
||||||
|
return MessageSequence(components)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def from_MaiSeq_to_maim_message_segments(msg_seq: MessageSequence) -> List[Seg]:
|
||||||
|
"""从MessageSequence转换为maim_message.MessageBase.message_segment格式的列表"""
|
||||||
|
segments = []
|
||||||
|
for component in msg_seq.components:
|
||||||
|
if isinstance(component, DictComponent):
|
||||||
|
seg = Seg(type="dict", data=component.data) # type: ignore
|
||||||
|
else:
|
||||||
|
seg = await component.to_seg()
|
||||||
|
segments.append(seg)
|
||||||
|
return segments
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_maim_message_segment_to_component(seg: Seg) -> "StandardMessageComponents":
|
||||||
|
if seg.type == "text":
|
||||||
|
assert isinstance(seg.data, str), "text类型的seg数据应该是字符串"
|
||||||
|
return TextComponent(text=seg.data)
|
||||||
|
elif seg.type == "image":
|
||||||
|
assert isinstance(seg.data, str), "image类型的seg数据应该是base64字符串"
|
||||||
|
image_bytes = base64.b64decode(seg.data)
|
||||||
|
binary_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
return ImageComponent(binary_hash=binary_hash, binary_data=image_bytes)
|
||||||
|
elif seg.type == "emoji":
|
||||||
|
assert isinstance(seg.data, str), "emoji类型的seg数据应该是base64字符串"
|
||||||
|
emoji_bytes = base64.b64decode(seg.data)
|
||||||
|
binary_hash = hashlib.md5(emoji_bytes).hexdigest()
|
||||||
|
return EmojiComponent(binary_hash=binary_hash, binary_data=emoji_bytes)
|
||||||
|
elif seg.type == "voice":
|
||||||
|
assert isinstance(seg.data, str), "voice类型的seg数据应该是base64字符串"
|
||||||
|
voice_bytes = base64.b64decode(seg.data)
|
||||||
|
binary_hash = hashlib.md5(voice_bytes).hexdigest()
|
||||||
|
return VoiceComponent(binary_hash=binary_hash, binary_data=voice_bytes)
|
||||||
|
elif seg.type == "at":
|
||||||
|
assert isinstance(seg.data, str), "at类型的seg数据应该是字符串"
|
||||||
|
return AtComponent(target_user_id=seg.data)
|
||||||
|
elif seg.type == "reply":
|
||||||
|
assert isinstance(seg.data, str), "reply类型的seg数据应该是字符串"
|
||||||
|
return ReplyComponent(target_message_id=seg.data)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"暂时不支持的消息片段类型: {seg.type}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue