合并到远程

pull/1496/head
UnCLAS-Prommer 2026-02-13 13:41:58 +08:00
parent 60f76e4d4e
commit b9f3c17e14
No known key found for this signature in database
15 changed files with 420 additions and 450 deletions

2
bot.py
View File

@ -1,4 +1,4 @@
# raise RuntimeError("System Not Ready")
raise RuntimeError("System Not Ready")
import asyncio
import hashlib
import os

View File

@ -106,7 +106,7 @@ def _install_stub_modules(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
class _Result:
def scalars(self):
return self
@ -231,8 +231,8 @@ def _install_stub_modules(monkeypatch):
def import_emoji_manager_new(monkeypatch):
_install_stub_modules(monkeypatch)
file_path = Path(__file__).resolve().parents[2] / "src" / "chat" / "emoji_system" / "emoji_manager_new.py"
spec = importlib.util.spec_from_file_location("emoji_manager_new", file_path)
file_path = Path(__file__).resolve().parents[2] / "src" / "chat" / "emoji_system" / "emoji_manager.py"
spec = importlib.util.spec_from_file_location("emoji_manager", file_path)
module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "emoji_manager_new", module)
spec.loader.exec_module(module)
@ -446,7 +446,7 @@ def test_load_emojis_from_db_empty(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
return _Result()
def _get_db_session():
@ -487,7 +487,7 @@ def test_load_emojis_from_db_partial_bad_records(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
return _Result()
def _get_db_session():
@ -524,7 +524,7 @@ def test_load_emojis_from_db_execute_error(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
raise RuntimeError("execute failed")
def _get_db_session():
@ -581,7 +581,7 @@ def test_load_emojis_from_db_scalars_all_error(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
return _Result()
def _get_db_session():
@ -799,6 +799,8 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
def _select(_model):
return _Select()
@ -817,7 +819,7 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
return _Result()
def _get_db_session():
@ -887,6 +889,9 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
def _select(_model):
return _Select()
@ -898,7 +903,7 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
raise RuntimeError("db delete failed")
def _get_db_session():
@ -942,6 +947,8 @@ def test_delete_emoji_success(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
def _select(_model):
return _Select()
@ -966,7 +973,7 @@ def test_delete_emoji_success(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
return _Result()
def delete(self, _record):
@ -998,6 +1005,8 @@ def test_update_emoji_usage_success(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
def _select(_model):
return _Select()
@ -1021,7 +1030,7 @@ def test_update_emoji_usage_success(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
return _Result()
def add(self, _record):
@ -1051,6 +1060,9 @@ def test_update_emoji_usage_missing_record(monkeypatch):
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
def _select(_model):
return _Select()
@ -1068,7 +1080,7 @@ def test_update_emoji_usage_missing_record(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
return _Result()
def _get_db_session():
@ -1094,6 +1106,8 @@ def test_update_emoji_usage_execute_error(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
def _select(_model):
return _Select()
@ -1105,7 +1119,7 @@ def test_update_emoji_usage_execute_error(monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _statement):
def exec(self, _statement):
raise RuntimeError("execute failed")
def _get_db_session():

View File

@ -62,7 +62,7 @@ class EmojiManager:
try:
with get_db_session() as session:
statement = select(Images)
results = session.execute(statement).scalars().all()
results = session.exec(statement).all()
for record in results:
try:
emoji = MaiEmoji.from_db_instance(record)
@ -144,8 +144,8 @@ class EmojiManager:
# 删除数据库记录
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI)
if image_record := session.execute(statement).scalars().first():
statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI).limit(1)
if image_record := session.exec(statement).first():
session.delete(image_record)
logger.info(f"[删除表情包] 成功删除数据库中的表情包记录: {emoji.emoji_hash}")
else:
@ -170,8 +170,8 @@ class EmojiManager:
"""
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI)
if image_record := session.execute(statement).scalars().first():
statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI).limit(1)
if image_record := session.exec(statement).first():
image_record.query_count += 1
image_record.last_used_time = datetime.now()
session.add(image_record)

View File

@ -1,54 +1,6 @@
import copy
from typing import Any
class BaseDataModel:
def deepcopy(self):
return copy.deepcopy(self)
def transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else
"""
将对象或容器中的 BaseDataModel 子类类对象 BaseDataModel 实例
递归转换为普通 dict不修改原对象
- 对于类对象isinstance(value, type) issubclass(..., BaseDataModel)
读取类的 __dict__ 中非 dunder 项并递归转换
- 对于实例isinstance(value, BaseDataModel)读取 vars(instance) 并递归转换
"""
def _transform(value: Any) -> Any:
# 值是类对象且为 BaseDataModel 的子类
if isinstance(value, type) and issubclass(value, BaseDataModel):
return {k: _transform(v) for k, v in value.__dict__.items() if not k.startswith("__") and not callable(v)}
# 值是 BaseDataModel 的实例
if isinstance(value, BaseDataModel):
return {k: _transform(v) for k, v in vars(value).items()}
# 常见容器类型,递归处理
if isinstance(value, dict):
return {k: _transform(v) for k, v in value.items()}
if isinstance(value, list):
return [_transform(v) for v in value]
if isinstance(value, tuple):
return tuple(_transform(v) for v in value)
if isinstance(value, set):
return {_transform(v) for v in value}
# 基本类型,直接返回
return value
result = _transform(obj)
def flatten(target_dict: dict):
flat_dict = {}
for k, v in target_dict.items():
if isinstance(v, dict):
# 递归扁平化子字典
sub_flat = flatten(v)
flat_dict.update(sub_flat)
else:
flat_dict[k] = v
return flat_dict
return flatten(result) if isinstance(result, dict) else result

View File

@ -0,0 +1,226 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo
from typing import Optional, List, Union, Dict, Any
import asyncio
import hashlib
import base64
from src.common.logger import get_logger
logger = get_logger("base_message_component_model")
class BaseMessageComponentModel(ABC):
@abstractmethod
async def to_seg(self) -> Seg:
"""将消息组件转换为 maim_message.Seg 对象"""
raise NotImplementedError
def clone(self):
return deepcopy(self)
class ByteComponent:
def __init__(self, *, binary_hash: str, content: Optional[str] = None, binary_data: Optional[bytes] = None) -> None:
self.content: str = content if content is not None else ""
"""处理后的内容"""
self.binary_data: bytes = binary_data if binary_data is not None else b""
"""原始二进制数据"""
self.binary_hash: str = hashlib.sha256(self.binary_data).hexdigest() if self.binary_data else binary_hash
"""二进制数据的 SHA256 哈希值,用于唯一标识该二进制数据"""
class TextComponent(BaseMessageComponentModel):
def __init__(self, text: str):
self.text = text
assert isinstance(text, str), "TextComponent 的 text 必须是字符串类型"
async def to_seg(self) -> Seg:
return Seg(type="text", data=self.text)
class ImageComponent(BaseMessageComponentModel, ByteComponent):
async def load_image_binary(self):
if not self.binary_data:
...
async def to_seg(self) -> Seg:
if not self.binary_data:
await self.load_image_binary()
return Seg(type="image", data=base64.b64encode(self.binary_data).decode())
class EmojiComponent(BaseMessageComponentModel, ByteComponent):
async def load_emoji_binary(self) -> None:
"""
加载表情的二进制数据如果 binary_data 为空则通过 emoji_hash 从表情管理器加载
Raises:
ValueError: 如果 binary_data 为空且缺少 emoji_hash
ValueError: 如果无法通过 emoji_hash 加载表情二进制数据
"""
if not self.binary_data:
from src.chat.emoji_system.emoji_manager import emoji_manager
if not (
emoji := emoji_manager.get_emoji_by_hash(self.binary_hash)
or emoji_manager.get_emoji_by_hash_from_db(self.binary_hash)
):
raise ValueError(f"无法通过 emoji_hash 加载表情二进制数据: {self.binary_hash}")
try:
self.binary_data = await asyncio.to_thread(emoji.full_path.read_bytes)
except Exception as e:
raise ValueError(f"通过 emoji_hash 加载表情二进制数据时发生错误: {e}") from e
async def to_seg(self) -> Seg:
if not self.binary_data:
await self.load_emoji_binary()
return Seg(type="emoji", data=base64.b64encode(self.binary_data).decode())
class VoiceComponent(BaseMessageComponentModel, ByteComponent):
async def load_voice_binary(self) -> None:
if not self.binary_data:
from src.common.utils.utils_file import FileUtils
try:
file_path = FileUtils.get_file_path_by_hash(self.binary_hash)
self.binary_data = await asyncio.to_thread(file_path.read_bytes)
except Exception as e:
raise ValueError(f"通过 voice_hash 加载语音二进制数据时发生错误: {e}") from e
async def to_seg(self) -> Seg:
if not self.binary_data:
await self.load_voice_binary()
return Seg(type="voice", data=base64.b64encode(self.binary_data).decode())
class ForwardNodeComponent(BaseMessageComponentModel):
def __init__(self, forward_components: List["ForwardComponent"]):
self.forward_components = forward_components
assert isinstance(forward_components, list), "ForwardNodeComponent 的 forward_components 必须是列表类型"
assert all(isinstance(comp, ForwardComponent) for comp in forward_components), (
"ForwardNodeComponent 的 forward_components 列表中必须全部是 ForwardComponent 类型"
)
assert forward_components, "ForwardNodeComponent 的 forward_components 不能为空列表"
async def to_seg(self) -> "Seg":
resp: List[Dict[str, Any]] = []
for comp in self.forward_components:
data = await comp.to_seg()
sender_info = UserInfo(None, comp.user_id, comp.user_nickname, comp.user_cardname)
base_message_info = BaseMessageInfo(user_info=sender_info)
base_message = MessageBase(base_message_info, data)
resp.append(base_message.to_dict())
return Seg(type="forward", data=resp) # type: ignore
class DictComponent:
def __init__(self, data: Dict[str, Any]):
self.data = data
assert isinstance(data, dict), "DictComponent 的 data 必须是字典类型"
StandardMessageComponents = Union[
TextComponent,
ImageComponent,
EmojiComponent,
VoiceComponent,
ForwardNodeComponent,
DictComponent,
]
class ForwardComponent(BaseMessageComponentModel):
def __init__(
self,
user_nickname: str,
content: List[StandardMessageComponents],
user_id: Optional[str] = None,
user_cardname: Optional[str] = None,
):
self.user_nickname: str = user_nickname
self.content: List[StandardMessageComponents] = content
self.user_id: Optional[str] = user_id
self.user_cardname: Optional[str] = user_cardname
assert self.content, "ForwardComponent 的 content 不能为空"
async def to_seg(self) -> "Seg":
return Seg(
type="seglist", data=[await comp.to_seg() for comp in self.content if not isinstance(comp, DictComponent)]
)
class MessageSequence:
def __init__(self, components: List[StandardMessageComponents]):
self.components: List[StandardMessageComponents] = components
def to_dict(self) -> List[Dict[str, Any]]:
return [self._item_2_dict(comp) for comp in self.components]
def _item_2_dict(self, item: StandardMessageComponents) -> Dict[str, Any]:
if isinstance(item, TextComponent):
return {"type": "text", "data": item.text}
elif isinstance(item, ImageComponent):
if not item.content:
raise RuntimeError("ImageComponent content 未初始化")
return {"type": "image", "data": item.content, "hash": item.binary_hash}
elif isinstance(item, EmojiComponent):
if not item.content:
raise RuntimeError("EmojiComponent content 未初始化")
return {"type": "emoji", "data": item.content, "hash": item.binary_hash}
elif isinstance(item, VoiceComponent):
if not item.content:
raise RuntimeError("VoiceComponent content 未初始化")
return {"type": "voice", "data": item.content, "hash": item.binary_hash}
elif isinstance(item, ForwardNodeComponent):
return {
"type": "forward",
"data": [
{
"user_id": comp.user_id,
"user_nickname": comp.user_nickname,
"user_cardname": comp.user_cardname,
"content": [self._item_2_dict(c) for c in comp.content],
}
for comp in item.forward_components
],
}
else:
logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent")
return {"type": "dict", "data": item.data}
@classmethod
def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence":
components: List[StandardMessageComponents] = []
components.extend(cls._dict_2_item(item) for item in data)
return cls(components=components)
@classmethod
def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents:
item_type = item.get("type")
if item_type == "text":
return TextComponent(text=item["data"])
elif item_type == "image":
return ImageComponent(binary_hash=item["hash"], content=item["data"])
elif item_type == "emoji":
return EmojiComponent(binary_hash=item["hash"], content=item["data"])
elif item_type == "voice":
return VoiceComponent(binary_hash=item["hash"], content=item["data"])
elif item_type == "forward":
forward_components = []
for fc in item["data"]:
content = [cls._dict_2_item(c) for c in fc["content"]]
forward_component = ForwardComponent(
user_nickname=fc["user_nickname"],
user_id=fc.get("user_id"),
user_cardname=fc.get("user_cardname"),
content=content,
)
forward_components.append(forward_component)
return ForwardNodeComponent(forward_components=forward_components)
else:
logger.warning(f"Unofficial component type in dict: {item_type}, defaulting to DictComponent")
return DictComponent(data=item.get("data") or {})

View File

@ -1,210 +0,0 @@
from typing import Optional, TYPE_CHECKING, List, Tuple, Union, Dict, Any
from dataclasses import dataclass, field
from enum import Enum
from . import BaseDataModel
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
@dataclass
class MessageAndActionModel(BaseDataModel):
chat_id: str = field(default_factory=str)
time: float = field(default_factory=float)
user_id: str = field(default_factory=str)
user_platform: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
user_cardname: Optional[str] = None
processed_plain_text: Optional[str] = None
display_message: Optional[str] = None
chat_info_platform: str = field(default_factory=str)
is_action_record: bool = field(default=False)
action_name: Optional[str] = None
is_command: bool = field(default=False)
intercept_message_level: int = field(default=0)
@classmethod
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
return cls(
chat_id=message.chat_id,
time=message.time,
user_id=message.user_info.user_id,
user_platform=message.user_info.platform,
user_nickname=message.user_info.user_nickname,
user_cardname=message.user_info.user_cardname,
processed_plain_text=message.processed_plain_text,
display_message=message.display_message,
chat_info_platform=message.chat_info.platform,
is_command=message.is_command,
intercept_message_level=getattr(message, "intercept_message_level", 0),
)
class ReplyContentType(Enum):
TEXT = "text"
IMAGE = "image"
EMOJI = "emoji"
COMMAND = "command"
VOICE = "voice"
FORWARD = "forward"
HYBRID = "hybrid" # 混合类型,包含多种内容
def __repr__(self) -> str:
return self.value
@dataclass
class ForwardNode(BaseDataModel):
user_id: Optional[str] = None
user_nickname: Optional[str] = None
content: Union[List["ReplyContent"], str] = field(default_factory=list)
@classmethod
def construct_as_id_reference(cls, message_id: str) -> "ForwardNode":
return cls(user_id="", user_nickname="", content=message_id)
@classmethod
def construct_as_created_node(
cls, user_id: str, user_nickname: str, content: List["ReplyContent"]
) -> "ForwardNode":
return cls(user_id=user_id, user_nickname=user_nickname, content=content)
@dataclass
class ReplyContent(BaseDataModel):
content_type: ReplyContentType | str
content: Union[str, Dict, List[ForwardNode], List["ReplyContent"]] # 支持嵌套的 ReplyContent
@classmethod
def construct_as_text(cls, text: str):
return cls(content_type=ReplyContentType.TEXT, content=text)
@classmethod
def construct_as_image(cls, image_base64: str):
return cls(content_type=ReplyContentType.IMAGE, content=image_base64)
@classmethod
def construct_as_voice(cls, voice_base64: str):
return cls(content_type=ReplyContentType.VOICE, content=voice_base64)
@classmethod
def construct_as_emoji(cls, emoji_str: str):
return cls(content_type=ReplyContentType.EMOJI, content=emoji_str)
@classmethod
def construct_as_command(cls, command_arg: Dict):
return cls(content_type=ReplyContentType.COMMAND, content=command_arg)
@classmethod
def construct_as_hybrid(cls, hybrid_content: List[Tuple[ReplyContentType | str, str]]):
hybrid_content_list: List[ReplyContent] = []
for content_type, content in hybrid_content:
assert content_type not in [
ReplyContentType.HYBRID,
ReplyContentType.FORWARD,
ReplyContentType.VOICE,
ReplyContentType.COMMAND,
], "混合内容的每个项不能是混合、转发、语音或命令类型"
assert isinstance(content, str), "混合内容的每个项必须是字符串"
hybrid_content_list.append(ReplyContent(content_type=content_type, content=content))
return cls(content_type=ReplyContentType.HYBRID, content=hybrid_content_list)
@classmethod
def construct_as_forward(cls, forward_nodes: List[ForwardNode]):
return cls(content_type=ReplyContentType.FORWARD, content=forward_nodes)
def __post_init__(self):
if isinstance(self.content_type, ReplyContentType):
if self.content_type not in [ReplyContentType.HYBRID, ReplyContentType.FORWARD] and isinstance(
self.content, List
):
raise ValueError(
f"非混合类型/转发类型的内容不能是列表content_type: {self.content_type}, content: {self.content}"
)
elif self.content_type in [ReplyContentType.HYBRID, ReplyContentType.FORWARD]:
if not isinstance(self.content, List):
raise ValueError(
f"混合类型/转发类型的内容必须是列表content_type: {self.content_type}, content: {self.content}"
)
@dataclass
class ReplySetModel(BaseDataModel):
"""
回复集数据模型用于多种回复类型的返回
"""
reply_data: List[ReplyContent] = field(default_factory=list)
def __len__(self):
return len(self.reply_data)
def add_text_content(self, text: str):
"""
添加文本内容
Args:
text: 文本内容
"""
self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text))
def add_image_content(self, image_base64: str):
"""
添加图片内容base64编码的图片数据
Args:
image_base64: base64编码的图片数据
"""
self.reply_data.append(ReplyContent(content_type=ReplyContentType.IMAGE, content=image_base64))
def add_voice_content(self, voice_base64: str):
"""
添加语音内容base64编码的音频数据
Args:
voice_base64: base64编码的音频数据
"""
self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64))
def add_hybrid_content_by_raw(self, hybrid_content: List[Tuple[ReplyContentType | str, str]]):
"""
添加混合型内容可以包含text, image, emoji的任意组合
Args:
hybrid_content: 元组 (类型, 消息内容) 构成的列表[(ReplyContentType.TEXT, "Hello"), (ReplyContentType.IMAGE, "<base64")]
"""
hybrid_content_list: List[ReplyContent] = []
for content_type, content in hybrid_content:
assert content_type not in [
ReplyContentType.HYBRID,
ReplyContentType.FORWARD,
ReplyContentType.VOICE,
ReplyContentType.COMMAND,
], "混合内容的每个项不能是混合、转发、语音或命令类型"
assert isinstance(content, str), "混合内容的每个项必须是字符串"
hybrid_content_list.append(ReplyContent(content_type=content_type, content=content))
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_content_list))
def add_hybrid_content(self, hybrid_content: List[ReplyContent]):
"""
添加混合型内容使用已经构造好的 ReplyContent 列表
Args:
hybrid_content: ReplyContent 构成的列表[ReplyContent(ReplyContentType.TEXT, "Hello"), ReplyContent(ReplyContentType.IMAGE, "<base64")]
"""
for content in hybrid_content:
assert content.content_type not in [
ReplyContentType.HYBRID,
ReplyContentType.FORWARD,
ReplyContentType.VOICE,
ReplyContentType.COMMAND,
], "混合内容的每个项不能是混合、转发、语音或命令类型"
assert isinstance(content.content, str), "混合内容的每个项必须是字符串"
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_content))
def add_custom_content(self, content_type: str, content: Any):
"""
添加自定义类型的内容"""
self.reply_data.append(ReplyContent(content_type=content_type, content=content))
def add_forward_content(self, forward_content: List[ForwardNode]):
"""添加转发内容可以是字符串或ReplyContent嵌套的转发内容需要自己构造放入"""
self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_content))

View File

@ -1,36 +0,0 @@
# 对于`message_data_model.py`中`class ReplyContent`的规划解读
分类讨论如下:
- `ReplyContent.TEXT`: 单独的文本,`_level = 0``content`为`str`类型。
- `ReplyContent.IMAGE`: 单独的图片,`_level = 0``content`为`str`类型图片base64
- `ReplyContent.EMOJI`: 单独的表情包,`_level = 0``content`为`str`类型图片base64
- `ReplyContent.VOICE`: 单独的语音,`_level = 0``content`为`str`类型语音base64
- `ReplyContent.HYBRID`: 混合内容,`_level = 0`
- 其应该是一个列表,列表内应该只接受`str`类型的内容(图片和文本混合体)
- `ReplyContent.FORWARD`: 转发消息,`_level = n`
- 其应该是一个列表,列表接受`str`类型(图片/文本),`ReplyContent`类型(嵌套转发,嵌套有最高层数限制)
- `ReplyContent.COMMAND`: 指令消息,`_level = 0`
- 其应该是一个列表,列表内应该只接受`Dict`类型的内容
未来规划:
- `ReplyContent.AT` 单独的艾特,`_level = 0``content`为`str`类型用户ID
内容构造方式:
- 对于`TEXT`, `IMAGE`, `EMOJI`, `VOICE`,直接传入对应类型的内容,且`content`应该为`str`。
- 对于`COMMAND`,传入一个字典,字典内的内容类型应符合上述规定。
- 对于`HYBRID`, `FORWARD`,传入一个列表,列表内的内容类型应符合上述规定。
因此,我们的类型注解应该是:
```python
from typing import Union, List, Dict
ReplyContentType = Union[
str, # TEXT, IMAGE, EMOJI, VOICE
List[Union[str, 'ReplyContent']], # HYBRID, FORWARD
Dict # COMMAND
]
```
现在`_level`被移除了,在解析的时候显式地检查内容的类型和结构即可。
`send_api`的custom_reply_set_to_stream仅在特定的类型下提供reply)message

View File

@ -1,57 +0,0 @@
# 有关转发消息和其他消息的构建类型说明
```mermaid
graph LR;
direction TB;
A[ReplySet] --- B[ReplyContent];
A --- C["ReplyContent"];
A --- K["ReplyContent"];
A --- L["ReplyContent"];
A --- N["ReplyContent"];
A --- D[...];
B --- E["Text (in str)"];
B --- F["Image (in base64)"];
C --- G["Voice (in base64)"];
B --- I["Emoji (in base64)"];
subgraph "可行内容(以下的任意组合)";
subgraph "转发消息(Forward)"
M["List[ForwardNode]"]
end
subgraph "混合消息(Hybrid)"
J["List[ReplyContent] (要求只能包含普通消息)"]
end
subgraph "命令消息(Command)"
H["Command (in Dict)"]
end
subgraph "语音消息"
G
end
subgraph "普通消息"
E
F
I
end
end
N --- H
K --- J
L --- M
subgraph ForwardNodes
O["ForwardNode"]
P["ForwardNode"]
Q["ForwardNode"]
end
M --- O
M --- P
M --- Q
subgraph "内容 (message_id引用法)"
P --- U["content: str, 引用已有消息的有效ID"];
end
subgraph "内容 (生成法)"
O --- R["user_id: str"];
O --- S["user_nickname: str"];
O --- T["content: List[ReplyContent], 为这个转发节点的消息内容"];
end
```
另外,自定义消息类型我们在这里不做讨论。
以上列出了所有可能的ReplySet构建方式下面我们来解释一下各个类型的含义。

View File

@ -1,10 +1,10 @@
from rich.traceback import install
from pathlib import Path
from contextlib import contextmanager
from sqlalchemy import create_engine, event, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy import inspect as sqlalchemy_inspect
from sqlalchemy.orm import Session, sessionmaker
from sqlmodel import create_engine, Session
from typing import TYPE_CHECKING, Generator
if TYPE_CHECKING:
@ -27,19 +27,12 @@ DATABASE_URL = f"sqlite:///{_DB_FILE}"
def set_sqlite_pragma(dbapi_connection: "SQLite3Connection", connection_record):
"""
为每个新的数据库连接设置 SQLite PRAGMA
这些设置优化了并发性能和数据安全性:
- journal_mode=WAL: 启用预写式日志,提高并发性能
- cache_size: 设置缓存大小为 64MB
- foreign_keys: 启用外键约束
- synchronous=NORMAL: 平衡性能和数据安全
- busy_timeout: 设置1秒超时,避免锁定冲突
"""
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA cache_size=-64000") # 负值表示KB,64000KB = 64MB
cursor.execute("PRAGMA foreign_keys=ON")
cursor.execute("PRAGMA synchronous=NORMAL") # NORMAL 模式在WAL下是安全的
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA busy_timeout=1000") # 1秒超时
cursor.close()
@ -52,11 +45,12 @@ engine = create_engine(
pool_pre_ping=True,
)
# 创建会话工厂
# 创建会话工厂(使用 sqlmodel.Session
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
@ -96,7 +90,6 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
session = SessionLocal()
try:
yield session
# 如果启用自动提交且没有异常,则提交事务
if auto_commit:
session.commit()
except Exception:
@ -132,59 +125,3 @@ def get_db() -> Generator[Session, None, None]:
yield session
finally:
session.close()
class _AtomicContext:
def __init__(self) -> None:
self._session: Session | None = None
def __enter__(self) -> Session:
self._session = SessionLocal()
self._session.begin()
return self._session
def __exit__(self, exc_type, exc, tb) -> None:
if self._session is None:
return
try:
if exc_type is None:
self._session.commit()
else:
self._session.rollback()
finally:
self._session.close()
class DatabaseCompat:
"""兼容旧 db 调用接口Peewee 风格),底层使用 SQLAlchemy。"""
def connect(self, reuse_if_open: bool = True) -> None:
# SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。
_ = reuse_if_open
def create_tables(self, models: list[type], safe: bool = True) -> None:
_ = safe
tables = [model.__table__ for model in models if hasattr(model, "__table__")]
if not tables:
return
from sqlmodel import SQLModel
SQLModel.metadata.create_all(engine, tables=tables)
def atomic(self) -> _AtomicContext:
return _AtomicContext()
def execute_sql(self, sql: str):
with engine.connect() as conn:
result = conn.execute(text(sql))
conn.commit()
return result
def table_exists(self, model: type) -> bool:
if not hasattr(model, "__tablename__"):
return False
inspector = sqlalchemy_inspect(engine)
return inspector.has_table(model.__tablename__)
db = DatabaseCompat()

View File

@ -1,6 +1,6 @@
from typing import Optional
from sqlalchemy import Column, Float, Enum as SQLEnum
from sqlmodel import SQLModel, Field
from sqlmodel import SQLModel, Field, LargeBinary
from enum import Enum
from datetime import datetime
@ -45,8 +45,8 @@ class Messages(SQLModel, table=True):
is_notify: bool = Field(default=False) # 是否为通知消息
# 消息内容
raw_content: str # base64编码的原始消息内容
processed_plain_text: str = Field(index=True) # 平面化处理后的纯文本消息
raw_content: bytes = Field(sa_column=Column(LargeBinary)) # base64编码的原始消息内容
processed_plain_text: str = Field() # 平面化处理后的纯文本消息
display_message: str # 显示的消息内容被放入Prompt
# 其他配置
@ -85,9 +85,9 @@ class Images(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
# 元信息
image_hash: str = Field(default="", max_length=255) # 图片哈希使用sha256哈希值亦作为图片唯一ID
image_hash: str = Field(index=True, max_length=255) # 图片哈希使用sha256哈希值亦作为图片唯一ID
description: str # 图片的描述
full_path: str = Field(index=True, max_length=1024) # 文件的完整路径 (包括文件名)
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI)
"""图片类型,例如 'emoji''image'"""
emotion: Optional[str] = Field(default=None, nullable=True) # 表情包的情感标签,逗号分隔
@ -116,7 +116,7 @@ class ActionRecord(SQLModel, table=True):
session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id
# 调用信息
action_name: str = Field(max_length=255) # 动作名称
action_name: str = Field(index=True, max_length=255) # 动作名称
action_reasoning: Optional[str] = Field(default=None) # 动作推理过程
action_data: Optional[str] = Field(default=None) # 动作数据JSON格式存储
@ -153,7 +153,7 @@ class OnlineTime(SQLModel, table=True):
timestamp: datetime = Field(default_factory=datetime.now, index=True) # 时间戳
duration_minutes: int = Field() # 时长,单位秒
start_timestamp: datetime = Field(default_factory=datetime.now) # 上线时间
end_timestamp: datetime = Field(index=True) # 下线时间
end_timestamp: datetime = Field() # 下线时间
class Expression(SQLModel, table=True):
@ -230,7 +230,68 @@ class ThinkingQuestion(SQLModel, table=True):
context: Optional[str] = Field(default=None, nullable=True) # 上下文
found_answer: bool = Field(default=False) # 是否找到答案
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤JSON格式存储
created_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 创建时间
updated_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 最后更新时间
class BinaryData(SQLModel, table=True):
"""存储二进制数据的模型"""
__tablename__ = "binary_data" # type: ignore
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
data_hash: str = Field(index=True, max_length=255) # 数据哈希使用sha256哈希值亦作为数据唯一ID
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
class PersonInfo(SQLModel, table=True):
"""存储个人信息的模型"""
__tablename__ = "person_info" # type: ignore
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
is_known: bool = Field(default=False) # 是否为已知人
person_id: str = Field(unique=True, index=True, max_length=255) # 人员ID
person_name: Optional[str] = Field(default=None, max_length=255, nullable=True) # 人员名称
name_reason: Optional[str] = Field(default=None, nullable=True) # 名称原因
# 身份元数据
platform: str = Field(index=True, max_length=100) # 平台名称
user_id: str = Field(index=True, max_length=255) # 用户ID
user_nickname: str = Field(index=True, max_length=255) # 用户昵称
group_nickname: Optional[str] = Field(
default=None, nullable=True
) # 群昵称 (JSON, [{"group_id": str, "group_nick_name": str}])
# 印象
memory_points: Optional[str] = Field(default=None, nullable=True) # 记忆要点JSON格式存储
# 认识次数和时间
know_counts: int = Field(default=0) # 认识次数
first_known_time: Optional[datetime] = Field(default=None, nullable=True) # 首次认识时间
last_known_time: Optional[datetime] = Field(default=None, nullable=True) # 最后认识时间
class ChatSession(SQLModel, table=True):
"""存储聊天会话的模型"""
__tablename__ = "chat_sessions" # type: ignore
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID
created_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 创建时间
last_active_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 最后活跃时间
# 身份元数据
user_id: str = Field(index=True, max_length=255) # 用户ID
user_nickname: str = Field(index=True, max_length=255) # 用户昵称
user_cardname: Optional[str] = Field(default=None, max_length=255, nullable=True) # 用户备注名
group_id: Optional[str] = Field(index=True, default=None, max_length=255, nullable=True) # 群组id
group_name: Optional[str] = Field(index=True, default=None, max_length=255, nullable=True) # 群组名称
platform: str = Field(index=True, max_length=100) # 用户平台

View File

@ -206,8 +206,6 @@ class WebSocketLogHandler(logging.Handler):
# 如果是 JSON 格式(文件格式化器),解析它
message = formatted_msg
try:
import json
log_dict = json.loads(formatted_msg)
message = log_dict.get("event", formatted_msg)
except (json.JSONDecodeError, ValueError):

View File

@ -0,0 +1,55 @@
from pathlib import Path
from sqlmodel import select
import hashlib
from src.common.logger import get_logger
from src.common.database.database_model import BinaryData
from src.common.database.database import get_db_session
logger = get_logger("file_utils")
class FileUtils:
@staticmethod
def save_bytes_to_file(file_path: Path, data: bytes):
"""
将字节数据保存到指定文件路径
Args:
file_path (Path): 目标文件路径
data (bytes): 要保存的字节数据
Raises:
IOError: 如果写入文件时发生错误
"""
try:
file_path = file_path.absolute().resolve()
with file_path.open("wb") as f:
f.write(data)
with get_db_session() as session:
# 计算数据哈希
data_hash = hashlib.sha256(data).hexdigest()
# 创建 BinaryData 记录
binary_data_record = BinaryData(data_hash=data_hash, full_path=str(file_path))
session.add(binary_data_record)
session.commit()
except Exception as e:
logger.error(f"保存文件 {file_path} 失败: {e}")
raise e
@staticmethod
def get_file_path_by_hash(data_hash: str) -> Path:
"""
根据数据哈希获取文件路径
Args:
data_hash (str): 数据的哈希值
Returns:
Path: 对应的数据文件路径
"""
with get_db_session() as session:
statement = select(BinaryData).filter_by(data_hash=data_hash).limit(1)
if binary_data := session.exec(statement).first():
return Path(binary_data.full_path)
else:
raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录")

View File

@ -0,0 +1,15 @@
import msgpack
from src.common.data_models.message_component_model import MessageSequence
class MessageUtils:
@staticmethod
def from_db_record_msg_to_MaiSeq(raw_content: bytes) -> MessageSequence:
unpacked_data = msgpack.unpackb(raw_content)
return MessageSequence.from_dict(unpacked_data)
@staticmethod
async def from_MaiSeq_to_db_record_msg(msg: MessageSequence) -> bytes:
dict_representation = msg.to_dict()
return msgpack.packb(dict_representation) # type: ignore

View File

@ -31,6 +31,7 @@ from .official_configs import (
DebugConfig,
DreamConfig,
WebUIConfig,
DatabaseConfig,
)
from .model_configs import ModelInfo, ModelTaskConfig, APIProvider
from .config_base import ConfigBase, Field, AttributeData
@ -125,6 +126,9 @@ class Config(ConfigBase):
webui: WebUIConfig = Field(default_factory=WebUIConfig)
"""WebUI配置类"""
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
"""数据库配置类"""
class ModelConfig(ConfigBase):

View File

@ -653,4 +653,15 @@ class WebUIConfig(ConfigBase):
"""是否启用安全Cookie仅通过HTTPS传输默认false"""
enable_paragraph_content: bool = False
"""是否在知识图谱中加载段落完整内容需要加载embedding store会占用额外内存"""
"""是否在知识图谱中加载段落完整内容需要加载embedding store会占用额外内存"""
class DatabaseConfig(ConfigBase):
"""数据库配置类"""
save_binary_data: bool = False
"""
是否将消息中的二进制数据保存为独立文件
若启用消息中的语音等二进制数据将会保存为独立文件并在消息中以特殊标记替代启用会导致数据文件夹体积增大但可以实现二次识别等功能
若禁用则消息中的二进制将会在识别后删除并在消息中使用识别结果替代无法二次识别
该配置项仅影响新存储的消息已有消息不会受到影响
"""