mirror of https://github.com/Mai-with-u/MaiBot.git
全新的process方法完成(Message其他部分仍未完成);对应测试;调整部分注释;数据库检索优化
parent
698b8355a4
commit
0d07e85434
|
|
@ -0,0 +1,420 @@
|
|||
import sys
|
||||
import asyncio
|
||||
import pytest
|
||||
import importlib
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent
|
||||
from src.chat.message_receive.message import (
|
||||
SessionMessage,
|
||||
TextComponent,
|
||||
ImageComponent,
|
||||
EmojiComponent,
|
||||
VoiceComponent,
|
||||
AtComponent,
|
||||
ReplyComponent,
|
||||
ForwardNodeComponent,
|
||||
StandardMessageComponents,
|
||||
)
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def __init__(self) -> None:
|
||||
self.logging_record = []
|
||||
|
||||
def debug(self, msg):
|
||||
print(f"DEBUG: {msg}")
|
||||
self.logging_record.append(f"DEBUG: {msg}")
|
||||
|
||||
def info(self, msg):
|
||||
print(f"INFO: {msg}")
|
||||
self.logging_record.append(f"INFO: {msg}")
|
||||
|
||||
def warning(self, msg):
|
||||
print(f"WARNING: {msg}")
|
||||
self.logging_record.append(f"WARNING: {msg}")
|
||||
|
||||
def error(self, msg):
|
||||
print(f"ERROR: {msg}")
|
||||
self.logging_record.append(f"ERROR: {msg}")
|
||||
|
||||
def critical(self, msg):
|
||||
print(f"CRITICAL: {msg}")
|
||||
self.logging_record.append(f"CRITICAL: {msg}")
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
return DummyLogger()
|
||||
|
||||
|
||||
class DummyDBSession:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
def exec(self, statement):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return None
|
||||
|
||||
def commit(self):
|
||||
pass
|
||||
|
||||
def all(self):
|
||||
return []
|
||||
|
||||
|
||||
def get_db_session():
|
||||
return DummyDBSession()
|
||||
|
||||
|
||||
def get_manual_db_session():
|
||||
return DummyDBSession()
|
||||
|
||||
|
||||
class DummySelect:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def filter_by(self, **kwargs):
|
||||
return self
|
||||
|
||||
def where(self, condition):
|
||||
return self
|
||||
|
||||
def limit(self, n):
|
||||
return self
|
||||
|
||||
|
||||
def select(model):
|
||||
return DummySelect(model)
|
||||
|
||||
|
||||
async def dummy_get_voice_text(binary_data):
|
||||
return None # 可以根据需要返回模拟的文本结果
|
||||
|
||||
|
||||
class DummyPersonUtils:
|
||||
@staticmethod
|
||||
def get_person_info_by_user_id_and_platform(user_id, platform):
|
||||
return None # 可以根据需要返回模拟的用户信息
|
||||
|
||||
|
||||
def setup_mocks(monkeypatch):
|
||||
def _stub_module(name: str) -> ModuleType:
|
||||
module = ModuleType(name)
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
return module
|
||||
|
||||
# src.common.logger
|
||||
logger_mod = _stub_module("src.common.logger")
|
||||
# Mock the logger
|
||||
logger_mod.get_logger = get_logger
|
||||
|
||||
db_mod = _stub_module("src.common.database.database")
|
||||
db_mod.get_db_session = get_db_session
|
||||
db_mod.get_manual_db_session = get_manual_db_session
|
||||
|
||||
emoji_manager_mod = _stub_module("src.chat.emoji_system.emoji_manager")
|
||||
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
|
||||
image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
msg_utils_mod = _stub_module("src.common.utils.utils_message")
|
||||
msg_utils_mod.MessageUtils = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
voice_utils_mod = _stub_module("src.common.utils.utils_voice")
|
||||
voice_utils_mod.get_voice_text = dummy_get_voice_text
|
||||
|
||||
person_utils_mod = _stub_module("src.common.utils.utils_person")
|
||||
person_utils_mod.PersonUtils = DummyPersonUtils
|
||||
|
||||
|
||||
def load_message_via_file(monkeypatch):
|
||||
setup_mocks(monkeypatch)
|
||||
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py"
|
||||
spec = importlib.util.spec_from_file_location("message", file_path)
|
||||
message_module = importlib.util.module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, "message_module", message_module)
|
||||
spec.loader.exec_module(message_module)
|
||||
message_module.select = select
|
||||
SessionMessageClass = message_module.SessionMessage
|
||||
TextComponentClass = message_module.TextComponent
|
||||
ImageComponentClass = message_module.ImageComponent
|
||||
EmojiComponentClass = message_module.EmojiComponent
|
||||
VoiceComponentClass = message_module.VoiceComponent
|
||||
AtComponentClass = message_module.AtComponent
|
||||
ReplyComponentClass = message_module.ReplyComponent
|
||||
ForwardNodeComponentClass = message_module.ForwardNodeComponent
|
||||
MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence
|
||||
ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent
|
||||
globals()["SessionMessage"] = SessionMessageClass
|
||||
globals()["TextComponent"] = TextComponentClass
|
||||
globals()["ImageComponent"] = ImageComponentClass
|
||||
globals()["EmojiComponent"] = EmojiComponentClass
|
||||
globals()["VoiceComponent"] = VoiceComponentClass
|
||||
globals()["AtComponent"] = AtComponentClass
|
||||
globals()["ReplyComponent"] = ReplyComponentClass
|
||||
globals()["ForwardNodeComponent"] = ForwardNodeComponentClass
|
||||
globals()["MessageSequence"] = MessageSequenceClass
|
||||
globals()["ForwardComponent"] = ForwardComponentClass
|
||||
return message_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_text(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [TextComponent("Hello,"), TextComponent("world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [ImageComponent(binary_hash="image_hash"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[发了一张图片,网卡了加载不出来] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emoji(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [EmojiComponent(binary_hash="emoji_hash"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[发了一个表情,网卡了加载不出来] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [VoiceComponent(binary_hash="voice_hash"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[语音消息,转录失败] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_at_component(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [AtComponent(target_user_id="114514"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "@114514 Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_component_fail_to_fetch(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_component_success(monkeypatch):
|
||||
module_msg = load_message_via_file(monkeypatch)
|
||||
|
||||
class DummyDBSessionWithReply(DummyDBSession):
|
||||
def exec(self, s):
|
||||
return self
|
||||
|
||||
def first(inner_self):
|
||||
class DummyRecord:
|
||||
processed_plain_text = "原消息内容"
|
||||
user_cardname = "cardname123"
|
||||
user_nickname = "nickname123"
|
||||
user_id = "userid123"
|
||||
|
||||
return DummyRecord()
|
||||
|
||||
module_msg.get_db_session = lambda: DummyDBSessionWithReply()
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[回复了cardname123的消息: 原消息内容] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_component_with_db_fail(monkeypatch):
|
||||
module_msg = load_message_via_file(monkeypatch)
|
||||
|
||||
class DummyDBSessionWithError(DummyDBSession):
|
||||
def exec(self, s):
|
||||
raise Exception("数据库查询失败")
|
||||
|
||||
module_msg.get_db_session = lambda: DummyDBSessionWithError()
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
|
||||
assert any("数据库查询失败" in log for log in module_msg.logger.logging_record)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_component(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [
|
||||
ForwardNodeComponent(
|
||||
forward_components=[
|
||||
ForwardComponent(
|
||||
message_id="msg1",
|
||||
user_id="user1",
|
||||
user_nickname="nickname1",
|
||||
user_cardname="cardname1",
|
||||
content=[TextComponent("转发消息1")],
|
||||
),
|
||||
ForwardComponent(
|
||||
message_id="msg2",
|
||||
user_id="user2",
|
||||
user_nickname="nickname2",
|
||||
user_cardname="cardname2",
|
||||
content=[TextComponent("转发消息2")],
|
||||
),
|
||||
]
|
||||
),
|
||||
TextComponent("Hello, world!"),
|
||||
]
|
||||
await msg.process()
|
||||
print("Processed plain text:", msg.processed_plain_text)
|
||||
expected_forward_text = """【合并转发消息:
|
||||
-- 【cardname1】: 转发消息1
|
||||
-- 【cardname2】: 转发消息2
|
||||
】 Hello, world!"""
|
||||
assert msg.processed_plain_text == expected_forward_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_with_reply(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [
|
||||
ForwardNodeComponent(
|
||||
forward_components=[
|
||||
ForwardComponent(
|
||||
message_id="msg1",
|
||||
user_id="user1",
|
||||
user_nickname="nickname1",
|
||||
user_cardname="cardname1",
|
||||
content=[TextComponent("转发消息1")],
|
||||
),
|
||||
ForwardComponent(
|
||||
message_id="msg2",
|
||||
user_id="user2",
|
||||
user_nickname="nickname2",
|
||||
user_cardname="cardname2",
|
||||
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
|
||||
),
|
||||
]
|
||||
),
|
||||
TextComponent("Hello, world!"),
|
||||
]
|
||||
await msg.process()
|
||||
assert (
|
||||
msg.processed_plain_text
|
||||
== """【合并转发消息:
|
||||
-- 【cardname1】: 转发消息1
|
||||
-- 【cardname2】: [回复了cardname1的消息: 转发消息1] 转发消息2
|
||||
】 Hello, world!"""
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_reply_with_delay_in_forward(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
|
||||
async def delayed_get_voice_text(binary_data):
|
||||
await asyncio.sleep(0.5) # 模拟延迟
|
||||
return "这是语音转文本的结果"
|
||||
|
||||
sys.modules["src.common.utils.utils_voice"].get_voice_text = delayed_get_voice_text
|
||||
|
||||
msg.raw_message.components = [
|
||||
ForwardNodeComponent(
|
||||
forward_components=[
|
||||
ForwardComponent(
|
||||
message_id="msg1",
|
||||
user_id="user1",
|
||||
user_nickname="nickname1",
|
||||
user_cardname="cardname1",
|
||||
content=[VoiceComponent(binary_hash="voice_hash1"), TextComponent("转发消息1")],
|
||||
),
|
||||
ForwardComponent(
|
||||
message_id="msg2",
|
||||
user_id="user2",
|
||||
user_nickname="nickname2",
|
||||
user_cardname="cardname2",
|
||||
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
|
||||
),
|
||||
ForwardComponent(
|
||||
message_id="msg3",
|
||||
user_id="user3",
|
||||
user_nickname="nickname3",
|
||||
user_cardname="cardname3",
|
||||
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息3")],
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
await msg.process()
|
||||
expected_text = """【合并转发消息:
|
||||
-- 【cardname1】: [语音: 这是语音转文本的结果] 转发消息1
|
||||
-- 【cardname2】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息2
|
||||
-- 【cardname3】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息3
|
||||
】"""
|
||||
assert msg.processed_plain_text == expected_text
|
||||
|
|
@ -67,6 +67,9 @@ class EmojiManager:
|
|||
emoji_hash (Optional[str]): 表情包的哈希值,如果提供了哈希值则优先使用哈希值查找表情包描述
|
||||
Returns:
|
||||
return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包,则返回包含描述和情感标签的元组;若没找到,则尝试构建表情包描述并返回,如果构建失败则返回 None
|
||||
Raises:
|
||||
ValueError: 如果既没有提供表情包字节数据,也没有提供表情包哈希值,则抛出异常
|
||||
Exception: 如果在缓存表情包的过程中发生错误,则抛出异常
|
||||
"""
|
||||
# 先查找
|
||||
if emoji_hash is None and emoji_bytes is not None:
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class ChatManager:
|
|||
# 内存没有就找db
|
||||
try:
|
||||
with get_db_session() as db_session:
|
||||
statement = select(ChatSession).filter_by(session_id=session_id)
|
||||
statement = select(ChatSession).filter_by(session_id=session_id).limit(1)
|
||||
if result := db_session.exec(statement).first():
|
||||
session = BotChatSession.from_db_instance(result)
|
||||
self.sessions[session.session_id] = session
|
||||
|
|
@ -229,7 +229,7 @@ class ChatManager:
|
|||
"""将会话记录保存到数据库"""
|
||||
with get_db_session() as db_session:
|
||||
db_instance = session.to_db_instance()
|
||||
statement = select(ChatSession).filter_by(session_id=db_instance.session_id)
|
||||
statement = select(ChatSession).filter_by(session_id=db_instance.session_id).limit(1)
|
||||
if result := db_session.exec(statement).first():
|
||||
result.created_timestamp = db_instance.created_timestamp
|
||||
result.last_active_timestamp = db_instance.last_active_timestamp
|
||||
|
|
|
|||
|
|
@ -1,561 +1,204 @@
|
|||
import time
|
||||
import asyncio
|
||||
import urllib3
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from asyncio import Task
|
||||
from rich.traceback import install
|
||||
from typing import Optional, Any, List
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
from sqlmodel import select
|
||||
from typing import List, Dict, Tuple, Sequence
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from .chat_stream import ChatStream
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
TextComponent,
|
||||
ImageComponent,
|
||||
EmojiComponent,
|
||||
AtComponent,
|
||||
ReplyComponent,
|
||||
VoiceComponent,
|
||||
ForwardNodeComponent,
|
||||
StandardMessageComponents,
|
||||
)
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_message")
|
||||
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# VLM 处理并发限制(避免同时处理太多图片导致卡死)
|
||||
_vlm_semaphore = asyncio.Semaphore(3)
|
||||
|
||||
# 这个类是消息数据类,用于存储和管理消息数据。
|
||||
# 它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
|
||||
# 它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
||||
class MsgIDMapping:
|
||||
def __init__(self):
|
||||
self.mapping: Dict[str, Tuple[str | Task, UserInfo]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message(MessageBase):
|
||||
chat_stream: "ChatStream" = None # type: ignore
|
||||
reply: Optional["Message"] = None
|
||||
processed_plain_text: str = ""
|
||||
class SessionMessage(MaiMessage):
|
||||
async def process(self):
|
||||
"""处理消息内容,识别消息内容并转化为文本"""
|
||||
tasks = [self.process_single_component(component, MsgIDMapping()) for component in self.raw_message.components]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
processed_texts: List[str] = []
|
||||
for result in results:
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"处理消息组件时发生错误: {result}")
|
||||
else:
|
||||
processed_texts.append(result)
|
||||
self.processed_plain_text = " ".join(processed_texts)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
processed_plain_text: str = "",
|
||||
):
|
||||
# 使用传入的时间戳或当前时间
|
||||
current_timestamp = timestamp if timestamp is not None else round(time.time(), 3)
|
||||
# 构造基础消息信息
|
||||
message_info = BaseMessageInfo(
|
||||
platform=chat_stream.platform,
|
||||
message_id=message_id,
|
||||
time=current_timestamp,
|
||||
group_info=chat_stream.group_info,
|
||||
user_info=user_info,
|
||||
)
|
||||
|
||||
# 调用父类初始化
|
||||
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
|
||||
|
||||
self.chat_stream = chat_stream
|
||||
# 文本处理相关属性
|
||||
self.processed_plain_text = processed_plain_text
|
||||
|
||||
# 回复消息
|
||||
self.reply = reply
|
||||
|
||||
async def _process_message_segments(self, segment: Seg) -> str:
|
||||
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
|
||||
"""递归处理消息段,转换为文字描述
|
||||
|
||||
Args:
|
||||
segment: 要处理的消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表 - 使用并行处理提升性能
|
||||
tasks = [self._process_message_segments(seg) for seg in segment.data] # type: ignore
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
segments_text = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"处理消息段时出错: {result}")
|
||||
continue
|
||||
if result:
|
||||
segments_text.append(result)
|
||||
return " ".join(segments_text)
|
||||
elif segment.type == "forward":
|
||||
# 处理转发消息 - 使用并行处理
|
||||
async def process_forward_node(node_dict):
|
||||
message = MessageBase.from_dict(node_dict) # type: ignore
|
||||
processed_text = await self._process_message_segments(message.message_segment)
|
||||
if processed_text:
|
||||
return f"{global_config.bot.nickname}: {processed_text}"
|
||||
return None
|
||||
|
||||
tasks = [process_forward_node(node_dict) for node_dict in segment.data]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
segments_text = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"处理转发节点时出错: {result}")
|
||||
continue
|
||||
if result:
|
||||
segments_text.append(result)
|
||||
return "[合并消息]: " + "\n-- ".join(segments_text)
|
||||
async def process_single_component(
|
||||
self, component: StandardMessageComponents, id_content_map: MsgIDMapping, recursion_depth: int = 0
|
||||
) -> str:
|
||||
if isinstance(component, TextComponent):
|
||||
return component.text
|
||||
elif isinstance(component, ImageComponent):
|
||||
return await self.process_image_component(component)
|
||||
elif isinstance(component, EmojiComponent):
|
||||
return await self.process_emoji_component(component)
|
||||
elif isinstance(component, AtComponent):
|
||||
return await self.process_at_component(component)
|
||||
elif isinstance(component, VoiceComponent):
|
||||
return await self.process_voice_component(component)
|
||||
elif isinstance(component, ReplyComponent):
|
||||
return await self.process_reply_component(component, id_content_map)
|
||||
elif isinstance(component, ForwardNodeComponent):
|
||||
return await self.process_forward_component(component, id_content_map, recursion_depth=recursion_depth + 1)
|
||||
else:
|
||||
# 处理单个消息段
|
||||
return await self._process_single_segment(segment) # type: ignore
|
||||
raise NotImplementedError(f"暂时不支持的消息组件类型: {type(component)}")
|
||||
|
||||
@abstractmethod
|
||||
async def _process_single_segment(self, segment) -> str:
|
||||
pass
|
||||
async def process_image_component(self, component: ImageComponent) -> str:
|
||||
if component.content: # 先检查是否处理过
|
||||
return component.content
|
||||
from src.chat.image_system.image_manager import image_manager
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecv(Message):
|
||||
"""接收消息类,用于处理从MessageCQ序列化的消息"""
|
||||
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
"""从MessageCQ的字典初始化
|
||||
|
||||
Args:
|
||||
message_dict: MessageCQ序列化后的字典
|
||||
"""
|
||||
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
self.raw_message = message_dict.get("raw_message")
|
||||
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||
self.is_emoji = False
|
||||
self.has_emoji = False
|
||||
self.is_picid = False
|
||||
self.has_picid = False
|
||||
self.is_voice = False
|
||||
self.is_mentioned = None
|
||||
self.is_at = False
|
||||
self.reply_probability_boost = 0.0
|
||||
self.is_notify = False
|
||||
|
||||
self.is_command = False
|
||||
self.intercept_message_level = 0
|
||||
|
||||
self.priority_mode = "interest"
|
||||
self.priority_info = None
|
||||
self.interest_value: float = None # type: ignore
|
||||
|
||||
self.key_words = []
|
||||
self.key_words_lite = []
|
||||
|
||||
# 兼容适配器通过 additional_config 传入的 @ 标记
|
||||
# 获取描述
|
||||
try:
|
||||
msg_info_dict = message_dict.get("message_info", {})
|
||||
add_cfg = msg_info_dict.get("additional_config") or {}
|
||||
if isinstance(add_cfg, dict) and add_cfg.get("at_bot"):
|
||||
# 标记为被提及,提高后续回复优先级
|
||||
self.is_mentioned = True # type: ignore
|
||||
desc = await image_manager.get_image_description(image_bytes=component.binary_data)
|
||||
except Exception:
|
||||
pass
|
||||
desc = None
|
||||
|
||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||
self.chat_stream = chat_stream
|
||||
content = f"[图片:{desc}]" if desc else "[发了一张图片,网卡了加载不出来]"
|
||||
component.content = content
|
||||
return content
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本
|
||||
async def process_emoji_component(self, component: EmojiComponent) -> str:
|
||||
if component.content: # 先检查是否处理过
|
||||
return component.content
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
||||
"""
|
||||
# print(f"self.message_segment: {self.message_segment}")
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
# 获取表情包描述
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
image_manager = get_image_manager()
|
||||
# 使用 semaphore 限制 VLM 并发,避免同时处理太多图片
|
||||
async with _vlm_semaphore:
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, str):
|
||||
# 使用 semaphore 限制 VLM 并发
|
||||
async with _vlm_semaphore:
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "video_card":
|
||||
# 处理视频卡片消息
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
file_name = segment.data.get("file", "未知视频")
|
||||
file_size = segment.data.get("file_size", "")
|
||||
url = segment.data.get("url", "")
|
||||
text = f"[视频: {file_name}"
|
||||
if file_size:
|
||||
text += f", 大小: {file_size}字节"
|
||||
text += "]"
|
||||
if url:
|
||||
text += f" 链接: {url}"
|
||||
return text
|
||||
return "[视频]"
|
||||
elif segment.type == "music_card":
|
||||
# 处理音乐卡片消息
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
title = segment.data.get("title", "未知歌曲")
|
||||
singer = segment.data.get("singer", "")
|
||||
tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐"
|
||||
jump_url = segment.data.get("jump_url", "")
|
||||
music_url = segment.data.get("music_url", "")
|
||||
text = f"[音乐: {title}"
|
||||
if singer:
|
||||
text += f" - {singer}"
|
||||
if tag:
|
||||
text += f" ({tag})"
|
||||
text += "]"
|
||||
if jump_url:
|
||||
text += f" 跳转链接: {jump_url}"
|
||||
if music_url:
|
||||
text += f" 音乐链接: {music_url}"
|
||||
return text
|
||||
return "[音乐]"
|
||||
elif segment.type == "miniapp_card":
|
||||
# 处理小程序分享卡片(如B站视频分享)
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
title = segment.data.get("title", "") # 小程序名称
|
||||
desc = segment.data.get("desc", "") # 内容描述
|
||||
source_url = segment.data.get("source_url", "") # 原始链接
|
||||
url = segment.data.get("url", "") # 小程序链接
|
||||
text = "[小程序分享"
|
||||
if title:
|
||||
text += f" - {title}"
|
||||
text += "]"
|
||||
if desc:
|
||||
text += f" {desc}"
|
||||
if source_url:
|
||||
text += f" 链接: {source_url}"
|
||||
elif url:
|
||||
text += f" 链接: {url}"
|
||||
return text
|
||||
return "[小程序分享]"
|
||||
else:
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
tuple_content = await emoji_manager.get_emoji_description(emoji_bytes=component.binary_data)
|
||||
except Exception:
|
||||
tuple_content = None
|
||||
|
||||
if tuple_content:
|
||||
desc, _ = tuple_content
|
||||
content = f"[表情包: {desc}]"
|
||||
else:
|
||||
content = "[发了一个表情,网卡了加载不出来]"
|
||||
component.content = content
|
||||
return content
|
||||
|
||||
@dataclass
|
||||
class MessageProcessBase(Message):
|
||||
"""消息处理基类,用于处理中和发送中的消息"""
|
||||
async def process_at_component(self, component: AtComponent) -> str:
|
||||
if component.target_user_cardname:
|
||||
return f"@{component.target_user_cardname}"
|
||||
elif component.target_user_nickname:
|
||||
return f"@{component.target_user_nickname}"
|
||||
from src.common.utils.utils_person import PersonUtils
|
||||
|
||||
def __init__(
|
||||
if person_info := PersonUtils.get_person_info_by_user_id_and_platform(component.target_user_id, self.platform):
|
||||
component.target_user_nickname = component.target_user_nickname or person_info.user_nickname
|
||||
if self.message_info.group_info and person_info.group_cardname_list:
|
||||
for group_card in person_info.group_cardname_list:
|
||||
if group_card.group_id == self.message_info.group_info.group_id:
|
||||
component.target_user_cardname = group_card.group_cardname
|
||||
break
|
||||
if component.target_user_cardname:
|
||||
return f"@{component.target_user_cardname}"
|
||||
elif component.target_user_nickname:
|
||||
return f"@{component.target_user_nickname}"
|
||||
else:
|
||||
return f"@{component.target_user_id}"
|
||||
|
||||
async def process_voice_component(self, component: VoiceComponent) -> str:
|
||||
if component.content: # 先检查是否处理过
|
||||
return component.content
|
||||
from src.common.utils.utils_voice import get_voice_text
|
||||
|
||||
text = await get_voice_text(component.binary_data)
|
||||
content = "[语音消息,转录失败]" if text is None else f"[语音: {text}]"
|
||||
component.content = content
|
||||
return content
|
||||
|
||||
async def process_reply_component(
|
||||
self,
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
thinking_start_time: float = 0,
|
||||
timestamp: Optional[float] = None,
|
||||
):
|
||||
# 调用父类初始化,传递时间戳
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
timestamp=timestamp,
|
||||
chat_stream=chat_stream,
|
||||
user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
reply=reply,
|
||||
)
|
||||
component: ReplyComponent,
|
||||
id_content_map: MsgIDMapping,
|
||||
) -> str:
|
||||
if component.target_message_content:
|
||||
return component.target_message_content
|
||||
if result_item := id_content_map.mapping.get(component.target_message_id):
|
||||
content, sender_info = result_item
|
||||
if isinstance(content, Task):
|
||||
content = await content
|
||||
id_content_map.mapping[component.target_message_id] = (content, sender_info) # 更新为实际内容
|
||||
component.target_message_content = content
|
||||
tgt_msg_s_name = sender_info.user_cardname or sender_info.user_nickname or sender_info.user_id
|
||||
component.target_message_sender_cardname = sender_info.user_cardname
|
||||
component.target_message_sender_nickname = sender_info.user_nickname
|
||||
component.target_message_sender_id = sender_info.user_id
|
||||
return f"[回复了{tgt_msg_s_name}的消息: {content}]"
|
||||
else:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Messages).filter_by(message_id=component.target_message_id).limit(1)
|
||||
if db_msg := session.exec(statement).first():
|
||||
component.target_message_content = db_msg.processed_plain_text
|
||||
component.target_message_sender_cardname = db_msg.user_cardname
|
||||
component.target_message_sender_nickname = db_msg.user_nickname
|
||||
component.target_message_sender_id = db_msg.user_id
|
||||
tgt_msg_s_name = db_msg.user_cardname or db_msg.user_nickname or db_msg.user_id
|
||||
return f"[回复了{tgt_msg_s_name}的消息: {db_msg.processed_plain_text}]"
|
||||
except Exception as e:
|
||||
logger.error(f"查询回复消息时发生错误: {e}")
|
||||
|
||||
# 处理状态相关属性
|
||||
self.thinking_start_time = thinking_start_time
|
||||
self.thinking_time = 0
|
||||
return "[回复了一条消息,但原消息已无法访问]"
|
||||
|
||||
def update_thinking_time(self) -> float:
|
||||
"""更新思考时间"""
|
||||
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
|
||||
return self.thinking_time
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 要处理的消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_image_description(segment.data)
|
||||
return "[图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_tag(segment.data)
|
||||
return "[表情,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "at":
|
||||
return f"[@{segment.data}]"
|
||||
elif segment.type == "reply":
|
||||
if self.reply and hasattr(self.reply, "processed_plain_text"):
|
||||
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
|
||||
# print(f"reply: {self.reply}")
|
||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
|
||||
return ""
|
||||
else:
|
||||
return f"[{segment.type}:{str(segment.data)}]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
|
||||
timestamp = self.message_info.time
|
||||
user_info = self.message_info.user_info
|
||||
|
||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
|
||||
return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSending(MessageProcessBase):
|
||||
"""发送状态的消息类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
bot_user_info: UserInfo,
|
||||
sender_info: UserInfo | None, # 用来记录发送者信息
|
||||
message_segment: Seg,
|
||||
display_message: str = "",
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
apply_set_reply_logic: bool = False,
|
||||
reply_to: Optional[str] = None,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
chat_stream=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
reply=reply,
|
||||
thinking_start_time=thinking_start_time,
|
||||
)
|
||||
|
||||
# 发送状态特有属性
|
||||
self.sender_info = sender_info
|
||||
self.reply_to_message_id = reply.message_info.message_id if reply else None
|
||||
self.is_head = is_head
|
||||
self.is_emoji = is_emoji
|
||||
self.apply_set_reply_logic = apply_set_reply_logic
|
||||
|
||||
self.reply_to = reply_to
|
||||
|
||||
# 用于显示发送内容与显示不一致的情况
|
||||
self.display_message = display_message
|
||||
|
||||
self.interest_value = 0.0
|
||||
|
||||
self.selected_expressions = selected_expressions
|
||||
|
||||
def build_reply(self):
|
||||
"""设置回复消息"""
|
||||
if self.reply:
|
||||
self.reply_to_message_id = self.reply.message_info.message_id
|
||||
self.message_segment = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
|
||||
self.message_segment,
|
||||
],
|
||||
async def process_forward_component(
|
||||
self, component: ForwardNodeComponent, id_content_map: MsgIDMapping, recursion_depth: int = 0
|
||||
) -> str:
|
||||
task_list: List[Task] = []
|
||||
node_user_info_list: List[UserInfo] = []
|
||||
for node in component.forward_components:
|
||||
task = asyncio.create_task(
|
||||
self._process_multiple_components(node.content, id_content_map, recursion_depth + 1)
|
||||
)
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本"""
|
||||
if self.message_segment:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
def to_dict(self):
|
||||
ret = super().to_dict()
|
||||
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()
|
||||
return ret
|
||||
|
||||
def is_private_message(self) -> bool:
|
||||
"""判断是否为私聊消息"""
|
||||
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSet:
|
||||
"""消息集合类,可以存储多个发送消息"""
|
||||
|
||||
def __init__(self, chat_stream: "ChatStream", message_id: str):
|
||||
self.chat_stream = chat_stream
|
||||
self.message_id = message_id
|
||||
self.messages: list[MessageSending] = []
|
||||
self.time = round(time.time(), 3) # 保留3位小数
|
||||
|
||||
def add_message(self, message: MessageSending) -> None:
|
||||
"""添加消息到集合"""
|
||||
if not isinstance(message, MessageSending):
|
||||
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
||||
self.messages.append(message)
|
||||
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||
|
||||
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
||||
"""通过索引获取消息"""
|
||||
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
||||
"""获取最接近指定时间的消息"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
left, right = 0, len(self.messages) - 1
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if self.messages[mid].message_info.time < target_time: # type: ignore
|
||||
left = mid + 1
|
||||
node_user_info = UserInfo(node.user_id or "未知用户", node.user_nickname, node.user_cardname)
|
||||
id_content_map.mapping[node.message_id] = (task, node_user_info)
|
||||
task_list.append(task)
|
||||
node_user_info_list.append(node_user_info)
|
||||
results = await asyncio.gather(*task_list, return_exceptions=True)
|
||||
forward_texts = []
|
||||
for idx, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"处理转发消息组件时发生错误: {result}")
|
||||
else:
|
||||
right = mid
|
||||
usr_info = node_user_info_list[idx]
|
||||
msg_sender_name = usr_info.user_cardname or usr_info.user_nickname or usr_info.user_id or "未知用户"
|
||||
forward_texts.append(f"{'-' * recursion_depth * 2} 【{msg_sender_name}】: {result}")
|
||||
return "【合并转发消息: \n" + "\n".join(forward_texts) + "\n】"
|
||||
|
||||
return self.messages[left]
|
||||
|
||||
def clear_messages(self) -> None:
|
||||
"""清空所有消息"""
|
||||
self.messages.clear()
|
||||
|
||||
def remove_message(self, message: MessageSending) -> bool:
|
||||
"""移除指定消息"""
|
||||
if message in self.messages:
|
||||
self.messages.remove(message)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.messages)
|
||||
|
||||
|
||||
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
|
||||
return MessageRecv(message_dict)
|
||||
|
||||
|
||||
def message_from_db_dict(db_dict: dict) -> MessageRecv:
|
||||
"""从数据库字典创建MessageRecv实例"""
|
||||
# 转换扁平的数据库字典为嵌套结构
|
||||
message_info_dict = {
|
||||
"platform": db_dict.get("chat_info_platform"),
|
||||
"message_id": db_dict.get("message_id"),
|
||||
"time": db_dict.get("time"),
|
||||
"group_info": {
|
||||
"platform": db_dict.get("chat_info_group_platform"),
|
||||
"group_id": db_dict.get("chat_info_group_id"),
|
||||
"group_name": db_dict.get("chat_info_group_name"),
|
||||
},
|
||||
"user_info": {
|
||||
"platform": db_dict.get("user_platform"),
|
||||
"user_id": db_dict.get("user_id"),
|
||||
"user_nickname": db_dict.get("user_nickname"),
|
||||
"user_cardname": db_dict.get("user_cardname"),
|
||||
},
|
||||
}
|
||||
|
||||
processed_text = db_dict.get("processed_plain_text", "")
|
||||
|
||||
# 构建 MessageRecv 需要的字典
|
||||
recv_dict = {
|
||||
"message_info": message_info_dict,
|
||||
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
|
||||
"raw_message": None, # 数据库中未存储原始消息
|
||||
"processed_plain_text": processed_text,
|
||||
}
|
||||
|
||||
# 创建 MessageRecv 实例
|
||||
msg = MessageRecv(recv_dict)
|
||||
|
||||
# 从数据库字典中填充其他可选字段
|
||||
msg.interest_value = db_dict.get("interest_value", 0.0)
|
||||
msg.is_mentioned = db_dict.get("is_mentioned")
|
||||
msg.priority_mode = db_dict.get("priority_mode", "interest")
|
||||
msg.priority_info = db_dict.get("priority_info")
|
||||
msg.is_emoji = db_dict.get("is_emoji", False)
|
||||
msg.is_picid = db_dict.get("is_picid", False)
|
||||
|
||||
return msg
|
||||
async def _process_multiple_components(
|
||||
self, components: Sequence[StandardMessageComponents], id_content_map: MsgIDMapping, recursion_depth: int = 0
|
||||
) -> str:
|
||||
tasks = [
|
||||
self.process_single_component(component, id_content_map, recursion_depth=recursion_depth)
|
||||
for component in components
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
processed_texts: List[str] = []
|
||||
for result in results:
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"处理消息组件时发生错误: {result}")
|
||||
else:
|
||||
processed_texts.append(result)
|
||||
return " ".join(processed_texts)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
from typing import Optional, List, Union, Dict, Any, Sequence
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
|
|
@ -142,9 +142,9 @@ class AtComponent(BaseMessageComponentModel):
|
|||
) -> None:
|
||||
self.target_user_id = target_user_id
|
||||
"""目标用户ID"""
|
||||
self.target_user_nickname = target_user_nickname
|
||||
self.target_user_nickname: Optional[str] = target_user_nickname
|
||||
"""目标用户昵称"""
|
||||
self.target_user_cardname = target_user_cardname
|
||||
self.target_user_cardname: Optional[str] = target_user_cardname
|
||||
"""目标用户备注名"""
|
||||
assert isinstance(target_user_id, str), "AtComponent 的 target_user_id 必须是字符串类型"
|
||||
|
||||
|
|
@ -159,10 +159,25 @@ class ReplyComponent(BaseMessageComponentModel):
|
|||
def format_name(self) -> str:
|
||||
return "reply"
|
||||
|
||||
def __init__(self, target_message_id: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
target_message_id: str,
|
||||
target_message_content: Optional[str] = None,
|
||||
target_message_sender_id: Optional[str] = None,
|
||||
target_message_sender_nickname: Optional[str] = None,
|
||||
target_message_sender_cardname: Optional[str] = None,
|
||||
) -> None:
|
||||
assert isinstance(target_message_id, str), "ReplyComponent 的 target_message_id 必须是字符串类型"
|
||||
self.target_message_id = target_message_id
|
||||
"""目标消息ID"""
|
||||
self.target_message_content: Optional[str] = target_message_content
|
||||
"""目标消息内容"""
|
||||
self.target_message_sender_id: Optional[str] = target_message_sender_id
|
||||
"""目标消息发送者ID"""
|
||||
self.target_message_sender_nickname: Optional[str] = target_message_sender_nickname
|
||||
"""目标消息发送者昵称"""
|
||||
self.target_message_sender_cardname: Optional[str] = target_message_sender_cardname
|
||||
"""目标消息发送者群昵称"""
|
||||
|
||||
async def to_seg(self) -> Seg:
|
||||
return Seg(type="reply", data=self.target_message_id)
|
||||
|
|
@ -224,7 +239,7 @@ class ForwardComponent(BaseMessageComponentModel):
|
|||
self,
|
||||
user_nickname: str,
|
||||
message_id: str,
|
||||
content: List[StandardMessageComponents],
|
||||
content: Sequence[StandardMessageComponents],
|
||||
user_id: Optional[str] = None,
|
||||
user_cardname: Optional[str] = None,
|
||||
):
|
||||
|
|
@ -232,7 +247,7 @@ class ForwardComponent(BaseMessageComponentModel):
|
|||
"""转发节点的发送者昵称"""
|
||||
self.message_id: str = message_id
|
||||
"""转发节点的消息ID"""
|
||||
self.content: List[StandardMessageComponents] = content
|
||||
self.content: Sequence[StandardMessageComponents] = content
|
||||
"""消息内容"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
"""转发节点的发送者ID,可能为 None"""
|
||||
|
|
@ -249,7 +264,7 @@ class ForwardComponent(BaseMessageComponentModel):
|
|||
class MessageSequence:
|
||||
"""消息组件序列,包含一个消息中的所有组件,按照顺序排列"""
|
||||
|
||||
def __init__(self, components: List[StandardMessageComponents]):
|
||||
def __init__(self, components: Sequence[StandardMessageComponents]):
|
||||
"""
|
||||
创建一个消息组件序列
|
||||
|
||||
|
|
@ -259,16 +274,16 @@ class MessageSequence:
|
|||
因此也可以包含多个`ReplyComponent`组件(例如回复多条消息)。
|
||||
如果需要对组件进行去重或校验,还请在使用时自行处理。
|
||||
"""
|
||||
self.components: List[StandardMessageComponents] = components
|
||||
self.components: Sequence[StandardMessageComponents] = components
|
||||
|
||||
def to_dict(self) -> List[Dict[str, Any]]:
|
||||
"""将消息序列转换为字典列表格式,便于存储或传输"""
|
||||
return [self._item_2_dict(comp) for comp in self.components]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence":
|
||||
def from_dict(cls, data: List[Dict[str, Any]]):
|
||||
"""从字典列表格式创建消息序列实例"""
|
||||
components: List[StandardMessageComponents] = []
|
||||
components: Sequence[StandardMessageComponents] = []
|
||||
components.extend(cls._dict_2_item(item) for item in data)
|
||||
return cls(components=components)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from maim_message import MessageBase, Seg
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Tuple, Optional, Sequence
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
|
|
@ -35,7 +35,7 @@ class MessageUtils:
|
|||
def from_maim_message_segments_to_MaiSeq(message: "MessageBase") -> MessageSequence:
|
||||
"""从maim_message.MessageBase.message_segment转换为MessageSequence"""
|
||||
raw_msg_seq = message.message_segment
|
||||
components: List[StandardMessageComponents] = []
|
||||
components: Sequence[StandardMessageComponents] = []
|
||||
if not raw_msg_seq:
|
||||
return MessageSequence(components)
|
||||
if raw_msg_seq.type == "seglist":
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class PersonUtils:
|
|||
"""根据person_id获取用户信息"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).filter_by(person_id=person_id)
|
||||
statement = select(PersonInfo).filter_by(person_id=person_id).limit(1)
|
||||
if result := session.exec(statement).first():
|
||||
return MaiPersonInfo.from_db_instance(result)
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Reference in New Issue