全新的process方法完成(Message其他部分仍未完成);对应测试;调整部分注释;数据库检索优化

pull/1496/merge
UnCLAS-Prommer 2026-02-23 21:29:17 +08:00
parent 698b8355a4
commit 0d07e85434
No known key found for this signature in database
7 changed files with 627 additions and 546 deletions

View File

@ -0,0 +1,420 @@
import sys
import asyncio
import pytest
import importlib
import importlib.util
from types import ModuleType
from pathlib import Path
from datetime import datetime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent
from src.chat.message_receive.message import (
SessionMessage,
TextComponent,
ImageComponent,
EmojiComponent,
VoiceComponent,
AtComponent,
ReplyComponent,
ForwardNodeComponent,
StandardMessageComponents,
)
class DummyLogger:
def __init__(self) -> None:
self.logging_record = []
def debug(self, msg):
print(f"DEBUG: {msg}")
self.logging_record.append(f"DEBUG: {msg}")
def info(self, msg):
print(f"INFO: {msg}")
self.logging_record.append(f"INFO: {msg}")
def warning(self, msg):
print(f"WARNING: {msg}")
self.logging_record.append(f"WARNING: {msg}")
def error(self, msg):
print(f"ERROR: {msg}")
self.logging_record.append(f"ERROR: {msg}")
def critical(self, msg):
print(f"CRITICAL: {msg}")
self.logging_record.append(f"CRITICAL: {msg}")
def get_logger(name):
return DummyLogger()
class DummyDBSession:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def exec(self, statement):
return self
def first(self):
return None
def commit(self):
pass
def all(self):
return []
def get_db_session():
return DummyDBSession()
def get_manual_db_session():
return DummyDBSession()
class DummySelect:
def __init__(self, model):
self.model = model
def filter_by(self, **kwargs):
return self
def where(self, condition):
return self
def limit(self, n):
return self
def select(model):
return DummySelect(model)
async def dummy_get_voice_text(binary_data):
return None # 可以根据需要返回模拟的文本结果
class DummyPersonUtils:
@staticmethod
def get_person_info_by_user_id_and_platform(user_id, platform):
return None # 可以根据需要返回模拟的用户信息
def setup_mocks(monkeypatch):
def _stub_module(name: str) -> ModuleType:
module = ModuleType(name)
monkeypatch.setitem(sys.modules, name, module)
return module
# src.common.logger
logger_mod = _stub_module("src.common.logger")
# Mock the logger
logger_mod.get_logger = get_logger
db_mod = _stub_module("src.common.database.database")
db_mod.get_db_session = get_db_session
db_mod.get_manual_db_session = get_manual_db_session
emoji_manager_mod = _stub_module("src.chat.emoji_system.emoji_manager")
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法
msg_utils_mod = _stub_module("src.common.utils.utils_message")
msg_utils_mod.MessageUtils = None # 可以根据需要添加更多的属性或方法
voice_utils_mod = _stub_module("src.common.utils.utils_voice")
voice_utils_mod.get_voice_text = dummy_get_voice_text
person_utils_mod = _stub_module("src.common.utils.utils_person")
person_utils_mod.PersonUtils = DummyPersonUtils
def load_message_via_file(monkeypatch):
setup_mocks(monkeypatch)
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py"
spec = importlib.util.spec_from_file_location("message", file_path)
message_module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "message_module", message_module)
spec.loader.exec_module(message_module)
message_module.select = select
SessionMessageClass = message_module.SessionMessage
TextComponentClass = message_module.TextComponent
ImageComponentClass = message_module.ImageComponent
EmojiComponentClass = message_module.EmojiComponent
VoiceComponentClass = message_module.VoiceComponent
AtComponentClass = message_module.AtComponent
ReplyComponentClass = message_module.ReplyComponent
ForwardNodeComponentClass = message_module.ForwardNodeComponent
MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence
ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent
globals()["SessionMessage"] = SessionMessageClass
globals()["TextComponent"] = TextComponentClass
globals()["ImageComponent"] = ImageComponentClass
globals()["EmojiComponent"] = EmojiComponentClass
globals()["VoiceComponent"] = VoiceComponentClass
globals()["AtComponent"] = AtComponentClass
globals()["ReplyComponent"] = ReplyComponentClass
globals()["ForwardNodeComponent"] = ForwardNodeComponentClass
globals()["MessageSequence"] = MessageSequenceClass
globals()["ForwardComponent"] = ForwardComponentClass
return message_module
@pytest.mark.asyncio
async def test_process(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "Hello, world!"
@pytest.mark.asyncio
async def test_multiple_text(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [TextComponent("Hello,"), TextComponent("world!")]
await msg.process()
assert msg.processed_plain_text == "Hello, world!"
@pytest.mark.asyncio
async def test_image(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [ImageComponent(binary_hash="image_hash"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[发了一张图片,网卡了加载不出来] Hello, world!"
@pytest.mark.asyncio
async def test_emoji(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [EmojiComponent(binary_hash="emoji_hash"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[发了一个表情,网卡了加载不出来] Hello, world!"
@pytest.mark.asyncio
async def test_voice(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [VoiceComponent(binary_hash="voice_hash"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[语音消息,转录失败] Hello, world!"
@pytest.mark.asyncio
async def test_at_component(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [AtComponent(target_user_id="114514"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "@114514 Hello, world!"
@pytest.mark.asyncio
async def test_reply_component_fail_to_fetch(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
@pytest.mark.asyncio
async def test_reply_component_success(monkeypatch):
module_msg = load_message_via_file(monkeypatch)
class DummyDBSessionWithReply(DummyDBSession):
def exec(self, s):
return self
def first(inner_self):
class DummyRecord:
processed_plain_text = "原消息内容"
user_cardname = "cardname123"
user_nickname = "nickname123"
user_id = "userid123"
return DummyRecord()
module_msg.get_db_session = lambda: DummyDBSessionWithReply()
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[回复了cardname123的消息: 原消息内容] Hello, world!"
@pytest.mark.asyncio
async def test_reply_component_with_db_fail(monkeypatch):
module_msg = load_message_via_file(monkeypatch)
class DummyDBSessionWithError(DummyDBSession):
def exec(self, s):
raise Exception("数据库查询失败")
module_msg.get_db_session = lambda: DummyDBSessionWithError()
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
assert any("数据库查询失败" in log for log in module_msg.logger.logging_record)
@pytest.mark.asyncio
async def test_forward_component(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [
ForwardNodeComponent(
forward_components=[
ForwardComponent(
message_id="msg1",
user_id="user1",
user_nickname="nickname1",
user_cardname="cardname1",
content=[TextComponent("转发消息1")],
),
ForwardComponent(
message_id="msg2",
user_id="user2",
user_nickname="nickname2",
user_cardname="cardname2",
content=[TextComponent("转发消息2")],
),
]
),
TextComponent("Hello, world!"),
]
await msg.process()
print("Processed plain text:", msg.processed_plain_text)
expected_forward_text = """【合并转发消息:
-- cardname1: 转发消息1
-- cardname2: 转发消息2
Hello, world!"""
assert msg.processed_plain_text == expected_forward_text
@pytest.mark.asyncio
async def test_forward_with_reply(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [
ForwardNodeComponent(
forward_components=[
ForwardComponent(
message_id="msg1",
user_id="user1",
user_nickname="nickname1",
user_cardname="cardname1",
content=[TextComponent("转发消息1")],
),
ForwardComponent(
message_id="msg2",
user_id="user2",
user_nickname="nickname2",
user_cardname="cardname2",
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
),
]
),
TextComponent("Hello, world!"),
]
await msg.process()
assert (
msg.processed_plain_text
== """【合并转发消息:
-- cardname1: 转发消息1
-- cardname2: [回复了cardname1的消息: 转发消息1] 转发消息2
Hello, world!"""
)
@pytest.mark.asyncio
async def test_multiple_reply_with_delay_in_forward(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now())
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
async def delayed_get_voice_text(binary_data):
await asyncio.sleep(0.5) # 模拟延迟
return "这是语音转文本的结果"
sys.modules["src.common.utils.utils_voice"].get_voice_text = delayed_get_voice_text
msg.raw_message.components = [
ForwardNodeComponent(
forward_components=[
ForwardComponent(
message_id="msg1",
user_id="user1",
user_nickname="nickname1",
user_cardname="cardname1",
content=[VoiceComponent(binary_hash="voice_hash1"), TextComponent("转发消息1")],
),
ForwardComponent(
message_id="msg2",
user_id="user2",
user_nickname="nickname2",
user_cardname="cardname2",
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
),
ForwardComponent(
message_id="msg3",
user_id="user3",
user_nickname="nickname3",
user_cardname="cardname3",
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息3")],
),
]
),
]
await msg.process()
expected_text = """【合并转发消息:
-- cardname1: [语音: 这是语音转文本的结果] 转发消息1
-- cardname2: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息2
-- cardname3: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息3
"""
assert msg.processed_plain_text == expected_text

View File

@ -67,6 +67,9 @@ class EmojiManager:
emoji_hash (Optional[str]): 表情包的哈希值如果提供了哈希值则优先使用哈希值查找表情包描述 emoji_hash (Optional[str]): 表情包的哈希值如果提供了哈希值则优先使用哈希值查找表情包描述
Returns: Returns:
return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包则返回包含描述和情感标签的元组若没找到则尝试构建表情包描述并返回如果构建失败则返回 None return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包则返回包含描述和情感标签的元组若没找到则尝试构建表情包描述并返回如果构建失败则返回 None
Raises:
ValueError: 如果既没有提供表情包字节数据也没有提供表情包哈希值则抛出异常
Exception: 如果在缓存表情包的过程中发生错误则抛出异常
""" """
# 先查找 # 先查找
if emoji_hash is None and emoji_bytes is not None: if emoji_hash is None and emoji_bytes is not None:

View File

@ -103,7 +103,7 @@ class ChatManager:
# 内存没有就找db # 内存没有就找db
try: try:
with get_db_session() as db_session: with get_db_session() as db_session:
statement = select(ChatSession).filter_by(session_id=session_id) statement = select(ChatSession).filter_by(session_id=session_id).limit(1)
if result := db_session.exec(statement).first(): if result := db_session.exec(statement).first():
session = BotChatSession.from_db_instance(result) session = BotChatSession.from_db_instance(result)
self.sessions[session.session_id] = session self.sessions[session.session_id] = session
@ -229,7 +229,7 @@ class ChatManager:
"""将会话记录保存到数据库""" """将会话记录保存到数据库"""
with get_db_session() as db_session: with get_db_session() as db_session:
db_instance = session.to_db_instance() db_instance = session.to_db_instance()
statement = select(ChatSession).filter_by(session_id=db_instance.session_id) statement = select(ChatSession).filter_by(session_id=db_instance.session_id).limit(1)
if result := db_session.exec(statement).first(): if result := db_session.exec(statement).first():
result.created_timestamp = db_instance.created_timestamp result.created_timestamp = db_instance.created_timestamp
result.last_active_timestamp = db_instance.last_active_timestamp result.last_active_timestamp = db_instance.last_active_timestamp

View File

@ -1,561 +1,204 @@
import time from asyncio import Task
import asyncio
import urllib3
from abc import abstractmethod
from dataclasses import dataclass
from rich.traceback import install from rich.traceback import install
from typing import Optional, Any, List from sqlmodel import select
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase from typing import List, Dict, Tuple, Sequence
import asyncio
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.common.database.database import get_db_session
from src.chat.utils.utils_image import get_image_manager from src.common.database.database_model import Messages
from src.chat.utils.utils_voice import get_voice_text from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from .chat_stream import ChatStream from src.common.data_models.message_component_data_model import (
TextComponent,
ImageComponent,
EmojiComponent,
AtComponent,
ReplyComponent,
VoiceComponent,
ForwardNodeComponent,
StandardMessageComponents,
)
install(extra_lines=3) install(extra_lines=3)
logger = get_logger("chat_message") logger = get_logger("chat_message")
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# VLM 处理并发限制(避免同时处理太多图片导致卡死) class MsgIDMapping:
_vlm_semaphore = asyncio.Semaphore(3) def __init__(self):
self.mapping: Dict[str, Tuple[str | Task, UserInfo]] = {}
# 这个类是消息数据类,用于存储和管理消息数据。
# 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
# 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass class SessionMessage(MaiMessage):
class Message(MessageBase): async def process(self):
chat_stream: "ChatStream" = None # type: ignore """处理消息内容,识别消息内容并转化为文本"""
reply: Optional["Message"] = None tasks = [self.process_single_component(component, MsgIDMapping()) for component in self.raw_message.components]
processed_plain_text: str = "" results = await asyncio.gather(*tasks, return_exceptions=True)
processed_texts: List[str] = []
for result in results:
if isinstance(result, BaseException):
logger.error(f"处理消息组件时发生错误: {result}")
else:
processed_texts.append(result)
self.processed_plain_text = " ".join(processed_texts)
def __init__( async def process_single_component(
self, self, component: StandardMessageComponents, id_content_map: MsgIDMapping, recursion_depth: int = 0
message_id: str, ) -> str:
chat_stream: "ChatStream", if isinstance(component, TextComponent):
user_info: UserInfo, return component.text
message_segment: Optional[Seg] = None, elif isinstance(component, ImageComponent):
timestamp: Optional[float] = None, return await self.process_image_component(component)
reply: Optional["MessageRecv"] = None, elif isinstance(component, EmojiComponent):
processed_plain_text: str = "", return await self.process_emoji_component(component)
): elif isinstance(component, AtComponent):
# 使用传入的时间戳或当前时间 return await self.process_at_component(component)
current_timestamp = timestamp if timestamp is not None else round(time.time(), 3) elif isinstance(component, VoiceComponent):
# 构造基础消息信息 return await self.process_voice_component(component)
message_info = BaseMessageInfo( elif isinstance(component, ReplyComponent):
platform=chat_stream.platform, return await self.process_reply_component(component, id_content_map)
message_id=message_id, elif isinstance(component, ForwardNodeComponent):
time=current_timestamp, return await self.process_forward_component(component, id_content_map, recursion_depth=recursion_depth + 1)
group_info=chat_stream.group_info,
user_info=user_info,
)
# 调用父类初始化
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = processed_plain_text
# 回复消息
self.reply = reply
async def _process_message_segments(self, segment: Seg) -> str:
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表 - 使用并行处理提升性能
tasks = [self._process_message_segments(seg) for seg in segment.data] # type: ignore
results = await asyncio.gather(*tasks, return_exceptions=True)
segments_text = []
for result in results:
if isinstance(result, Exception):
logger.error(f"处理消息段时出错: {result}")
continue
if result:
segments_text.append(result)
return " ".join(segments_text)
elif segment.type == "forward":
# 处理转发消息 - 使用并行处理
async def process_forward_node(node_dict):
message = MessageBase.from_dict(node_dict) # type: ignore
processed_text = await self._process_message_segments(message.message_segment)
if processed_text:
return f"{global_config.bot.nickname}: {processed_text}"
return None
tasks = [process_forward_node(node_dict) for node_dict in segment.data]
results = await asyncio.gather(*tasks, return_exceptions=True)
segments_text = []
for result in results:
if isinstance(result, Exception):
logger.error(f"处理转发节点时出错: {result}")
continue
if result:
segments_text.append(result)
return "[合并消息]: " + "\n-- ".join(segments_text)
else: else:
# 处理单个消息段 raise NotImplementedError(f"暂时不支持的消息组件类型: {type(component)}")
return await self._process_single_segment(segment) # type: ignore
@abstractmethod async def process_image_component(self, component: ImageComponent) -> str:
async def _process_single_segment(self, segment) -> str: if component.content: # 先检查是否处理过
pass return component.content
from src.chat.image_system.image_manager import image_manager
# 获取描述
@dataclass
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.is_emoji = False
self.has_emoji = False
self.is_picid = False
self.has_picid = False
self.is_voice = False
self.is_mentioned = None
self.is_at = False
self.reply_probability_boost = 0.0
self.is_notify = False
self.is_command = False
self.intercept_message_level = 0
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = None # type: ignore
self.key_words = []
self.key_words_lite = []
# 兼容适配器通过 additional_config 传入的 @ 标记
try: try:
msg_info_dict = message_dict.get("message_info", {}) desc = await image_manager.get_image_description(image_bytes=component.binary_data)
add_cfg = msg_info_dict.get("additional_config") or {}
if isinstance(add_cfg, dict) and add_cfg.get("at_bot"):
# 标记为被提及,提高后续回复优先级
self.is_mentioned = True # type: ignore
except Exception: except Exception:
pass desc = None
def update_chat_stream(self, chat_stream: "ChatStream"): content = f"[图片:{desc}]" if desc else "[发了一张图片,网卡了加载不出来]"
self.chat_stream = chat_stream component.content = content
return content
async def process(self) -> None: async def process_emoji_component(self, component: EmojiComponent) -> str:
"""处理消息内容,生成纯文本和详细文本 if component.content: # 先检查是否处理过
return component.content
from src.chat.emoji_system.emoji_manager import emoji_manager
这个方法必须在创建实例后显式调用因为它包含异步操作 # 获取表情包描述
"""
# print(f"self.message_segment: {self.message_segment}")
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try: try:
if segment.type == "text": tuple_content = await emoji_manager.get_emoji_description(emoji_bytes=component.binary_data)
self.is_picid = False except Exception:
self.is_emoji = False tuple_content = None
return segment.data # type: ignore
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
image_manager = get_image_manager()
# 使用 semaphore 限制 VLM 并发,避免同时处理太多图片
async with _vlm_semaphore:
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
self.is_voice = False
if isinstance(segment.data, str):
# 使用 semaphore 限制 VLM 并发
async with _vlm_semaphore:
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False
self.is_emoji = False
self.is_voice = True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "video_card":
# 处理视频卡片消息
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
file_name = segment.data.get("file", "未知视频")
file_size = segment.data.get("file_size", "")
url = segment.data.get("url", "")
text = f"[视频: {file_name}"
if file_size:
text += f", 大小: {file_size}字节"
text += "]"
if url:
text += f" 链接: {url}"
return text
return "[视频]"
elif segment.type == "music_card":
# 处理音乐卡片消息
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
title = segment.data.get("title", "未知歌曲")
singer = segment.data.get("singer", "")
tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐"
jump_url = segment.data.get("jump_url", "")
music_url = segment.data.get("music_url", "")
text = f"[音乐: {title}"
if singer:
text += f" - {singer}"
if tag:
text += f" ({tag})"
text += "]"
if jump_url:
text += f" 跳转链接: {jump_url}"
if music_url:
text += f" 音乐链接: {music_url}"
return text
return "[音乐]"
elif segment.type == "miniapp_card":
# 处理小程序分享卡片如B站视频分享
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
title = segment.data.get("title", "") # 小程序名称
desc = segment.data.get("desc", "") # 内容描述
source_url = segment.data.get("source_url", "") # 原始链接
url = segment.data.get("url", "") # 小程序链接
text = "[小程序分享"
if title:
text += f" - {title}"
text += "]"
if desc:
text += f" {desc}"
if source_url:
text += f" 链接: {source_url}"
elif url:
text += f" 链接: {url}"
return text
return "[小程序分享]"
else:
return ""
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
if tuple_content:
desc, _ = tuple_content
content = f"[表情包: {desc}]"
else:
content = "[发了一个表情,网卡了加载不出来]"
component.content = content
return content
@dataclass async def process_at_component(self, component: AtComponent) -> str:
class MessageProcessBase(Message): if component.target_user_cardname:
"""消息处理基类,用于处理中和发送中的消息""" return f"@{component.target_user_cardname}"
elif component.target_user_nickname:
return f"@{component.target_user_nickname}"
from src.common.utils.utils_person import PersonUtils
def __init__( if person_info := PersonUtils.get_person_info_by_user_id_and_platform(component.target_user_id, self.platform):
component.target_user_nickname = component.target_user_nickname or person_info.user_nickname
if self.message_info.group_info and person_info.group_cardname_list:
for group_card in person_info.group_cardname_list:
if group_card.group_id == self.message_info.group_info.group_id:
component.target_user_cardname = group_card.group_cardname
break
if component.target_user_cardname:
return f"@{component.target_user_cardname}"
elif component.target_user_nickname:
return f"@{component.target_user_nickname}"
else:
return f"@{component.target_user_id}"
async def process_voice_component(self, component: VoiceComponent) -> str:
if component.content: # 先检查是否处理过
return component.content
from src.common.utils.utils_voice import get_voice_text
text = await get_voice_text(component.binary_data)
content = "[语音消息,转录失败]" if text is None else f"[语音: {text}]"
component.content = content
return content
async def process_reply_component(
self, self,
message_id: str, component: ReplyComponent,
chat_stream: "ChatStream", id_content_map: MsgIDMapping,
bot_user_info: UserInfo, ) -> str:
message_segment: Optional[Seg] = None, if component.target_message_content:
reply: Optional["MessageRecv"] = None, return component.target_message_content
thinking_start_time: float = 0, if result_item := id_content_map.mapping.get(component.target_message_id):
timestamp: Optional[float] = None, content, sender_info = result_item
): if isinstance(content, Task):
# 调用父类初始化,传递时间戳 content = await content
super().__init__( id_content_map.mapping[component.target_message_id] = (content, sender_info) # 更新为实际内容
message_id=message_id, component.target_message_content = content
timestamp=timestamp, tgt_msg_s_name = sender_info.user_cardname or sender_info.user_nickname or sender_info.user_id
chat_stream=chat_stream, component.target_message_sender_cardname = sender_info.user_cardname
user_info=bot_user_info, component.target_message_sender_nickname = sender_info.user_nickname
message_segment=message_segment, component.target_message_sender_id = sender_info.user_id
reply=reply, return f"[回复了{tgt_msg_s_name}的消息: {content}]"
) else:
try:
with get_db_session() as session:
statement = select(Messages).filter_by(message_id=component.target_message_id).limit(1)
if db_msg := session.exec(statement).first():
component.target_message_content = db_msg.processed_plain_text
component.target_message_sender_cardname = db_msg.user_cardname
component.target_message_sender_nickname = db_msg.user_nickname
component.target_message_sender_id = db_msg.user_id
tgt_msg_s_name = db_msg.user_cardname or db_msg.user_nickname or db_msg.user_id
return f"[回复了{tgt_msg_s_name}的消息: {db_msg.processed_plain_text}]"
except Exception as e:
logger.error(f"查询回复消息时发生错误: {e}")
# 处理状态相关属性 return "[回复了一条消息,但原消息已无法访问]"
self.thinking_start_time = thinking_start_time
self.thinking_time = 0
def update_thinking_time(self) -> float: async def process_forward_component(
"""更新思考时间""" self, component: ForwardNodeComponent, id_content_map: MsgIDMapping, recursion_depth: int = 0
self.thinking_time = round(time.time() - self.thinking_start_time, 2) ) -> str:
return self.thinking_time task_list: List[Task] = []
node_user_info_list: List[UserInfo] = []
async def _process_single_segment(self, segment: Seg) -> str: for node in component.forward_components:
"""处理单个消息段 task = asyncio.create_task(
self._process_multiple_components(node.content, id_content_map, recursion_depth + 1)
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
return segment.data # type: ignore
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
return await get_image_manager().get_image_description(segment.data)
return "[图片,网卡了加载不出来]"
elif segment.type == "emoji":
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_tag(segment.data)
return "[表情,网卡了加载不出来]"
elif segment.type == "voice":
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "at":
return f"[@{segment.data}]"
elif segment.type == "reply":
if self.reply and hasattr(self.reply, "processed_plain_text"):
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
# print(f"reply: {self.reply}")
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
return ""
else:
return f"[{segment.type}:{str(segment.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n"
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
sender_info: UserInfo | None, # 用来记录发送者信息
message_segment: Seg,
display_message: str = "",
reply: Optional["MessageRecv"] = None,
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: Optional[str] = None,
selected_expressions: Optional[List[int]] = None,
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply,
thinking_start_time=thinking_start_time,
)
# 发送状态特有属性
self.sender_info = sender_info
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic
self.reply_to = reply_to
# 用于显示发送内容与显示不一致的情况
self.display_message = display_message
self.interest_value = 0.0
self.selected_expressions = selected_expressions
def build_reply(self):
"""设置回复消息"""
if self.reply:
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
self.message_segment,
],
) )
node_user_info = UserInfo(node.user_id or "未知用户", node.user_nickname, node.user_cardname)
async def process(self) -> None: id_content_map.mapping[node.message_id] = (task, node_user_info)
"""处理消息内容,生成纯文本和详细文本""" task_list.append(task)
if self.message_segment: node_user_info_list.append(node_user_info)
self.processed_plain_text = await self._process_message_segments(self.message_segment) results = await asyncio.gather(*task_list, return_exceptions=True)
forward_texts = []
def to_dict(self): for idx, result in enumerate(results):
ret = super().to_dict() if isinstance(result, BaseException):
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict() logger.error(f"处理转发消息组件时发生错误: {result}")
return ret
def is_private_message(self) -> bool:
"""判断是否为私聊消息"""
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
@dataclass
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, chat_stream: "ChatStream", message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: list[MessageSending] = []
self.time = round(time.time(), 3) # 保留3位小数
def add_message(self, message: MessageSending) -> None:
"""添加消息到集合"""
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
return self.messages[index] if 0 <= index < len(self.messages) else None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].message_info.time < target_time: # type: ignore
left = mid + 1
else: else:
right = mid usr_info = node_user_info_list[idx]
msg_sender_name = usr_info.user_cardname or usr_info.user_nickname or usr_info.user_id or "未知用户"
forward_texts.append(f"{'-' * recursion_depth * 2}{msg_sender_name}】: {result}")
return "【合并转发消息: \n" + "\n".join(forward_texts) + "\n"
return self.messages[left] async def _process_multiple_components(
self, components: Sequence[StandardMessageComponents], id_content_map: MsgIDMapping, recursion_depth: int = 0
def clear_messages(self) -> None: ) -> str:
"""清空所有消息""" tasks = [
self.messages.clear() self.process_single_component(component, id_content_map, recursion_depth=recursion_depth)
for component in components
def remove_message(self, message: MessageSending) -> bool: ]
"""移除指定消息""" results = await asyncio.gather(*tasks, return_exceptions=True)
if message in self.messages: processed_texts: List[str] = []
self.messages.remove(message) for result in results:
return True if isinstance(result, BaseException):
return False logger.error(f"处理消息组件时发生错误: {result}")
else:
def __str__(self) -> str: processed_texts.append(result)
return f"MessageSet(id={self.message_id}, count={len(self.messages)})" return " ".join(processed_texts)
def __len__(self) -> int:
return len(self.messages)
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
return MessageRecv(message_dict)
def message_from_db_dict(db_dict: dict) -> MessageRecv:
"""从数据库字典创建MessageRecv实例"""
# 转换扁平的数据库字典为嵌套结构
message_info_dict = {
"platform": db_dict.get("chat_info_platform"),
"message_id": db_dict.get("message_id"),
"time": db_dict.get("time"),
"group_info": {
"platform": db_dict.get("chat_info_group_platform"),
"group_id": db_dict.get("chat_info_group_id"),
"group_name": db_dict.get("chat_info_group_name"),
},
"user_info": {
"platform": db_dict.get("user_platform"),
"user_id": db_dict.get("user_id"),
"user_nickname": db_dict.get("user_nickname"),
"user_cardname": db_dict.get("user_cardname"),
},
}
processed_text = db_dict.get("processed_plain_text", "")
# 构建 MessageRecv 需要的字典
recv_dict = {
"message_info": message_info_dict,
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
"raw_message": None, # 数据库中未存储原始消息
"processed_plain_text": processed_text,
}
# 创建 MessageRecv 实例
msg = MessageRecv(recv_dict)
# 从数据库字典中填充其他可选字段
msg.interest_value = db_dict.get("interest_value", 0.0)
msg.is_mentioned = db_dict.get("is_mentioned")
msg.priority_mode = db_dict.get("priority_mode", "interest")
msg.priority_info = db_dict.get("priority_info")
msg.is_emoji = db_dict.get("is_emoji", False)
msg.is_picid = db_dict.get("is_picid", False)
return msg

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import deepcopy from copy import deepcopy
from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo
from typing import Optional, List, Union, Dict, Any from typing import Optional, List, Union, Dict, Any, Sequence
import asyncio import asyncio
import hashlib import hashlib
@ -142,9 +142,9 @@ class AtComponent(BaseMessageComponentModel):
) -> None: ) -> None:
self.target_user_id = target_user_id self.target_user_id = target_user_id
"""目标用户ID""" """目标用户ID"""
self.target_user_nickname = target_user_nickname self.target_user_nickname: Optional[str] = target_user_nickname
"""目标用户昵称""" """目标用户昵称"""
self.target_user_cardname = target_user_cardname self.target_user_cardname: Optional[str] = target_user_cardname
"""目标用户备注名""" """目标用户备注名"""
assert isinstance(target_user_id, str), "AtComponent 的 target_user_id 必须是字符串类型" assert isinstance(target_user_id, str), "AtComponent 的 target_user_id 必须是字符串类型"
@ -159,10 +159,25 @@ class ReplyComponent(BaseMessageComponentModel):
def format_name(self) -> str: def format_name(self) -> str:
return "reply" return "reply"
def __init__(self, target_message_id: str) -> None: def __init__(
self,
target_message_id: str,
target_message_content: Optional[str] = None,
target_message_sender_id: Optional[str] = None,
target_message_sender_nickname: Optional[str] = None,
target_message_sender_cardname: Optional[str] = None,
) -> None:
assert isinstance(target_message_id, str), "ReplyComponent 的 target_message_id 必须是字符串类型" assert isinstance(target_message_id, str), "ReplyComponent 的 target_message_id 必须是字符串类型"
self.target_message_id = target_message_id self.target_message_id = target_message_id
"""目标消息ID""" """目标消息ID"""
self.target_message_content: Optional[str] = target_message_content
"""目标消息内容"""
self.target_message_sender_id: Optional[str] = target_message_sender_id
"""目标消息发送者ID"""
self.target_message_sender_nickname: Optional[str] = target_message_sender_nickname
"""目标消息发送者昵称"""
self.target_message_sender_cardname: Optional[str] = target_message_sender_cardname
"""目标消息发送者群昵称"""
async def to_seg(self) -> Seg: async def to_seg(self) -> Seg:
return Seg(type="reply", data=self.target_message_id) return Seg(type="reply", data=self.target_message_id)
@ -224,7 +239,7 @@ class ForwardComponent(BaseMessageComponentModel):
self, self,
user_nickname: str, user_nickname: str,
message_id: str, message_id: str,
content: List[StandardMessageComponents], content: Sequence[StandardMessageComponents],
user_id: Optional[str] = None, user_id: Optional[str] = None,
user_cardname: Optional[str] = None, user_cardname: Optional[str] = None,
): ):
@ -232,7 +247,7 @@ class ForwardComponent(BaseMessageComponentModel):
"""转发节点的发送者昵称""" """转发节点的发送者昵称"""
self.message_id: str = message_id self.message_id: str = message_id
"""转发节点的消息ID""" """转发节点的消息ID"""
self.content: List[StandardMessageComponents] = content self.content: Sequence[StandardMessageComponents] = content
"""消息内容""" """消息内容"""
self.user_id: Optional[str] = user_id self.user_id: Optional[str] = user_id
"""转发节点的发送者ID可能为 None""" """转发节点的发送者ID可能为 None"""
@ -249,7 +264,7 @@ class ForwardComponent(BaseMessageComponentModel):
class MessageSequence: class MessageSequence:
"""消息组件序列,包含一个消息中的所有组件,按照顺序排列""" """消息组件序列,包含一个消息中的所有组件,按照顺序排列"""
def __init__(self, components: List[StandardMessageComponents]): def __init__(self, components: Sequence[StandardMessageComponents]):
""" """
创建一个消息组件序列 创建一个消息组件序列
@ -259,16 +274,16 @@ class MessageSequence:
因此也可以包含多个`ReplyComponent`组件例如回复多条消息 因此也可以包含多个`ReplyComponent`组件例如回复多条消息
如果需要对组件进行去重或校验还请在使用时自行处理 如果需要对组件进行去重或校验还请在使用时自行处理
""" """
self.components: List[StandardMessageComponents] = components self.components: Sequence[StandardMessageComponents] = components
def to_dict(self) -> List[Dict[str, Any]]: def to_dict(self) -> List[Dict[str, Any]]:
"""将消息序列转换为字典列表格式,便于存储或传输""" """将消息序列转换为字典列表格式,便于存储或传输"""
return [self._item_2_dict(comp) for comp in self.components] return [self._item_2_dict(comp) for comp in self.components]
@classmethod @classmethod
def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence": def from_dict(cls, data: List[Dict[str, Any]]):
"""从字典列表格式创建消息序列实例""" """从字典列表格式创建消息序列实例"""
components: List[StandardMessageComponents] = [] components: Sequence[StandardMessageComponents] = []
components.extend(cls._dict_2_item(item) for item in data) components.extend(cls._dict_2_item(item) for item in data)
return cls(components=components) return cls(components=components)

View File

@ -1,5 +1,5 @@
from maim_message import MessageBase, Seg from maim_message import MessageBase, Seg
from typing import List, Tuple, Optional from typing import List, Tuple, Optional, Sequence
import base64 import base64
import hashlib import hashlib
@ -35,7 +35,7 @@ class MessageUtils:
def from_maim_message_segments_to_MaiSeq(message: "MessageBase") -> MessageSequence: def from_maim_message_segments_to_MaiSeq(message: "MessageBase") -> MessageSequence:
"""从maim_message.MessageBase.message_segment转换为MessageSequence""" """从maim_message.MessageBase.message_segment转换为MessageSequence"""
raw_msg_seq = message.message_segment raw_msg_seq = message.message_segment
components: List[StandardMessageComponents] = [] components: Sequence[StandardMessageComponents] = []
if not raw_msg_seq: if not raw_msg_seq:
return MessageSequence(components) return MessageSequence(components)
if raw_msg_seq.type == "seglist": if raw_msg_seq.type == "seglist":

View File

@ -20,7 +20,7 @@ class PersonUtils:
"""根据person_id获取用户信息""" """根据person_id获取用户信息"""
try: try:
with get_db_session() as session: with get_db_session() as session:
statement = select(PersonInfo).filter_by(person_id=person_id) statement = select(PersonInfo).filter_by(person_id=person_id).limit(1)
if result := session.exec(statement).first(): if result := session.exec(statement).first():
return MaiPersonInfo.from_db_instance(result) return MaiPersonInfo.from_db_instance(result)
except Exception as e: except Exception as e: