From 0d07e85434738e653a574a57084b1542beb94b34 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Mon, 23 Feb 2026 21:29:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=A8=E6=96=B0=E7=9A=84process=E6=96=B9?= =?UTF-8?q?=E6=B3=95=E5=AE=8C=E6=88=90=EF=BC=88Message=E5=85=B6=E4=BB=96?= =?UTF-8?q?=E9=83=A8=E5=88=86=E4=BB=8D=E6=9C=AA=E5=AE=8C=E6=88=90=EF=BC=89?= =?UTF-8?q?=EF=BC=9B=E5=AF=B9=E5=BA=94=E6=B5=8B=E8=AF=95=EF=BC=9B=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E9=83=A8=E5=88=86=E6=B3=A8=E9=87=8A=EF=BC=9B=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=A3=80=E7=B4=A2=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/message_test/session_message_test.py | 420 +++++++++++ src/chat/emoji_system/emoji_manager.py | 3 + src/chat/message_receive/chat_manager.py | 4 +- src/chat/message_receive/message.py | 705 +++++------------- .../message_component_data_model.py | 35 +- src/common/utils/utils_message.py | 4 +- src/common/utils/utils_person.py | 2 +- 7 files changed, 627 insertions(+), 546 deletions(-) create mode 100644 pytests/message_test/session_message_test.py diff --git a/pytests/message_test/session_message_test.py b/pytests/message_test/session_message_test.py new file mode 100644 index 00000000..9fdd22dd --- /dev/null +++ b/pytests/message_test/session_message_test.py @@ -0,0 +1,420 @@ +import sys +import asyncio +import pytest +import importlib +import importlib.util +from types import ModuleType +from pathlib import Path +from datetime import datetime +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent + from src.chat.message_receive.message import ( + SessionMessage, + TextComponent, + ImageComponent, + EmojiComponent, + VoiceComponent, + AtComponent, + ReplyComponent, + ForwardNodeComponent, + StandardMessageComponents, + ) + + +class DummyLogger: + def __init__(self) -> None: + self.logging_record = [] + + def debug(self, msg): + print(f"DEBUG: {msg}") + self.logging_record.append(f"DEBUG: {msg}") + + def info(self, msg): + print(f"INFO: {msg}") + self.logging_record.append(f"INFO: {msg}") + + def warning(self, msg): + print(f"WARNING: {msg}") + self.logging_record.append(f"WARNING: {msg}") + + def error(self, msg): + print(f"ERROR: {msg}") + self.logging_record.append(f"ERROR: {msg}") + + def critical(self, msg): + print(f"CRITICAL: {msg}") + self.logging_record.append(f"CRITICAL: {msg}") + + +def get_logger(name): + return DummyLogger() + + +class DummyDBSession: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def exec(self, statement): + return self + + def first(self): + return None + + def commit(self): + pass + + def all(self): + return [] + + +def get_db_session(): + return DummyDBSession() + + +def get_manual_db_session(): + return DummyDBSession() + + +class DummySelect: + def __init__(self, model): + self.model = model + + def filter_by(self, **kwargs): + return self + + def where(self, condition): + return self + + def limit(self, n): + return self + + +def select(model): + return DummySelect(model) + + +async def dummy_get_voice_text(binary_data): + return None # 可以根据需要返回模拟的文本结果 + + +class DummyPersonUtils: + @staticmethod + def get_person_info_by_user_id_and_platform(user_id, platform): + return None # 可以根据需要返回模拟的用户信息 + + +def setup_mocks(monkeypatch): + def _stub_module(name: str) -> ModuleType: + module = ModuleType(name) + monkeypatch.setitem(sys.modules, name, module) + return module + + # src.common.logger + logger_mod = _stub_module("src.common.logger") + # Mock the logger + logger_mod.get_logger = get_logger + + db_mod = _stub_module("src.common.database.database") + db_mod.get_db_session = get_db_session + db_mod.get_manual_db_session = get_manual_db_session + + emoji_manager_mod = _stub_module("src.chat.emoji_system.emoji_manager") + emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法 + + image_manager_mod = _stub_module("src.chat.image_system.image_manager") + image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法 + + msg_utils_mod = _stub_module("src.common.utils.utils_message") + msg_utils_mod.MessageUtils = None # 可以根据需要添加更多的属性或方法 + + voice_utils_mod = _stub_module("src.common.utils.utils_voice") + voice_utils_mod.get_voice_text = dummy_get_voice_text + + person_utils_mod = _stub_module("src.common.utils.utils_person") + person_utils_mod.PersonUtils = DummyPersonUtils + + +def load_message_via_file(monkeypatch): + setup_mocks(monkeypatch) + file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py" + spec = importlib.util.spec_from_file_location("message", file_path) + message_module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, "message_module", message_module) + spec.loader.exec_module(message_module) + message_module.select = select + SessionMessageClass = message_module.SessionMessage + TextComponentClass = message_module.TextComponent + ImageComponentClass = message_module.ImageComponent + EmojiComponentClass = message_module.EmojiComponent + VoiceComponentClass = message_module.VoiceComponent + AtComponentClass = message_module.AtComponent + ReplyComponentClass = message_module.ReplyComponent + ForwardNodeComponentClass = message_module.ForwardNodeComponent + MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence + ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent + globals()["SessionMessage"] = SessionMessageClass + globals()["TextComponent"] = TextComponentClass + globals()["ImageComponent"] = ImageComponentClass + globals()["EmojiComponent"] = EmojiComponentClass + globals()["VoiceComponent"] = VoiceComponentClass + globals()["AtComponent"] = AtComponentClass + globals()["ReplyComponent"] = ReplyComponentClass + globals()["ForwardNodeComponent"] = ForwardNodeComponentClass + globals()["MessageSequence"] = MessageSequenceClass + globals()["ForwardComponent"] = ForwardComponentClass + return message_module + + +@pytest.mark.asyncio +async def test_process(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [TextComponent("Hello, world!")] + await msg.process() + assert msg.processed_plain_text == "Hello, world!" + + +@pytest.mark.asyncio +async def test_multiple_text(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [TextComponent("Hello,"), TextComponent("world!")] + await msg.process() + assert msg.processed_plain_text == "Hello, world!" + + +@pytest.mark.asyncio +async def test_image(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [ImageComponent(binary_hash="image_hash"), TextComponent("Hello, world!")] + await msg.process() + assert msg.processed_plain_text == "[发了一张图片,网卡了加载不出来] Hello, world!" + + +@pytest.mark.asyncio +async def test_emoji(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [EmojiComponent(binary_hash="emoji_hash"), TextComponent("Hello, world!")] + await msg.process() + assert msg.processed_plain_text == "[发了一个表情,网卡了加载不出来] Hello, world!" + + +@pytest.mark.asyncio +async def test_voice(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [VoiceComponent(binary_hash="voice_hash"), TextComponent("Hello, world!")] + await msg.process() + assert msg.processed_plain_text == "[语音消息,转录失败] Hello, world!" + + +@pytest.mark.asyncio +async def test_at_component(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.platform = "test_platform" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [AtComponent(target_user_id="114514"), TextComponent("Hello, world!")] + await msg.process() + assert msg.processed_plain_text == "@114514 Hello, world!" + + +@pytest.mark.asyncio +async def test_reply_component_fail_to_fetch(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.platform = "test_platform" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")] + await msg.process() + assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!" + + +@pytest.mark.asyncio +async def test_reply_component_success(monkeypatch): + module_msg = load_message_via_file(monkeypatch) + + class DummyDBSessionWithReply(DummyDBSession): + def exec(self, s): + return self + + def first(inner_self): + class DummyRecord: + processed_plain_text = "原消息内容" + user_cardname = "cardname123" + user_nickname = "nickname123" + user_id = "userid123" + + return DummyRecord() + + module_msg.get_db_session = lambda: DummyDBSessionWithReply() + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.platform = "test_platform" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")] + await msg.process() + assert msg.processed_plain_text == "[回复了cardname123的消息: 原消息内容] Hello, world!" + + +@pytest.mark.asyncio +async def test_reply_component_with_db_fail(monkeypatch): + module_msg = load_message_via_file(monkeypatch) + + class DummyDBSessionWithError(DummyDBSession): + def exec(self, s): + raise Exception("数据库查询失败") + + module_msg.get_db_session = lambda: DummyDBSessionWithError() + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.platform = "test_platform" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")] + await msg.process() + assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!" + assert any("数据库查询失败" in log for log in module_msg.logger.logging_record) + + +@pytest.mark.asyncio +async def test_forward_component(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.platform = "test_platform" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [ + ForwardNodeComponent( + forward_components=[ + ForwardComponent( + message_id="msg1", + user_id="user1", + user_nickname="nickname1", + user_cardname="cardname1", + content=[TextComponent("转发消息1")], + ), + ForwardComponent( + message_id="msg2", + user_id="user2", + user_nickname="nickname2", + user_cardname="cardname2", + content=[TextComponent("转发消息2")], + ), + ] + ), + TextComponent("Hello, world!"), + ] + await msg.process() + print("Processed plain text:", msg.processed_plain_text) + expected_forward_text = """【合并转发消息: +-- 【cardname1】: 转发消息1 +-- 【cardname2】: 转发消息2 +】 Hello, world!""" + assert msg.processed_plain_text == expected_forward_text + + +@pytest.mark.asyncio +async def test_forward_with_reply(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.platform = "test_platform" + msg.raw_message = MessageSequence(components=[]) + msg.raw_message.components = [ + ForwardNodeComponent( + forward_components=[ + ForwardComponent( + message_id="msg1", + user_id="user1", + user_nickname="nickname1", + user_cardname="cardname1", + content=[TextComponent("转发消息1")], + ), + ForwardComponent( + message_id="msg2", + user_id="user2", + user_nickname="nickname2", + user_cardname="cardname2", + content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")], + ), + ] + ), + TextComponent("Hello, world!"), + ] + await msg.process() + assert ( + msg.processed_plain_text + == """【合并转发消息: +-- 【cardname1】: 转发消息1 +-- 【cardname2】: [回复了cardname1的消息: 转发消息1] 转发消息2 +】 Hello, world!""" + ) + + +@pytest.mark.asyncio +async def test_multiple_reply_with_delay_in_forward(monkeypatch): + load_message_via_file(monkeypatch) + msg = SessionMessage("msg123", datetime.now()) + msg.session_id = "session123" + msg.platform = "test_platform" + msg.raw_message = MessageSequence(components=[]) + + async def delayed_get_voice_text(binary_data): + await asyncio.sleep(0.5) # 模拟延迟 + return "这是语音转文本的结果" + + sys.modules["src.common.utils.utils_voice"].get_voice_text = delayed_get_voice_text + + msg.raw_message.components = [ + ForwardNodeComponent( + forward_components=[ + ForwardComponent( + message_id="msg1", + user_id="user1", + user_nickname="nickname1", + user_cardname="cardname1", + content=[VoiceComponent(binary_hash="voice_hash1"), TextComponent("转发消息1")], + ), + ForwardComponent( + message_id="msg2", + user_id="user2", + user_nickname="nickname2", + user_cardname="cardname2", + content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")], + ), + ForwardComponent( + message_id="msg3", + user_id="user3", + user_nickname="nickname3", + user_cardname="cardname3", + content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息3")], + ), + ] + ), + ] + await msg.process() + expected_text = """【合并转发消息: +-- 【cardname1】: [语音: 这是语音转文本的结果] 转发消息1 +-- 【cardname2】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息2 +-- 【cardname3】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息3 +】""" + assert msg.processed_plain_text == expected_text diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index c5b40bd9..2be0bd69 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -67,6 +67,9 @@ class EmojiManager: emoji_hash (Optional[str]): 表情包的哈希值,如果提供了哈希值则优先使用哈希值查找表情包描述 Returns: return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包,则返回包含描述和情感标签的元组;若没找到,则尝试构建表情包描述并返回,如果构建失败则返回 None + Raises: + ValueError: 如果既没有提供表情包字节数据,也没有提供表情包哈希值,则抛出异常 + Exception: 如果在缓存表情包的过程中发生错误,则抛出异常 """ # 先查找 if emoji_hash is None and emoji_bytes is not None: diff --git a/src/chat/message_receive/chat_manager.py b/src/chat/message_receive/chat_manager.py index fa20d876..3dd73050 100644 --- a/src/chat/message_receive/chat_manager.py +++ b/src/chat/message_receive/chat_manager.py @@ -103,7 +103,7 @@ class ChatManager: # 内存没有就找db try: with get_db_session() as db_session: - statement = select(ChatSession).filter_by(session_id=session_id) + statement = select(ChatSession).filter_by(session_id=session_id).limit(1) if result := db_session.exec(statement).first(): session = BotChatSession.from_db_instance(result) self.sessions[session.session_id] = session @@ -229,7 +229,7 @@ class ChatManager: """将会话记录保存到数据库""" with get_db_session() as db_session: db_instance = session.to_db_instance() - statement = select(ChatSession).filter_by(session_id=db_instance.session_id) + statement = select(ChatSession).filter_by(session_id=db_instance.session_id).limit(1) if result := db_session.exec(statement).first(): result.created_timestamp = db_instance.created_timestamp result.last_active_timestamp = db_instance.last_active_timestamp diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 86122b9a..099aada2 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,561 +1,204 @@ -import time -import asyncio -import urllib3 - -from abc import abstractmethod -from dataclasses import dataclass +from asyncio import Task from rich.traceback import install -from typing import Optional, Any, List -from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase +from sqlmodel import select +from typing import List, Dict, Tuple, Sequence + +import asyncio from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.utils.utils_image import get_image_manager -from src.chat.utils.utils_voice import get_voice_text -from .chat_stream import ChatStream +from src.common.database.database import get_db_session +from src.common.database.database_model import Messages +from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.common.data_models.message_component_data_model import ( + TextComponent, + ImageComponent, + EmojiComponent, + AtComponent, + ReplyComponent, + VoiceComponent, + ForwardNodeComponent, + StandardMessageComponents, +) + install(extra_lines=3) logger = get_logger("chat_message") -# 禁用SSL警告 -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) -# VLM 处理并发限制(避免同时处理太多图片导致卡死) -_vlm_semaphore = asyncio.Semaphore(3) - -# 这个类是消息数据类,用于存储和管理消息数据。 -# 它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 -# 它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。 +class MsgIDMapping: + def __init__(self): + self.mapping: Dict[str, Tuple[str | Task, UserInfo]] = {} -@dataclass -class Message(MessageBase): - chat_stream: "ChatStream" = None # type: ignore - reply: Optional["Message"] = None - processed_plain_text: str = "" +class SessionMessage(MaiMessage): + async def process(self): + """处理消息内容,识别消息内容并转化为文本""" + tasks = [self.process_single_component(component, MsgIDMapping()) for component in self.raw_message.components] + results = await asyncio.gather(*tasks, return_exceptions=True) + processed_texts: List[str] = [] + for result in results: + if isinstance(result, BaseException): + logger.error(f"处理消息组件时发生错误: {result}") + else: + processed_texts.append(result) + self.processed_plain_text = " ".join(processed_texts) - def __init__( - self, - message_id: str, - chat_stream: "ChatStream", - user_info: UserInfo, - message_segment: Optional[Seg] = None, - timestamp: Optional[float] = None, - reply: Optional["MessageRecv"] = None, - processed_plain_text: str = "", - ): - # 使用传入的时间戳或当前时间 - current_timestamp = timestamp if timestamp is not None else round(time.time(), 3) - # 构造基础消息信息 - message_info = BaseMessageInfo( - platform=chat_stream.platform, - message_id=message_id, - time=current_timestamp, - group_info=chat_stream.group_info, - user_info=user_info, - ) - - # 调用父类初始化 - super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore - - self.chat_stream = chat_stream - # 文本处理相关属性 - self.processed_plain_text = processed_plain_text - - # 回复消息 - self.reply = reply - - async def _process_message_segments(self, segment: Seg) -> str: - # sourcery skip: remove-unnecessary-else, swap-if-else-branches - """递归处理消息段,转换为文字描述 - - Args: - segment: 要处理的消息段 - - Returns: - str: 处理后的文本 - """ - if segment.type == "seglist": - # 处理消息段列表 - 使用并行处理提升性能 - tasks = [self._process_message_segments(seg) for seg in segment.data] # type: ignore - results = await asyncio.gather(*tasks, return_exceptions=True) - segments_text = [] - for result in results: - if isinstance(result, Exception): - logger.error(f"处理消息段时出错: {result}") - continue - if result: - segments_text.append(result) - return " ".join(segments_text) - elif segment.type == "forward": - # 处理转发消息 - 使用并行处理 - async def process_forward_node(node_dict): - message = MessageBase.from_dict(node_dict) # type: ignore - processed_text = await self._process_message_segments(message.message_segment) - if processed_text: - return f"{global_config.bot.nickname}: {processed_text}" - return None - - tasks = [process_forward_node(node_dict) for node_dict in segment.data] - results = await asyncio.gather(*tasks, return_exceptions=True) - segments_text = [] - for result in results: - if isinstance(result, Exception): - logger.error(f"处理转发节点时出错: {result}") - continue - if result: - segments_text.append(result) - return "[合并消息]: " + "\n-- ".join(segments_text) + async def process_single_component( + self, component: StandardMessageComponents, id_content_map: MsgIDMapping, recursion_depth: int = 0 + ) -> str: + if isinstance(component, TextComponent): + return component.text + elif isinstance(component, ImageComponent): + return await self.process_image_component(component) + elif isinstance(component, EmojiComponent): + return await self.process_emoji_component(component) + elif isinstance(component, AtComponent): + return await self.process_at_component(component) + elif isinstance(component, VoiceComponent): + return await self.process_voice_component(component) + elif isinstance(component, ReplyComponent): + return await self.process_reply_component(component, id_content_map) + elif isinstance(component, ForwardNodeComponent): + return await self.process_forward_component(component, id_content_map, recursion_depth=recursion_depth + 1) else: - # 处理单个消息段 - return await self._process_single_segment(segment) # type: ignore + raise NotImplementedError(f"暂时不支持的消息组件类型: {type(component)}") - @abstractmethod - async def _process_single_segment(self, segment) -> str: - pass + async def process_image_component(self, component: ImageComponent) -> str: + if component.content: # 先检查是否处理过 + return component.content + from src.chat.image_system.image_manager import image_manager - -@dataclass -class MessageRecv(Message): - """接收消息类,用于处理从MessageCQ序列化的消息""" - - def __init__(self, message_dict: dict[str, Any]): - """从MessageCQ的字典初始化 - - Args: - message_dict: MessageCQ序列化后的字典 - """ - self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) - self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) - self.raw_message = message_dict.get("raw_message") - self.processed_plain_text = message_dict.get("processed_plain_text", "") - self.is_emoji = False - self.has_emoji = False - self.is_picid = False - self.has_picid = False - self.is_voice = False - self.is_mentioned = None - self.is_at = False - self.reply_probability_boost = 0.0 - self.is_notify = False - - self.is_command = False - self.intercept_message_level = 0 - - self.priority_mode = "interest" - self.priority_info = None - self.interest_value: float = None # type: ignore - - self.key_words = [] - self.key_words_lite = [] - - # 兼容适配器通过 additional_config 传入的 @ 标记 + # 获取描述 try: - msg_info_dict = message_dict.get("message_info", {}) - add_cfg = msg_info_dict.get("additional_config") or {} - if isinstance(add_cfg, dict) and add_cfg.get("at_bot"): - # 标记为被提及,提高后续回复优先级 - self.is_mentioned = True # type: ignore + desc = await image_manager.get_image_description(image_bytes=component.binary_data) except Exception: - pass + desc = None - def update_chat_stream(self, chat_stream: "ChatStream"): - self.chat_stream = chat_stream + content = f"[图片:{desc}]" if desc else "[发了一张图片,网卡了加载不出来]" + component.content = content + return content - async def process(self) -> None: - """处理消息内容,生成纯文本和详细文本 + async def process_emoji_component(self, component: EmojiComponent) -> str: + if component.content: # 先检查是否处理过 + return component.content + from src.chat.emoji_system.emoji_manager import emoji_manager - 这个方法必须在创建实例后显式调用,因为它包含异步操作。 - """ - # print(f"self.message_segment: {self.message_segment}") - self.processed_plain_text = await self._process_message_segments(self.message_segment) - - async def _process_single_segment(self, segment: Seg) -> str: - """处理单个消息段 - - Args: - segment: 消息段 - - Returns: - str: 处理后的文本 - """ + # 获取表情包描述 try: - if segment.type == "text": - self.is_picid = False - self.is_emoji = False - return segment.data # type: ignore - elif segment.type == "image": - # 如果是base64图片数据 - if isinstance(segment.data, str): - self.has_picid = True - self.is_picid = True - self.is_emoji = False - image_manager = get_image_manager() - # 使用 semaphore 限制 VLM 并发,避免同时处理太多图片 - async with _vlm_semaphore: - _, processed_text = await image_manager.process_image(segment.data) - return processed_text - return "[发了一张图片,网卡了加载不出来]" - elif segment.type == "emoji": - self.has_emoji = True - self.is_emoji = True - self.is_picid = False - self.is_voice = False - if isinstance(segment.data, str): - # 使用 semaphore 限制 VLM 并发 - async with _vlm_semaphore: - return await get_image_manager().get_emoji_description(segment.data) - return "[发了一个表情包,网卡了加载不出来]" - elif segment.type == "voice": - self.is_picid = False - self.is_emoji = False - self.is_voice = True - if isinstance(segment.data, str): - return await get_voice_text(segment.data) - return "[发了一段语音,网卡了加载不出来]" - elif segment.type == "mention_bot": - self.is_picid = False - self.is_emoji = False - self.is_voice = False - self.is_mentioned = float(segment.data) # type: ignore - return "" - elif segment.type == "priority_info": - self.is_picid = False - self.is_emoji = False - self.is_voice = False - if isinstance(segment.data, dict): - # 处理优先级信息 - self.priority_mode = "priority" - self.priority_info = segment.data - """ - { - 'message_type': 'vip', # vip or normal - 'message_priority': 1.0, # 优先级,大为优先,float - } - """ - return "" - elif segment.type == "video_card": - # 处理视频卡片消息 - self.is_picid = False - self.is_emoji = False - self.is_voice = False - if isinstance(segment.data, dict): - file_name = segment.data.get("file", "未知视频") - file_size = segment.data.get("file_size", "") - url = segment.data.get("url", "") - text = f"[视频: {file_name}" - if file_size: - text += f", 大小: {file_size}字节" - text += "]" - if url: - text += f" 链接: {url}" - return text - return "[视频]" - elif segment.type == "music_card": - # 处理音乐卡片消息 - self.is_picid = False - self.is_emoji = False - self.is_voice = False - if isinstance(segment.data, dict): - title = segment.data.get("title", "未知歌曲") - singer = segment.data.get("singer", "") - tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐" - jump_url = segment.data.get("jump_url", "") - music_url = segment.data.get("music_url", "") - text = f"[音乐: {title}" - if singer: - text += f" - {singer}" - if tag: - text += f" ({tag})" - text += "]" - if jump_url: - text += f" 跳转链接: {jump_url}" - if music_url: - text += f" 音乐链接: {music_url}" - return text - return "[音乐]" - elif segment.type == "miniapp_card": - # 处理小程序分享卡片(如B站视频分享) - self.is_picid = False - self.is_emoji = False - self.is_voice = False - if isinstance(segment.data, dict): - title = segment.data.get("title", "") # 小程序名称 - desc = segment.data.get("desc", "") # 内容描述 - source_url = segment.data.get("source_url", "") # 原始链接 - url = segment.data.get("url", "") # 小程序链接 - text = "[小程序分享" - if title: - text += f" - {title}" - text += "]" - if desc: - text += f" {desc}" - if source_url: - text += f" 链接: {source_url}" - elif url: - text += f" 链接: {url}" - return text - return "[小程序分享]" - else: - return "" - except Exception as e: - logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") - return f"[处理失败的{segment.type}消息]" + tuple_content = await emoji_manager.get_emoji_description(emoji_bytes=component.binary_data) + except Exception: + tuple_content = None + if tuple_content: + desc, _ = tuple_content + content = f"[表情包: {desc}]" + else: + content = "[发了一个表情,网卡了加载不出来]" + component.content = content + return content -@dataclass -class MessageProcessBase(Message): - """消息处理基类,用于处理中和发送中的消息""" + async def process_at_component(self, component: AtComponent) -> str: + if component.target_user_cardname: + return f"@{component.target_user_cardname}" + elif component.target_user_nickname: + return f"@{component.target_user_nickname}" + from src.common.utils.utils_person import PersonUtils - def __init__( + if person_info := PersonUtils.get_person_info_by_user_id_and_platform(component.target_user_id, self.platform): + component.target_user_nickname = component.target_user_nickname or person_info.user_nickname + if self.message_info.group_info and person_info.group_cardname_list: + for group_card in person_info.group_cardname_list: + if group_card.group_id == self.message_info.group_info.group_id: + component.target_user_cardname = group_card.group_cardname + break + if component.target_user_cardname: + return f"@{component.target_user_cardname}" + elif component.target_user_nickname: + return f"@{component.target_user_nickname}" + else: + return f"@{component.target_user_id}" + + async def process_voice_component(self, component: VoiceComponent) -> str: + if component.content: # 先检查是否处理过 + return component.content + from src.common.utils.utils_voice import get_voice_text + + text = await get_voice_text(component.binary_data) + content = "[语音消息,转录失败]" if text is None else f"[语音: {text}]" + component.content = content + return content + + async def process_reply_component( self, - message_id: str, - chat_stream: "ChatStream", - bot_user_info: UserInfo, - message_segment: Optional[Seg] = None, - reply: Optional["MessageRecv"] = None, - thinking_start_time: float = 0, - timestamp: Optional[float] = None, - ): - # 调用父类初始化,传递时间戳 - super().__init__( - message_id=message_id, - timestamp=timestamp, - chat_stream=chat_stream, - user_info=bot_user_info, - message_segment=message_segment, - reply=reply, - ) + component: ReplyComponent, + id_content_map: MsgIDMapping, + ) -> str: + if component.target_message_content: + return component.target_message_content + if result_item := id_content_map.mapping.get(component.target_message_id): + content, sender_info = result_item + if isinstance(content, Task): + content = await content + id_content_map.mapping[component.target_message_id] = (content, sender_info) # 更新为实际内容 + component.target_message_content = content + tgt_msg_s_name = sender_info.user_cardname or sender_info.user_nickname or sender_info.user_id + component.target_message_sender_cardname = sender_info.user_cardname + component.target_message_sender_nickname = sender_info.user_nickname + component.target_message_sender_id = sender_info.user_id + return f"[回复了{tgt_msg_s_name}的消息: {content}]" + else: + try: + with get_db_session() as session: + statement = select(Messages).filter_by(message_id=component.target_message_id).limit(1) + if db_msg := session.exec(statement).first(): + component.target_message_content = db_msg.processed_plain_text + component.target_message_sender_cardname = db_msg.user_cardname + component.target_message_sender_nickname = db_msg.user_nickname + component.target_message_sender_id = db_msg.user_id + tgt_msg_s_name = db_msg.user_cardname or db_msg.user_nickname or db_msg.user_id + return f"[回复了{tgt_msg_s_name}的消息: {db_msg.processed_plain_text}]" + except Exception as e: + logger.error(f"查询回复消息时发生错误: {e}") - # 处理状态相关属性 - self.thinking_start_time = thinking_start_time - self.thinking_time = 0 + return "[回复了一条消息,但原消息已无法访问]" - def update_thinking_time(self) -> float: - """更新思考时间""" - self.thinking_time = round(time.time() - self.thinking_start_time, 2) - return self.thinking_time - - async def _process_single_segment(self, segment: Seg) -> str: - """处理单个消息段 - - Args: - segment: 要处理的消息段 - - Returns: - str: 处理后的文本 - """ - try: - if segment.type == "text": - return segment.data # type: ignore - elif segment.type == "image": - # 如果是base64图片数据 - if isinstance(segment.data, str): - return await get_image_manager().get_image_description(segment.data) - return "[图片,网卡了加载不出来]" - elif segment.type == "emoji": - if isinstance(segment.data, str): - return await get_image_manager().get_emoji_tag(segment.data) - return "[表情,网卡了加载不出来]" - elif segment.type == "voice": - if isinstance(segment.data, str): - return await get_voice_text(segment.data) - return "[发了一段语音,网卡了加载不出来]" - elif segment.type == "at": - return f"[@{segment.data}]" - elif segment.type == "reply": - if self.reply and hasattr(self.reply, "processed_plain_text"): - # print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}") - # print(f"reply: {self.reply}") - return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore - return "" - else: - return f"[{segment.type}:{str(segment.data)}]" - except Exception as e: - logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") - return f"[处理失败的{segment.type}消息]" - - def _generate_detailed_text(self) -> str: - """生成详细文本,包含时间和用户信息""" - # time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time)) - timestamp = self.message_info.time - user_info = self.message_info.user_info - - name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore - return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n" - - -@dataclass -class MessageSending(MessageProcessBase): - """发送状态的消息类""" - - def __init__( - self, - message_id: str, - chat_stream: "ChatStream", - bot_user_info: UserInfo, - sender_info: UserInfo | None, # 用来记录发送者信息 - message_segment: Seg, - display_message: str = "", - reply: Optional["MessageRecv"] = None, - is_head: bool = False, - is_emoji: bool = False, - thinking_start_time: float = 0, - apply_set_reply_logic: bool = False, - reply_to: Optional[str] = None, - selected_expressions: Optional[List[int]] = None, - ): - # 调用父类初始化 - super().__init__( - message_id=message_id, - chat_stream=chat_stream, - bot_user_info=bot_user_info, - message_segment=message_segment, - reply=reply, - thinking_start_time=thinking_start_time, - ) - - # 发送状态特有属性 - self.sender_info = sender_info - self.reply_to_message_id = reply.message_info.message_id if reply else None - self.is_head = is_head - self.is_emoji = is_emoji - self.apply_set_reply_logic = apply_set_reply_logic - - self.reply_to = reply_to - - # 用于显示发送内容与显示不一致的情况 - self.display_message = display_message - - self.interest_value = 0.0 - - self.selected_expressions = selected_expressions - - def build_reply(self): - """设置回复消息""" - if self.reply: - self.reply_to_message_id = self.reply.message_info.message_id - self.message_segment = Seg( - type="seglist", - data=[ - Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore - self.message_segment, - ], + async def process_forward_component( + self, component: ForwardNodeComponent, id_content_map: MsgIDMapping, recursion_depth: int = 0 + ) -> str: + task_list: List[Task] = [] + node_user_info_list: List[UserInfo] = [] + for node in component.forward_components: + task = asyncio.create_task( + self._process_multiple_components(node.content, id_content_map, recursion_depth + 1) ) - - async def process(self) -> None: - """处理消息内容,生成纯文本和详细文本""" - if self.message_segment: - self.processed_plain_text = await self._process_message_segments(self.message_segment) - - def to_dict(self): - ret = super().to_dict() - ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict() - return ret - - def is_private_message(self) -> bool: - """判断是否为私聊消息""" - return self.message_info.group_info is None or self.message_info.group_info.group_id is None - - -@dataclass -class MessageSet: - """消息集合类,可以存储多个发送消息""" - - def __init__(self, chat_stream: "ChatStream", message_id: str): - self.chat_stream = chat_stream - self.message_id = message_id - self.messages: list[MessageSending] = [] - self.time = round(time.time(), 3) # 保留3位小数 - - def add_message(self, message: MessageSending) -> None: - """添加消息到集合""" - if not isinstance(message, MessageSending): - raise TypeError("MessageSet只能添加MessageSending类型的消息") - self.messages.append(message) - self.messages.sort(key=lambda x: x.message_info.time) # type: ignore - - def get_message_by_index(self, index: int) -> Optional[MessageSending]: - """通过索引获取消息""" - return self.messages[index] if 0 <= index < len(self.messages) else None - - def get_message_by_time(self, target_time: float) -> Optional[MessageSending]: - """获取最接近指定时间的消息""" - if not self.messages: - return None - - left, right = 0, len(self.messages) - 1 - while left < right: - mid = (left + right) // 2 - if self.messages[mid].message_info.time < target_time: # type: ignore - left = mid + 1 + node_user_info = UserInfo(node.user_id or "未知用户", node.user_nickname, node.user_cardname) + id_content_map.mapping[node.message_id] = (task, node_user_info) + task_list.append(task) + node_user_info_list.append(node_user_info) + results = await asyncio.gather(*task_list, return_exceptions=True) + forward_texts = [] + for idx, result in enumerate(results): + if isinstance(result, BaseException): + logger.error(f"处理转发消息组件时发生错误: {result}") else: - right = mid + usr_info = node_user_info_list[idx] + msg_sender_name = usr_info.user_cardname or usr_info.user_nickname or usr_info.user_id or "未知用户" + forward_texts.append(f"{'-' * recursion_depth * 2} 【{msg_sender_name}】: {result}") + return "【合并转发消息: \n" + "\n".join(forward_texts) + "\n】" - return self.messages[left] - - def clear_messages(self) -> None: - """清空所有消息""" - self.messages.clear() - - def remove_message(self, message: MessageSending) -> bool: - """移除指定消息""" - if message in self.messages: - self.messages.remove(message) - return True - return False - - def __str__(self) -> str: - return f"MessageSet(id={self.message_id}, count={len(self.messages)})" - - def __len__(self) -> int: - return len(self.messages) - - -def message_recv_from_dict(message_dict: dict) -> MessageRecv: - return MessageRecv(message_dict) - - -def message_from_db_dict(db_dict: dict) -> MessageRecv: - """从数据库字典创建MessageRecv实例""" - # 转换扁平的数据库字典为嵌套结构 - message_info_dict = { - "platform": db_dict.get("chat_info_platform"), - "message_id": db_dict.get("message_id"), - "time": db_dict.get("time"), - "group_info": { - "platform": db_dict.get("chat_info_group_platform"), - "group_id": db_dict.get("chat_info_group_id"), - "group_name": db_dict.get("chat_info_group_name"), - }, - "user_info": { - "platform": db_dict.get("user_platform"), - "user_id": db_dict.get("user_id"), - "user_nickname": db_dict.get("user_nickname"), - "user_cardname": db_dict.get("user_cardname"), - }, - } - - processed_text = db_dict.get("processed_plain_text", "") - - # 构建 MessageRecv 需要的字典 - recv_dict = { - "message_info": message_info_dict, - "message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段 - "raw_message": None, # 数据库中未存储原始消息 - "processed_plain_text": processed_text, - } - - # 创建 MessageRecv 实例 - msg = MessageRecv(recv_dict) - - # 从数据库字典中填充其他可选字段 - msg.interest_value = db_dict.get("interest_value", 0.0) - msg.is_mentioned = db_dict.get("is_mentioned") - msg.priority_mode = db_dict.get("priority_mode", "interest") - msg.priority_info = db_dict.get("priority_info") - msg.is_emoji = db_dict.get("is_emoji", False) - msg.is_picid = db_dict.get("is_picid", False) - - return msg + async def _process_multiple_components( + self, components: Sequence[StandardMessageComponents], id_content_map: MsgIDMapping, recursion_depth: int = 0 + ) -> str: + tasks = [ + self.process_single_component(component, id_content_map, recursion_depth=recursion_depth) + for component in components + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + processed_texts: List[str] = [] + for result in results: + if isinstance(result, BaseException): + logger.error(f"处理消息组件时发生错误: {result}") + else: + processed_texts.append(result) + return " ".join(processed_texts) diff --git a/src/common/data_models/message_component_data_model.py b/src/common/data_models/message_component_data_model.py index 8290fdbe..f54dee0a 100644 --- a/src/common/data_models/message_component_data_model.py +++ b/src/common/data_models/message_component_data_model.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo -from typing import Optional, List, Union, Dict, Any +from typing import Optional, List, Union, Dict, Any, Sequence import asyncio import hashlib @@ -142,9 +142,9 @@ class AtComponent(BaseMessageComponentModel): ) -> None: self.target_user_id = target_user_id """目标用户ID""" - self.target_user_nickname = target_user_nickname + self.target_user_nickname: Optional[str] = target_user_nickname """目标用户昵称""" - self.target_user_cardname = target_user_cardname + self.target_user_cardname: Optional[str] = target_user_cardname """目标用户备注名""" assert isinstance(target_user_id, str), "AtComponent 的 target_user_id 必须是字符串类型" @@ -159,10 +159,25 @@ class ReplyComponent(BaseMessageComponentModel): def format_name(self) -> str: return "reply" - def __init__(self, target_message_id: str) -> None: + def __init__( + self, + target_message_id: str, + target_message_content: Optional[str] = None, + target_message_sender_id: Optional[str] = None, + target_message_sender_nickname: Optional[str] = None, + target_message_sender_cardname: Optional[str] = None, + ) -> None: assert isinstance(target_message_id, str), "ReplyComponent 的 target_message_id 必须是字符串类型" self.target_message_id = target_message_id """目标消息ID""" + self.target_message_content: Optional[str] = target_message_content + """目标消息内容""" + self.target_message_sender_id: Optional[str] = target_message_sender_id + """目标消息发送者ID""" + self.target_message_sender_nickname: Optional[str] = target_message_sender_nickname + """目标消息发送者昵称""" + self.target_message_sender_cardname: Optional[str] = target_message_sender_cardname + """目标消息发送者群昵称""" async def to_seg(self) -> Seg: return Seg(type="reply", data=self.target_message_id) @@ -224,7 +239,7 @@ class ForwardComponent(BaseMessageComponentModel): self, user_nickname: str, message_id: str, - content: List[StandardMessageComponents], + content: Sequence[StandardMessageComponents], user_id: Optional[str] = None, user_cardname: Optional[str] = None, ): @@ -232,7 +247,7 @@ class ForwardComponent(BaseMessageComponentModel): """转发节点的发送者昵称""" self.message_id: str = message_id """转发节点的消息ID""" - self.content: List[StandardMessageComponents] = content + self.content: Sequence[StandardMessageComponents] = content """消息内容""" self.user_id: Optional[str] = user_id """转发节点的发送者ID,可能为 None""" @@ -249,7 +264,7 @@ class ForwardComponent(BaseMessageComponentModel): class MessageSequence: """消息组件序列,包含一个消息中的所有组件,按照顺序排列""" - def __init__(self, components: List[StandardMessageComponents]): + def __init__(self, components: Sequence[StandardMessageComponents]): """ 创建一个消息组件序列 @@ -259,16 +274,16 @@ class MessageSequence: 因此也可以包含多个`ReplyComponent`组件(例如回复多条消息)。 如果需要对组件进行去重或校验,还请在使用时自行处理。 """ - self.components: List[StandardMessageComponents] = components + self.components: Sequence[StandardMessageComponents] = components def to_dict(self) -> List[Dict[str, Any]]: """将消息序列转换为字典列表格式,便于存储或传输""" return [self._item_2_dict(comp) for comp in self.components] @classmethod - def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence": + def from_dict(cls, data: List[Dict[str, Any]]): """从字典列表格式创建消息序列实例""" - components: List[StandardMessageComponents] = [] + components: Sequence[StandardMessageComponents] = [] components.extend(cls._dict_2_item(item) for item in data) return cls(components=components) diff --git a/src/common/utils/utils_message.py b/src/common/utils/utils_message.py index 96481baa..8ac898c3 100644 --- a/src/common/utils/utils_message.py +++ b/src/common/utils/utils_message.py @@ -1,5 +1,5 @@ from maim_message import MessageBase, Seg -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Sequence import base64 import hashlib @@ -35,7 +35,7 @@ class MessageUtils: def from_maim_message_segments_to_MaiSeq(message: "MessageBase") -> MessageSequence: """从maim_message.MessageBase.message_segment转换为MessageSequence""" raw_msg_seq = message.message_segment - components: List[StandardMessageComponents] = [] + components: Sequence[StandardMessageComponents] = [] if not raw_msg_seq: return MessageSequence(components) if raw_msg_seq.type == "seglist": diff --git a/src/common/utils/utils_person.py b/src/common/utils/utils_person.py index 6089e2d2..f0d9886e 100644 --- a/src/common/utils/utils_person.py +++ b/src/common/utils/utils_person.py @@ -20,7 +20,7 @@ class PersonUtils: """根据person_id获取用户信息""" try: with get_db_session() as session: - statement = select(PersonInfo).filter_by(person_id=person_id) + statement = select(PersonInfo).filter_by(person_id=person_id).limit(1) if result := session.exec(statement).first(): return MaiPersonInfo.from_db_instance(result) except Exception as e: