diff --git a/src/common/data_models/mai_message_data_model.py b/src/common/data_models/mai_message_data_model.py index 0bd159c7..0e881508 100644 --- a/src/common/data_models/mai_message_data_model.py +++ b/src/common/data_models/mai_message_data_model.py @@ -1,5 +1,11 @@ from dataclasses import dataclass, field -from maim_message import MessageBase +from maim_message import ( + MessageBase, + UserInfo as MaimUserInfo, + GroupInfo as MaimGroupInfo, + BaseMessageInfo as MaimBaseMessageInfo, + Seg, +) from typing import Optional import json @@ -112,10 +118,59 @@ class MaiMessage(BaseDatabaseDataModel[Messages]): @classmethod def from_maim_message(cls, message: MessageBase) -> "MaiMessage": - raise NotImplementedError + """从 maim_message.MessageBase 创建 MaiMessage 实例,解析消息内容并提取相关信息""" + msg_info = message.message_info + assert msg_info, "MessageBase 的 message_info 不能为空" + msg_id = msg_info.message_id + timestamp = msg_info.time + assert isinstance(msg_id, str) + assert msg_id + assert timestamp + obj = cls(message_id=msg_id, timestamp=datetime.fromtimestamp(timestamp)) + obj.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(message) + usr_info = msg_info.user_info + assert usr_info + assert isinstance(usr_info.user_id, str) + assert isinstance(usr_info.user_nickname, str) + user_info = UserInfo( + user_id=usr_info.user_id, + user_nickname=usr_info.user_nickname, + user_cardname=usr_info.user_cardname, + ) + if grp_info := msg_info.group_info: + assert isinstance(grp_info.group_id, str) + assert isinstance(grp_info.group_name, str) + group_info = GroupInfo(group_id=grp_info.group_id, group_name=grp_info.group_name) + else: + group_info = None + add_cfg = msg_info.additional_config or {} + obj.message_info = MessageInfo(user_info=user_info, group_info=group_info, additional_config=add_cfg) + return obj - def to_maim_message(self) -> MessageBase: - raise NotImplementedError - - def parse_message_segments(self): - raise NotImplementedError + async def to_maim_message(self) -> MessageBase: + """ + 从 MaiMessage 实例转换为 maim_message.MessageBase,构建消息内容并设置相关信息 + """ + maim_user_info = MaimUserInfo( + user_id=self.message_info.user_info.user_id, + user_nickname=self.message_info.user_info.user_nickname, + user_cardname=self.message_info.user_info.user_cardname, + platform=self.platform, + ) + maim_group_info = None + if self.message_info.group_info: + maim_group_info = MaimGroupInfo( + group_id=self.message_info.group_info.group_id, + group_name=self.message_info.group_info.group_name, + platform=self.platform, + ) + maim_msg_info = MaimBaseMessageInfo( + platform=self.platform, + message_id=self.message_id, + time=self.timestamp.timestamp(), + group_info=maim_group_info, + user_info=maim_user_info, + additional_config=self.message_info.additional_config, + ) + msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message) + return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments)) diff --git a/src/common/data_models/message_component_model.py b/src/common/data_models/message_component_model.py index 7733ddce..7bb132b8 100644 --- a/src/common/data_models/message_component_model.py +++ b/src/common/data_models/message_component_model.py @@ -33,6 +33,7 @@ class ByteComponent: class TextComponent(BaseMessageComponentModel): + """文本组件,包含一个文本消息的内容""" def __init__(self, text: str): self.text = text assert isinstance(text, str), "TextComponent 的 text 必须是字符串类型" @@ -42,9 +43,10 @@ class TextComponent(BaseMessageComponentModel): class ImageComponent(BaseMessageComponentModel, ByteComponent): + """图片组件,包含一个图片消息的二进制数据和一个唯一标识该图片消息的 hash 值""" async def load_image_binary(self): if not self.binary_data: - ... + raise NotImplementedError async def to_seg(self) -> Seg: if not self.binary_data: @@ -53,6 +55,7 @@ class ImageComponent(BaseMessageComponentModel, ByteComponent): class EmojiComponent(BaseMessageComponentModel, ByteComponent): + """表情组件,包含一个表情消息的二进制数据和一个唯一标识该表情消息的 hash 值""" async def load_emoji_binary(self) -> None: """ 加载表情的二进制数据,如果 binary_data 为空,则通过 emoji_hash 从表情管理器加载 @@ -81,6 +84,7 @@ class EmojiComponent(BaseMessageComponentModel, ByteComponent): class VoiceComponent(BaseMessageComponentModel, ByteComponent): + """语音组件,包含一个语音消息的二进制数据和一个唯一标识该语音消息的 hash 值""" async def load_voice_binary(self) -> None: if not self.binary_data: from src.common.utils.utils_file import FileUtils @@ -97,9 +101,33 @@ class VoiceComponent(BaseMessageComponentModel, ByteComponent): return Seg(type="voice", data=base64.b64encode(self.binary_data).decode()) +class AtComponent(BaseMessageComponentModel): + """@组件,包含一个被@的用户的ID,用于表示该组件是一个@某人的消息片段""" + def __init__(self, target_user_id: str) -> None: + self.target_user_id = target_user_id + """目标用户ID""" + assert isinstance(target_user_id, str), "AtComponent 的 target_user_id 必须是字符串类型" + + async def to_seg(self) -> Seg: + return Seg(type="at", data=self.target_user_id) + + +class ReplyComponent(BaseMessageComponentModel): + """回复组件,包含一个回复消息的 ID,用于表示该组件是对哪条消息的回复""" + def __init__(self, target_message_id: str) -> None: + assert isinstance(target_message_id, str), "ReplyComponent 的 target_message_id 必须是字符串类型" + self.target_message_id = target_message_id + """目标消息ID""" + + async def to_seg(self) -> Seg: + return Seg(type="reply", data=self.target_message_id) + + class ForwardNodeComponent(BaseMessageComponentModel): + """转发节点消息组件,包含一个转发节点的消息,所有组件按照消息顺序排列""" def __init__(self, forward_components: List["ForwardComponent"]): self.forward_components = forward_components + """节点的消息组件列表,按照消息顺序排列""" assert isinstance(forward_components, list), "ForwardNodeComponent 的 forward_components 必须是列表类型" assert all(isinstance(comp, ForwardComponent) for comp in forward_components), ( "ForwardNodeComponent 的 forward_components 列表中必须全部是 ForwardComponent 类型" @@ -128,12 +156,15 @@ StandardMessageComponents = Union[ ImageComponent, EmojiComponent, VoiceComponent, + AtComponent, + ReplyComponent, ForwardNodeComponent, DictComponent, ] class ForwardComponent(BaseMessageComponentModel): + """转发组件,包含一个转发消息中的一个节点的信息,包括发送者信息和该节点的消息内容""" def __init__( self, user_nickname: str, @@ -142,9 +173,13 @@ class ForwardComponent(BaseMessageComponentModel): user_cardname: Optional[str] = None, ): self.user_nickname: str = user_nickname + """转发节点的发送者昵称""" self.content: List[StandardMessageComponents] = content + """消息内容""" self.user_id: Optional[str] = user_id + """转发节点的发送者ID,可能为 None""" self.user_cardname: Optional[str] = user_cardname + """转发节点的发送者群名片,可能为 None""" assert self.content, "ForwardComponent 的 content 不能为空" async def to_seg(self) -> "Seg": @@ -154,13 +189,33 @@ class ForwardComponent(BaseMessageComponentModel): class MessageSequence: + """消息组件序列,包含一个消息中的所有组件,按照顺序排列""" + def __init__(self, components: List[StandardMessageComponents]): + """ + 创建一个消息组件序列 + + **消息组件序列不会对组件进行去重或校验。** + + 因此同一消息中可以包含多个相同的组件(例如多个文本组件、多个图片组件等)。 + 因此也可以包含多个`ReplyComponent`组件(例如回复多条消息)。 + 如果需要对组件进行去重或校验,还请在使用时自行处理。 + """ self.components: List[StandardMessageComponents] = components def to_dict(self) -> List[Dict[str, Any]]: + """将消息序列转换为字典列表格式,便于存储或传输""" return [self._item_2_dict(comp) for comp in self.components] + @classmethod + def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence": + """从字典列表格式创建消息序列实例""" + components: List[StandardMessageComponents] = [] + components.extend(cls._dict_2_item(item) for item in data) + return cls(components=components) + def _item_2_dict(self, item: StandardMessageComponents) -> Dict[str, Any]: + """内部方法:将单个消息组件转换为字典格式""" if isinstance(item, TextComponent): return {"type": "text", "data": item.text} elif isinstance(item, ImageComponent): @@ -175,6 +230,10 @@ class MessageSequence: if not item.content: raise RuntimeError("VoiceComponent content 未初始化") return {"type": "voice", "data": item.content, "hash": item.binary_hash} + elif isinstance(item, AtComponent): + return {"type": "at", "data": item.target_user_id} + elif isinstance(item, ReplyComponent): + return {"type": "reply", "data": item.target_message_id} elif isinstance(item, ForwardNodeComponent): return { "type": "forward", @@ -192,14 +251,9 @@ class MessageSequence: logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent") return {"type": "dict", "data": item.data} - @classmethod - def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence": - components: List[StandardMessageComponents] = [] - components.extend(cls._dict_2_item(item) for item in data) - return cls(components=components) - @classmethod def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents: + """内部方法:将单个消息组件的字典格式转换回组件对象""" item_type = item.get("type") if item_type == "text": return TextComponent(text=item["data"]) @@ -209,6 +263,10 @@ class MessageSequence: return EmojiComponent(binary_hash=item["hash"], content=item["data"]) elif item_type == "voice": return VoiceComponent(binary_hash=item["hash"], content=item["data"]) + elif item_type == "at": + return AtComponent(target_user_id=item["data"]) + elif item_type == "reply": + return ReplyComponent(target_message_id=item["data"]) elif item_type == "forward": forward_components = [] for fc in item["data"]: diff --git a/src/common/utils/utils_message.py b/src/common/utils/utils_message.py index 41164623..aead8777 100644 --- a/src/common/utils/utils_message.py +++ b/src/common/utils/utils_message.py @@ -1,6 +1,21 @@ +from maim_message import MessageBase, Seg +from typing import List + +import base64 +import hashlib import msgpack -from src.common.data_models.message_component_model import MessageSequence +from src.common.data_models.message_component_model import ( + MessageSequence, + StandardMessageComponents, + TextComponent, + ImageComponent, + EmojiComponent, + VoiceComponent, + AtComponent, + ReplyComponent, + DictComponent, +) class MessageUtils: @@ -13,3 +28,60 @@ class MessageUtils: def from_MaiSeq_to_db_record_msg(msg: MessageSequence) -> bytes: dict_representation = msg.to_dict() return msgpack.packb(dict_representation) # type: ignore + + @staticmethod + def from_maim_message_segments_to_MaiSeq(message: "MessageBase") -> MessageSequence: + """从maim_message.MessageBase.message_segment转换为MessageSequence""" + raw_msg_seq = message.message_segment + components: List[StandardMessageComponents] = [] + if not raw_msg_seq: + return MessageSequence(components) + if raw_msg_seq.type == "seglist": + assert isinstance(raw_msg_seq.data, list), "seglist类型的message_segment数据应该是一个列表" + components.extend(MessageUtils._parse_maim_message_segment_to_component(item) for item in raw_msg_seq.data) + elif raw_msg_seq.type in {"text", "image", "emoji", "voice", "at", "reply"}: + components.append(MessageUtils._parse_maim_message_segment_to_component(raw_msg_seq)) + else: + raise NotImplementedError(f"暂时不支持的消息片段类型: {raw_msg_seq.type}") + return MessageSequence(components) + + @staticmethod + async def from_MaiSeq_to_maim_message_segments(msg_seq: MessageSequence) -> List[Seg]: + """从MessageSequence转换为maim_message.MessageBase.message_segment格式的列表""" + segments = [] + for component in msg_seq.components: + if isinstance(component, DictComponent): + seg = Seg(type="dict", data=component.data) # type: ignore + else: + seg = await component.to_seg() + segments.append(seg) + return segments + + @staticmethod + def _parse_maim_message_segment_to_component(seg: Seg) -> "StandardMessageComponents": + if seg.type == "text": + assert isinstance(seg.data, str), "text类型的seg数据应该是字符串" + return TextComponent(text=seg.data) + elif seg.type == "image": + assert isinstance(seg.data, str), "image类型的seg数据应该是base64字符串" + image_bytes = base64.b64decode(seg.data) + binary_hash = hashlib.md5(image_bytes).hexdigest() + return ImageComponent(binary_hash=binary_hash, binary_data=image_bytes) + elif seg.type == "emoji": + assert isinstance(seg.data, str), "emoji类型的seg数据应该是base64字符串" + emoji_bytes = base64.b64decode(seg.data) + binary_hash = hashlib.md5(emoji_bytes).hexdigest() + return EmojiComponent(binary_hash=binary_hash, binary_data=emoji_bytes) + elif seg.type == "voice": + assert isinstance(seg.data, str), "voice类型的seg数据应该是base64字符串" + voice_bytes = base64.b64decode(seg.data) + binary_hash = hashlib.md5(voice_bytes).hexdigest() + return VoiceComponent(binary_hash=binary_hash, binary_data=voice_bytes) + elif seg.type == "at": + assert isinstance(seg.data, str), "at类型的seg数据应该是字符串" + return AtComponent(target_user_id=seg.data) + elif seg.type == "reply": + assert isinstance(seg.data, str), "reply类型的seg数据应该是字符串" + return ReplyComponent(target_message_id=seg.data) + else: + raise NotImplementedError(f"暂时不支持的消息片段类型: {seg.type}")