mirror of https://github.com/Mai-with-u/MaiBot.git
合并到远程
parent
60f76e4d4e
commit
b9f3c17e14
2
bot.py
2
bot.py
|
|
@ -1,4 +1,4 @@
|
||||||
# raise RuntimeError("System Not Ready")
|
raise RuntimeError("System Not Ready")
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ def _install_stub_modules(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
class _Result:
|
class _Result:
|
||||||
def scalars(self):
|
def scalars(self):
|
||||||
return self
|
return self
|
||||||
|
|
@ -231,8 +231,8 @@ def _install_stub_modules(monkeypatch):
|
||||||
|
|
||||||
def import_emoji_manager_new(monkeypatch):
|
def import_emoji_manager_new(monkeypatch):
|
||||||
_install_stub_modules(monkeypatch)
|
_install_stub_modules(monkeypatch)
|
||||||
file_path = Path(__file__).resolve().parents[2] / "src" / "chat" / "emoji_system" / "emoji_manager_new.py"
|
file_path = Path(__file__).resolve().parents[2] / "src" / "chat" / "emoji_system" / "emoji_manager.py"
|
||||||
spec = importlib.util.spec_from_file_location("emoji_manager_new", file_path)
|
spec = importlib.util.spec_from_file_location("emoji_manager", file_path)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
monkeypatch.setitem(sys.modules, "emoji_manager_new", module)
|
monkeypatch.setitem(sys.modules, "emoji_manager_new", module)
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
|
|
@ -446,7 +446,7 @@ def test_load_emojis_from_db_empty(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
return _Result()
|
return _Result()
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
|
|
@ -487,7 +487,7 @@ def test_load_emojis_from_db_partial_bad_records(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
return _Result()
|
return _Result()
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
|
|
@ -524,7 +524,7 @@ def test_load_emojis_from_db_execute_error(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
raise RuntimeError("execute failed")
|
raise RuntimeError("execute failed")
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
|
|
@ -581,7 +581,7 @@ def test_load_emojis_from_db_scalars_all_error(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
return _Result()
|
return _Result()
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
|
|
@ -799,6 +799,8 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
def limit(self, _num):
|
||||||
|
return self
|
||||||
|
|
||||||
def _select(_model):
|
def _select(_model):
|
||||||
return _Select()
|
return _Select()
|
||||||
|
|
@ -817,7 +819,7 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
return _Result()
|
return _Result()
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
|
|
@ -887,6 +889,9 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def limit(self, _num):
|
||||||
|
return self
|
||||||
|
|
||||||
def _select(_model):
|
def _select(_model):
|
||||||
return _Select()
|
return _Select()
|
||||||
|
|
@ -898,7 +903,7 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
raise RuntimeError("db delete failed")
|
raise RuntimeError("db delete failed")
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
|
|
@ -942,6 +947,8 @@ def test_delete_emoji_success(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
def limit(self, _num):
|
||||||
|
return self
|
||||||
|
|
||||||
def _select(_model):
|
def _select(_model):
|
||||||
return _Select()
|
return _Select()
|
||||||
|
|
@ -966,7 +973,7 @@ def test_delete_emoji_success(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
return _Result()
|
return _Result()
|
||||||
|
|
||||||
def delete(self, _record):
|
def delete(self, _record):
|
||||||
|
|
@ -998,6 +1005,8 @@ def test_update_emoji_usage_success(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
def limit(self, _num):
|
||||||
|
return self
|
||||||
|
|
||||||
def _select(_model):
|
def _select(_model):
|
||||||
return _Select()
|
return _Select()
|
||||||
|
|
@ -1021,7 +1030,7 @@ def test_update_emoji_usage_success(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
return _Result()
|
return _Result()
|
||||||
|
|
||||||
def add(self, _record):
|
def add(self, _record):
|
||||||
|
|
@ -1051,6 +1060,9 @@ def test_update_emoji_usage_missing_record(monkeypatch):
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def limit(self, _num):
|
||||||
|
return self
|
||||||
|
|
||||||
def _select(_model):
|
def _select(_model):
|
||||||
return _Select()
|
return _Select()
|
||||||
|
|
||||||
|
|
@ -1068,7 +1080,7 @@ def test_update_emoji_usage_missing_record(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
return _Result()
|
return _Result()
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
|
|
@ -1094,6 +1106,8 @@ def test_update_emoji_usage_execute_error(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
def limit(self, _num):
|
||||||
|
return self
|
||||||
|
|
||||||
def _select(_model):
|
def _select(_model):
|
||||||
return _Select()
|
return _Select()
|
||||||
|
|
@ -1105,7 +1119,7 @@ def test_update_emoji_usage_execute_error(monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute(self, _statement):
|
def exec(self, _statement):
|
||||||
raise RuntimeError("execute failed")
|
raise RuntimeError("execute failed")
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ class EmojiManager:
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(Images)
|
statement = select(Images)
|
||||||
results = session.execute(statement).scalars().all()
|
results = session.exec(statement).all()
|
||||||
for record in results:
|
for record in results:
|
||||||
try:
|
try:
|
||||||
emoji = MaiEmoji.from_db_instance(record)
|
emoji = MaiEmoji.from_db_instance(record)
|
||||||
|
|
@ -144,8 +144,8 @@ class EmojiManager:
|
||||||
# 删除数据库记录
|
# 删除数据库记录
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI)
|
statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
if image_record := session.execute(statement).scalars().first():
|
if image_record := session.exec(statement).first():
|
||||||
session.delete(image_record)
|
session.delete(image_record)
|
||||||
logger.info(f"[删除表情包] 成功删除数据库中的表情包记录: {emoji.emoji_hash}")
|
logger.info(f"[删除表情包] 成功删除数据库中的表情包记录: {emoji.emoji_hash}")
|
||||||
else:
|
else:
|
||||||
|
|
@ -170,8 +170,8 @@ class EmojiManager:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI)
|
statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
if image_record := session.execute(statement).scalars().first():
|
if image_record := session.exec(statement).first():
|
||||||
image_record.query_count += 1
|
image_record.query_count += 1
|
||||||
image_record.last_used_time = datetime.now()
|
image_record.last_used_time = datetime.now()
|
||||||
session.add(image_record)
|
session.add(image_record)
|
||||||
|
|
|
||||||
|
|
@ -1,54 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataModel:
|
class BaseDataModel:
|
||||||
def deepcopy(self):
|
def deepcopy(self):
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
|
|
||||||
def transform_class_to_dict(obj: Any) -> Any:
|
|
||||||
# sourcery skip: assign-if-exp, reintroduce-else
|
|
||||||
"""
|
|
||||||
将对象或容器中的 BaseDataModel 子类(类对象)或 BaseDataModel 实例
|
|
||||||
递归转换为普通 dict,不修改原对象。
|
|
||||||
- 对于类对象(isinstance(value, type) 且 issubclass(..., BaseDataModel)),
|
|
||||||
读取类的 __dict__ 中非 dunder 项并递归转换。
|
|
||||||
- 对于实例(isinstance(value, BaseDataModel)),读取 vars(instance) 并递归转换。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _transform(value: Any) -> Any:
|
|
||||||
# 值是类对象且为 BaseDataModel 的子类
|
|
||||||
if isinstance(value, type) and issubclass(value, BaseDataModel):
|
|
||||||
return {k: _transform(v) for k, v in value.__dict__.items() if not k.startswith("__") and not callable(v)}
|
|
||||||
|
|
||||||
# 值是 BaseDataModel 的实例
|
|
||||||
if isinstance(value, BaseDataModel):
|
|
||||||
return {k: _transform(v) for k, v in vars(value).items()}
|
|
||||||
|
|
||||||
# 常见容器类型,递归处理
|
|
||||||
if isinstance(value, dict):
|
|
||||||
return {k: _transform(v) for k, v in value.items()}
|
|
||||||
if isinstance(value, list):
|
|
||||||
return [_transform(v) for v in value]
|
|
||||||
if isinstance(value, tuple):
|
|
||||||
return tuple(_transform(v) for v in value)
|
|
||||||
if isinstance(value, set):
|
|
||||||
return {_transform(v) for v in value}
|
|
||||||
# 基本类型,直接返回
|
|
||||||
return value
|
|
||||||
|
|
||||||
result = _transform(obj)
|
|
||||||
|
|
||||||
def flatten(target_dict: dict):
|
|
||||||
flat_dict = {}
|
|
||||||
for k, v in target_dict.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
# 递归扁平化子字典
|
|
||||||
sub_flat = flatten(v)
|
|
||||||
flat_dict.update(sub_flat)
|
|
||||||
else:
|
|
||||||
flat_dict[k] = v
|
|
||||||
return flat_dict
|
|
||||||
|
|
||||||
return flatten(result) if isinstance(result, dict) else result
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,226 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from copy import deepcopy
|
||||||
|
from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo
|
||||||
|
from typing import Optional, List, Union, Dict, Any
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("base_message_component_model")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessageComponentModel(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def to_seg(self) -> Seg:
|
||||||
|
"""将消息组件转换为 maim_message.Seg 对象"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
return deepcopy(self)
|
||||||
|
|
||||||
|
|
||||||
|
class ByteComponent:
|
||||||
|
def __init__(self, *, binary_hash: str, content: Optional[str] = None, binary_data: Optional[bytes] = None) -> None:
|
||||||
|
self.content: str = content if content is not None else ""
|
||||||
|
"""处理后的内容"""
|
||||||
|
self.binary_data: bytes = binary_data if binary_data is not None else b""
|
||||||
|
"""原始二进制数据"""
|
||||||
|
self.binary_hash: str = hashlib.sha256(self.binary_data).hexdigest() if self.binary_data else binary_hash
|
||||||
|
"""二进制数据的 SHA256 哈希值,用于唯一标识该二进制数据"""
|
||||||
|
|
||||||
|
|
||||||
|
class TextComponent(BaseMessageComponentModel):
|
||||||
|
def __init__(self, text: str):
|
||||||
|
self.text = text
|
||||||
|
assert isinstance(text, str), "TextComponent 的 text 必须是字符串类型"
|
||||||
|
|
||||||
|
async def to_seg(self) -> Seg:
|
||||||
|
return Seg(type="text", data=self.text)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageComponent(BaseMessageComponentModel, ByteComponent):
|
||||||
|
async def load_image_binary(self):
|
||||||
|
if not self.binary_data:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def to_seg(self) -> Seg:
|
||||||
|
if not self.binary_data:
|
||||||
|
await self.load_image_binary()
|
||||||
|
return Seg(type="image", data=base64.b64encode(self.binary_data).decode())
|
||||||
|
|
||||||
|
|
||||||
|
class EmojiComponent(BaseMessageComponentModel, ByteComponent):
|
||||||
|
async def load_emoji_binary(self) -> None:
|
||||||
|
"""
|
||||||
|
加载表情的二进制数据,如果 binary_data 为空,则通过 emoji_hash 从表情管理器加载
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果 binary_data 为空且缺少 emoji_hash
|
||||||
|
ValueError: 如果无法通过 emoji_hash 加载表情二进制数据
|
||||||
|
"""
|
||||||
|
if not self.binary_data:
|
||||||
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
|
|
||||||
|
if not (
|
||||||
|
emoji := emoji_manager.get_emoji_by_hash(self.binary_hash)
|
||||||
|
or emoji_manager.get_emoji_by_hash_from_db(self.binary_hash)
|
||||||
|
):
|
||||||
|
raise ValueError(f"无法通过 emoji_hash 加载表情二进制数据: {self.binary_hash}")
|
||||||
|
try:
|
||||||
|
self.binary_data = await asyncio.to_thread(emoji.full_path.read_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"通过 emoji_hash 加载表情二进制数据时发生错误: {e}") from e
|
||||||
|
|
||||||
|
async def to_seg(self) -> Seg:
|
||||||
|
if not self.binary_data:
|
||||||
|
await self.load_emoji_binary()
|
||||||
|
return Seg(type="emoji", data=base64.b64encode(self.binary_data).decode())
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceComponent(BaseMessageComponentModel, ByteComponent):
|
||||||
|
async def load_voice_binary(self) -> None:
|
||||||
|
if not self.binary_data:
|
||||||
|
from src.common.utils.utils_file import FileUtils
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_path = FileUtils.get_file_path_by_hash(self.binary_hash)
|
||||||
|
self.binary_data = await asyncio.to_thread(file_path.read_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"通过 voice_hash 加载语音二进制数据时发生错误: {e}") from e
|
||||||
|
|
||||||
|
async def to_seg(self) -> Seg:
|
||||||
|
if not self.binary_data:
|
||||||
|
await self.load_voice_binary()
|
||||||
|
return Seg(type="voice", data=base64.b64encode(self.binary_data).decode())
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardNodeComponent(BaseMessageComponentModel):
|
||||||
|
def __init__(self, forward_components: List["ForwardComponent"]):
|
||||||
|
self.forward_components = forward_components
|
||||||
|
assert isinstance(forward_components, list), "ForwardNodeComponent 的 forward_components 必须是列表类型"
|
||||||
|
assert all(isinstance(comp, ForwardComponent) for comp in forward_components), (
|
||||||
|
"ForwardNodeComponent 的 forward_components 列表中必须全部是 ForwardComponent 类型"
|
||||||
|
)
|
||||||
|
assert forward_components, "ForwardNodeComponent 的 forward_components 不能为空列表"
|
||||||
|
|
||||||
|
async def to_seg(self) -> "Seg":
|
||||||
|
resp: List[Dict[str, Any]] = []
|
||||||
|
for comp in self.forward_components:
|
||||||
|
data = await comp.to_seg()
|
||||||
|
sender_info = UserInfo(None, comp.user_id, comp.user_nickname, comp.user_cardname)
|
||||||
|
base_message_info = BaseMessageInfo(user_info=sender_info)
|
||||||
|
base_message = MessageBase(base_message_info, data)
|
||||||
|
resp.append(base_message.to_dict())
|
||||||
|
return Seg(type="forward", data=resp) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class DictComponent:
|
||||||
|
def __init__(self, data: Dict[str, Any]):
|
||||||
|
self.data = data
|
||||||
|
assert isinstance(data, dict), "DictComponent 的 data 必须是字典类型"
|
||||||
|
|
||||||
|
|
||||||
|
StandardMessageComponents = Union[
|
||||||
|
TextComponent,
|
||||||
|
ImageComponent,
|
||||||
|
EmojiComponent,
|
||||||
|
VoiceComponent,
|
||||||
|
ForwardNodeComponent,
|
||||||
|
DictComponent,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardComponent(BaseMessageComponentModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_nickname: str,
|
||||||
|
content: List[StandardMessageComponents],
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
user_cardname: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.user_nickname: str = user_nickname
|
||||||
|
self.content: List[StandardMessageComponents] = content
|
||||||
|
self.user_id: Optional[str] = user_id
|
||||||
|
self.user_cardname: Optional[str] = user_cardname
|
||||||
|
assert self.content, "ForwardComponent 的 content 不能为空"
|
||||||
|
|
||||||
|
async def to_seg(self) -> "Seg":
|
||||||
|
return Seg(
|
||||||
|
type="seglist", data=[await comp.to_seg() for comp in self.content if not isinstance(comp, DictComponent)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageSequence:
|
||||||
|
def __init__(self, components: List[StandardMessageComponents]):
|
||||||
|
self.components: List[StandardMessageComponents] = components
|
||||||
|
|
||||||
|
def to_dict(self) -> List[Dict[str, Any]]:
|
||||||
|
return [self._item_2_dict(comp) for comp in self.components]
|
||||||
|
|
||||||
|
def _item_2_dict(self, item: StandardMessageComponents) -> Dict[str, Any]:
|
||||||
|
if isinstance(item, TextComponent):
|
||||||
|
return {"type": "text", "data": item.text}
|
||||||
|
elif isinstance(item, ImageComponent):
|
||||||
|
if not item.content:
|
||||||
|
raise RuntimeError("ImageComponent content 未初始化")
|
||||||
|
return {"type": "image", "data": item.content, "hash": item.binary_hash}
|
||||||
|
elif isinstance(item, EmojiComponent):
|
||||||
|
if not item.content:
|
||||||
|
raise RuntimeError("EmojiComponent content 未初始化")
|
||||||
|
return {"type": "emoji", "data": item.content, "hash": item.binary_hash}
|
||||||
|
elif isinstance(item, VoiceComponent):
|
||||||
|
if not item.content:
|
||||||
|
raise RuntimeError("VoiceComponent content 未初始化")
|
||||||
|
return {"type": "voice", "data": item.content, "hash": item.binary_hash}
|
||||||
|
elif isinstance(item, ForwardNodeComponent):
|
||||||
|
return {
|
||||||
|
"type": "forward",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"user_id": comp.user_id,
|
||||||
|
"user_nickname": comp.user_nickname,
|
||||||
|
"user_cardname": comp.user_cardname,
|
||||||
|
"content": [self._item_2_dict(c) for c in comp.content],
|
||||||
|
}
|
||||||
|
for comp in item.forward_components
|
||||||
|
],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent")
|
||||||
|
return {"type": "dict", "data": item.data}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence":
|
||||||
|
components: List[StandardMessageComponents] = []
|
||||||
|
components.extend(cls._dict_2_item(item) for item in data)
|
||||||
|
return cls(components=components)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents:
|
||||||
|
item_type = item.get("type")
|
||||||
|
if item_type == "text":
|
||||||
|
return TextComponent(text=item["data"])
|
||||||
|
elif item_type == "image":
|
||||||
|
return ImageComponent(binary_hash=item["hash"], content=item["data"])
|
||||||
|
elif item_type == "emoji":
|
||||||
|
return EmojiComponent(binary_hash=item["hash"], content=item["data"])
|
||||||
|
elif item_type == "voice":
|
||||||
|
return VoiceComponent(binary_hash=item["hash"], content=item["data"])
|
||||||
|
elif item_type == "forward":
|
||||||
|
forward_components = []
|
||||||
|
for fc in item["data"]:
|
||||||
|
content = [cls._dict_2_item(c) for c in fc["content"]]
|
||||||
|
forward_component = ForwardComponent(
|
||||||
|
user_nickname=fc["user_nickname"],
|
||||||
|
user_id=fc.get("user_id"),
|
||||||
|
user_cardname=fc.get("user_cardname"),
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
forward_components.append(forward_component)
|
||||||
|
return ForwardNodeComponent(forward_components=forward_components)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unofficial component type in dict: {item_type}, defaulting to DictComponent")
|
||||||
|
return DictComponent(data=item.get("data") or {})
|
||||||
|
|
@ -1,210 +0,0 @@
|
||||||
from typing import Optional, TYPE_CHECKING, List, Tuple, Union, Dict, Any
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from . import BaseDataModel
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .database_data_model import DatabaseMessages
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageAndActionModel(BaseDataModel):
|
|
||||||
chat_id: str = field(default_factory=str)
|
|
||||||
time: float = field(default_factory=float)
|
|
||||||
user_id: str = field(default_factory=str)
|
|
||||||
user_platform: str = field(default_factory=str)
|
|
||||||
user_nickname: str = field(default_factory=str)
|
|
||||||
user_cardname: Optional[str] = None
|
|
||||||
processed_plain_text: Optional[str] = None
|
|
||||||
display_message: Optional[str] = None
|
|
||||||
chat_info_platform: str = field(default_factory=str)
|
|
||||||
is_action_record: bool = field(default=False)
|
|
||||||
action_name: Optional[str] = None
|
|
||||||
is_command: bool = field(default=False)
|
|
||||||
intercept_message_level: int = field(default=0)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
|
|
||||||
return cls(
|
|
||||||
chat_id=message.chat_id,
|
|
||||||
time=message.time,
|
|
||||||
user_id=message.user_info.user_id,
|
|
||||||
user_platform=message.user_info.platform,
|
|
||||||
user_nickname=message.user_info.user_nickname,
|
|
||||||
user_cardname=message.user_info.user_cardname,
|
|
||||||
processed_plain_text=message.processed_plain_text,
|
|
||||||
display_message=message.display_message,
|
|
||||||
chat_info_platform=message.chat_info.platform,
|
|
||||||
is_command=message.is_command,
|
|
||||||
intercept_message_level=getattr(message, "intercept_message_level", 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReplyContentType(Enum):
|
|
||||||
TEXT = "text"
|
|
||||||
IMAGE = "image"
|
|
||||||
EMOJI = "emoji"
|
|
||||||
COMMAND = "command"
|
|
||||||
VOICE = "voice"
|
|
||||||
FORWARD = "forward"
|
|
||||||
HYBRID = "hybrid" # 混合类型,包含多种内容
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ForwardNode(BaseDataModel):
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
user_nickname: Optional[str] = None
|
|
||||||
content: Union[List["ReplyContent"], str] = field(default_factory=list)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct_as_id_reference(cls, message_id: str) -> "ForwardNode":
|
|
||||||
return cls(user_id="", user_nickname="", content=message_id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct_as_created_node(
|
|
||||||
cls, user_id: str, user_nickname: str, content: List["ReplyContent"]
|
|
||||||
) -> "ForwardNode":
|
|
||||||
return cls(user_id=user_id, user_nickname=user_nickname, content=content)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ReplyContent(BaseDataModel):
|
|
||||||
content_type: ReplyContentType | str
|
|
||||||
content: Union[str, Dict, List[ForwardNode], List["ReplyContent"]] # 支持嵌套的 ReplyContent
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct_as_text(cls, text: str):
|
|
||||||
return cls(content_type=ReplyContentType.TEXT, content=text)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct_as_image(cls, image_base64: str):
|
|
||||||
return cls(content_type=ReplyContentType.IMAGE, content=image_base64)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct_as_voice(cls, voice_base64: str):
|
|
||||||
return cls(content_type=ReplyContentType.VOICE, content=voice_base64)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct_as_emoji(cls, emoji_str: str):
|
|
||||||
return cls(content_type=ReplyContentType.EMOJI, content=emoji_str)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct_as_command(cls, command_arg: Dict):
|
|
||||||
return cls(content_type=ReplyContentType.COMMAND, content=command_arg)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct_as_hybrid(cls, hybrid_content: List[Tuple[ReplyContentType | str, str]]):
|
|
||||||
hybrid_content_list: List[ReplyContent] = []
|
|
||||||
for content_type, content in hybrid_content:
|
|
||||||
assert content_type not in [
|
|
||||||
ReplyContentType.HYBRID,
|
|
||||||
ReplyContentType.FORWARD,
|
|
||||||
ReplyContentType.VOICE,
|
|
||||||
ReplyContentType.COMMAND,
|
|
||||||
], "混合内容的每个项不能是混合、转发、语音或命令类型"
|
|
||||||
assert isinstance(content, str), "混合内容的每个项必须是字符串"
|
|
||||||
hybrid_content_list.append(ReplyContent(content_type=content_type, content=content))
|
|
||||||
return cls(content_type=ReplyContentType.HYBRID, content=hybrid_content_list)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct_as_forward(cls, forward_nodes: List[ForwardNode]):
|
|
||||||
return cls(content_type=ReplyContentType.FORWARD, content=forward_nodes)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if isinstance(self.content_type, ReplyContentType):
|
|
||||||
if self.content_type not in [ReplyContentType.HYBRID, ReplyContentType.FORWARD] and isinstance(
|
|
||||||
self.content, List
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"非混合类型/转发类型的内容不能是列表,content_type: {self.content_type}, content: {self.content}"
|
|
||||||
)
|
|
||||||
elif self.content_type in [ReplyContentType.HYBRID, ReplyContentType.FORWARD]:
|
|
||||||
if not isinstance(self.content, List):
|
|
||||||
raise ValueError(
|
|
||||||
f"混合类型/转发类型的内容必须是列表,content_type: {self.content_type}, content: {self.content}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ReplySetModel(BaseDataModel):
|
|
||||||
"""
|
|
||||||
回复集数据模型,用于多种回复类型的返回
|
|
||||||
"""
|
|
||||||
|
|
||||||
reply_data: List[ReplyContent] = field(default_factory=list)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.reply_data)
|
|
||||||
|
|
||||||
def add_text_content(self, text: str):
|
|
||||||
"""
|
|
||||||
添加文本内容
|
|
||||||
Args:
|
|
||||||
text: 文本内容
|
|
||||||
"""
|
|
||||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text))
|
|
||||||
|
|
||||||
def add_image_content(self, image_base64: str):
|
|
||||||
"""
|
|
||||||
添加图片内容,base64编码的图片数据
|
|
||||||
Args:
|
|
||||||
image_base64: base64编码的图片数据
|
|
||||||
"""
|
|
||||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.IMAGE, content=image_base64))
|
|
||||||
|
|
||||||
def add_voice_content(self, voice_base64: str):
|
|
||||||
"""
|
|
||||||
添加语音内容,base64编码的音频数据
|
|
||||||
Args:
|
|
||||||
voice_base64: base64编码的音频数据
|
|
||||||
"""
|
|
||||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64))
|
|
||||||
|
|
||||||
def add_hybrid_content_by_raw(self, hybrid_content: List[Tuple[ReplyContentType | str, str]]):
|
|
||||||
"""
|
|
||||||
添加混合型内容,可以包含text, image, emoji的任意组合
|
|
||||||
Args:
|
|
||||||
hybrid_content: 元组 (类型, 消息内容) 构成的列表,如[(ReplyContentType.TEXT, "Hello"), (ReplyContentType.IMAGE, "<base64")]
|
|
||||||
"""
|
|
||||||
hybrid_content_list: List[ReplyContent] = []
|
|
||||||
for content_type, content in hybrid_content:
|
|
||||||
assert content_type not in [
|
|
||||||
ReplyContentType.HYBRID,
|
|
||||||
ReplyContentType.FORWARD,
|
|
||||||
ReplyContentType.VOICE,
|
|
||||||
ReplyContentType.COMMAND,
|
|
||||||
], "混合内容的每个项不能是混合、转发、语音或命令类型"
|
|
||||||
assert isinstance(content, str), "混合内容的每个项必须是字符串"
|
|
||||||
hybrid_content_list.append(ReplyContent(content_type=content_type, content=content))
|
|
||||||
|
|
||||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_content_list))
|
|
||||||
|
|
||||||
def add_hybrid_content(self, hybrid_content: List[ReplyContent]):
|
|
||||||
"""
|
|
||||||
添加混合型内容,使用已经构造好的 ReplyContent 列表
|
|
||||||
Args:
|
|
||||||
hybrid_content: ReplyContent 构成的列表,如[ReplyContent(ReplyContentType.TEXT, "Hello"), ReplyContent(ReplyContentType.IMAGE, "<base64")]
|
|
||||||
"""
|
|
||||||
for content in hybrid_content:
|
|
||||||
assert content.content_type not in [
|
|
||||||
ReplyContentType.HYBRID,
|
|
||||||
ReplyContentType.FORWARD,
|
|
||||||
ReplyContentType.VOICE,
|
|
||||||
ReplyContentType.COMMAND,
|
|
||||||
], "混合内容的每个项不能是混合、转发、语音或命令类型"
|
|
||||||
assert isinstance(content.content, str), "混合内容的每个项必须是字符串"
|
|
||||||
|
|
||||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_content))
|
|
||||||
|
|
||||||
def add_custom_content(self, content_type: str, content: Any):
|
|
||||||
"""
|
|
||||||
添加自定义类型的内容"""
|
|
||||||
self.reply_data.append(ReplyContent(content_type=content_type, content=content))
|
|
||||||
|
|
||||||
def add_forward_content(self, forward_content: List[ForwardNode]):
|
|
||||||
"""添加转发内容,可以是字符串或ReplyContent,嵌套的转发内容需要自己构造放入"""
|
|
||||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_content))
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
# 对于`message_data_model.py`中`class ReplyContent`的规划解读
|
|
||||||
|
|
||||||
分类讨论如下:
|
|
||||||
- `ReplyContent.TEXT`: 单独的文本,`_level = 0`,`content`为`str`类型。
|
|
||||||
- `ReplyContent.IMAGE`: 单独的图片,`_level = 0`,`content`为`str`类型(图片base64)。
|
|
||||||
- `ReplyContent.EMOJI`: 单独的表情包,`_level = 0`,`content`为`str`类型(图片base64)。
|
|
||||||
- `ReplyContent.VOICE`: 单独的语音,`_level = 0`,`content`为`str`类型(语音base64)。
|
|
||||||
- `ReplyContent.HYBRID`: 混合内容,`_level = 0`
|
|
||||||
- 其应该是一个列表,列表内应该只接受`str`类型的内容(图片和文本混合体)
|
|
||||||
- `ReplyContent.FORWARD`: 转发消息,`_level = n`
|
|
||||||
- 其应该是一个列表,列表接受`str`类型(图片/文本),`ReplyContent`类型(嵌套转发,嵌套有最高层数限制)
|
|
||||||
- `ReplyContent.COMMAND`: 指令消息,`_level = 0`
|
|
||||||
- 其应该是一个列表,列表内应该只接受`Dict`类型的内容
|
|
||||||
|
|
||||||
未来规划:
|
|
||||||
- `ReplyContent.AT`: 单独的艾特,`_level = 0`,`content`为`str`类型(用户ID)。
|
|
||||||
|
|
||||||
内容构造方式:
|
|
||||||
- 对于`TEXT`, `IMAGE`, `EMOJI`, `VOICE`,直接传入对应类型的内容,且`content`应该为`str`。
|
|
||||||
- 对于`COMMAND`,传入一个字典,字典内的内容类型应符合上述规定。
|
|
||||||
- 对于`HYBRID`, `FORWARD`,传入一个列表,列表内的内容类型应符合上述规定。
|
|
||||||
|
|
||||||
因此,我们的类型注解应该是:
|
|
||||||
```python
|
|
||||||
from typing import Union, List, Dict
|
|
||||||
|
|
||||||
ReplyContentType = Union[
|
|
||||||
str, # TEXT, IMAGE, EMOJI, VOICE
|
|
||||||
List[Union[str, 'ReplyContent']], # HYBRID, FORWARD
|
|
||||||
Dict # COMMAND
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
现在`_level`被移除了,在解析的时候显式地检查内容的类型和结构即可。
|
|
||||||
|
|
||||||
`send_api`的custom_reply_set_to_stream仅在特定的类型下提供reply)message
|
|
||||||
|
|
@ -1,57 +0,0 @@
|
||||||
# 有关转发消息和其他消息的构建类型说明
|
|
||||||
```mermaid
|
|
||||||
graph LR;
|
|
||||||
direction TB;
|
|
||||||
A[ReplySet] --- B[ReplyContent];
|
|
||||||
A --- C["ReplyContent"];
|
|
||||||
A --- K["ReplyContent"];
|
|
||||||
A --- L["ReplyContent"];
|
|
||||||
A --- N["ReplyContent"];
|
|
||||||
A --- D[...];
|
|
||||||
B --- E["Text (in str)"];
|
|
||||||
B --- F["Image (in base64)"];
|
|
||||||
C --- G["Voice (in base64)"];
|
|
||||||
B --- I["Emoji (in base64)"];
|
|
||||||
subgraph "可行内容(以下的任意组合)";
|
|
||||||
subgraph "转发消息(Forward)"
|
|
||||||
M["List[ForwardNode]"]
|
|
||||||
end
|
|
||||||
subgraph "混合消息(Hybrid)"
|
|
||||||
J["List[ReplyContent] (要求只能包含普通消息)"]
|
|
||||||
end
|
|
||||||
subgraph "命令消息(Command)"
|
|
||||||
H["Command (in Dict)"]
|
|
||||||
end
|
|
||||||
subgraph "语音消息"
|
|
||||||
G
|
|
||||||
end
|
|
||||||
subgraph "普通消息"
|
|
||||||
E
|
|
||||||
F
|
|
||||||
I
|
|
||||||
end
|
|
||||||
end
|
|
||||||
N --- H
|
|
||||||
K --- J
|
|
||||||
L --- M
|
|
||||||
subgraph ForwardNodes
|
|
||||||
O["ForwardNode"]
|
|
||||||
P["ForwardNode"]
|
|
||||||
Q["ForwardNode"]
|
|
||||||
end
|
|
||||||
M --- O
|
|
||||||
M --- P
|
|
||||||
M --- Q
|
|
||||||
subgraph "内容 (message_id引用法)"
|
|
||||||
P --- U["content: str, 引用已有消息的有效ID"];
|
|
||||||
end
|
|
||||||
subgraph "内容 (生成法)"
|
|
||||||
O --- R["user_id: str"];
|
|
||||||
O --- S["user_nickname: str"];
|
|
||||||
O --- T["content: List[ReplyContent], 为这个转发节点的消息内容"];
|
|
||||||
end
|
|
||||||
```
|
|
||||||
|
|
||||||
另外,自定义消息类型我们在这里不做讨论。
|
|
||||||
|
|
||||||
以上列出了所有可能的ReplySet构建方式,下面我们来解释一下各个类型的含义。
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from sqlalchemy import create_engine, event, text
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy import event
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy import inspect as sqlalchemy_inspect
|
from sqlmodel import create_engine, Session
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
|
||||||
from typing import TYPE_CHECKING, Generator
|
from typing import TYPE_CHECKING, Generator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -27,19 +27,12 @@ DATABASE_URL = f"sqlite:///{_DB_FILE}"
|
||||||
def set_sqlite_pragma(dbapi_connection: "SQLite3Connection", connection_record):
|
def set_sqlite_pragma(dbapi_connection: "SQLite3Connection", connection_record):
|
||||||
"""
|
"""
|
||||||
为每个新的数据库连接设置 SQLite PRAGMA。
|
为每个新的数据库连接设置 SQLite PRAGMA。
|
||||||
|
|
||||||
这些设置优化了并发性能和数据安全性:
|
|
||||||
- journal_mode=WAL: 启用预写式日志,提高并发性能
|
|
||||||
- cache_size: 设置缓存大小为 64MB
|
|
||||||
- foreign_keys: 启用外键约束
|
|
||||||
- synchronous=NORMAL: 平衡性能和数据安全
|
|
||||||
- busy_timeout: 设置1秒超时,避免锁定冲突
|
|
||||||
"""
|
"""
|
||||||
cursor = dbapi_connection.cursor()
|
cursor = dbapi_connection.cursor()
|
||||||
cursor.execute("PRAGMA journal_mode=WAL")
|
cursor.execute("PRAGMA journal_mode=WAL")
|
||||||
cursor.execute("PRAGMA cache_size=-64000") # 负值表示KB,64000KB = 64MB
|
cursor.execute("PRAGMA cache_size=-64000") # 负值表示KB,64000KB = 64MB
|
||||||
cursor.execute("PRAGMA foreign_keys=ON")
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
cursor.execute("PRAGMA synchronous=NORMAL") # NORMAL 模式在WAL下是安全的
|
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||||
cursor.execute("PRAGMA busy_timeout=1000") # 1秒超时
|
cursor.execute("PRAGMA busy_timeout=1000") # 1秒超时
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
|
@ -52,11 +45,12 @@ engine = create_engine(
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建会话工厂
|
# 创建会话工厂(使用 sqlmodel.Session)
|
||||||
SessionLocal = sessionmaker(
|
SessionLocal = sessionmaker(
|
||||||
autocommit=False,
|
autocommit=False,
|
||||||
autoflush=False,
|
autoflush=False,
|
||||||
bind=engine,
|
bind=engine,
|
||||||
|
class_=Session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -96,7 +90,6 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||||
session = SessionLocal()
|
session = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
# 如果启用自动提交且没有异常,则提交事务
|
|
||||||
if auto_commit:
|
if auto_commit:
|
||||||
session.commit()
|
session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -132,59 +125,3 @@ def get_db() -> Generator[Session, None, None]:
|
||||||
yield session
|
yield session
|
||||||
finally:
|
finally:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
class _AtomicContext:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._session: Session | None = None
|
|
||||||
|
|
||||||
def __enter__(self) -> Session:
|
|
||||||
self._session = SessionLocal()
|
|
||||||
self._session.begin()
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb) -> None:
|
|
||||||
if self._session is None:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
if exc_type is None:
|
|
||||||
self._session.commit()
|
|
||||||
else:
|
|
||||||
self._session.rollback()
|
|
||||||
finally:
|
|
||||||
self._session.close()
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseCompat:
|
|
||||||
"""兼容旧 db 调用接口(Peewee 风格),底层使用 SQLAlchemy。"""
|
|
||||||
|
|
||||||
def connect(self, reuse_if_open: bool = True) -> None:
|
|
||||||
# SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。
|
|
||||||
_ = reuse_if_open
|
|
||||||
|
|
||||||
def create_tables(self, models: list[type], safe: bool = True) -> None:
|
|
||||||
_ = safe
|
|
||||||
tables = [model.__table__ for model in models if hasattr(model, "__table__")]
|
|
||||||
if not tables:
|
|
||||||
return
|
|
||||||
from sqlmodel import SQLModel
|
|
||||||
|
|
||||||
SQLModel.metadata.create_all(engine, tables=tables)
|
|
||||||
|
|
||||||
def atomic(self) -> _AtomicContext:
|
|
||||||
return _AtomicContext()
|
|
||||||
|
|
||||||
def execute_sql(self, sql: str):
|
|
||||||
with engine.connect() as conn:
|
|
||||||
result = conn.execute(text(sql))
|
|
||||||
conn.commit()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def table_exists(self, model: type) -> bool:
|
|
||||||
if not hasattr(model, "__tablename__"):
|
|
||||||
return False
|
|
||||||
inspector = sqlalchemy_inspect(engine)
|
|
||||||
return inspector.has_table(model.__tablename__)
|
|
||||||
|
|
||||||
|
|
||||||
db = DatabaseCompat()
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from sqlalchemy import Column, Float, Enum as SQLEnum
|
from sqlalchemy import Column, Float, Enum as SQLEnum
|
||||||
from sqlmodel import SQLModel, Field
|
from sqlmodel import SQLModel, Field, LargeBinary
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
@ -45,8 +45,8 @@ class Messages(SQLModel, table=True):
|
||||||
is_notify: bool = Field(default=False) # 是否为通知消息
|
is_notify: bool = Field(default=False) # 是否为通知消息
|
||||||
|
|
||||||
# 消息内容
|
# 消息内容
|
||||||
raw_content: str # base64编码的原始消息内容
|
raw_content: bytes = Field(sa_column=Column(LargeBinary)) # base64编码的原始消息内容
|
||||||
processed_plain_text: str = Field(index=True) # 平面化处理后的纯文本消息
|
processed_plain_text: str = Field() # 平面化处理后的纯文本消息
|
||||||
display_message: str # 显示的消息内容(被放入Prompt)
|
display_message: str # 显示的消息内容(被放入Prompt)
|
||||||
|
|
||||||
# 其他配置
|
# 其他配置
|
||||||
|
|
@ -85,9 +85,9 @@ class Images(SQLModel, table=True):
|
||||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||||
|
|
||||||
# 元信息
|
# 元信息
|
||||||
image_hash: str = Field(default="", max_length=255) # 图片哈希,使用sha256哈希值,亦作为图片唯一ID
|
image_hash: str = Field(index=True, max_length=255) # 图片哈希,使用sha256哈希值,亦作为图片唯一ID
|
||||||
description: str # 图片的描述
|
description: str # 图片的描述
|
||||||
full_path: str = Field(index=True, max_length=1024) # 文件的完整路径 (包括文件名)
|
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
|
||||||
image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI)
|
image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI)
|
||||||
"""图片类型,例如 'emoji' 或 'image'"""
|
"""图片类型,例如 'emoji' 或 'image'"""
|
||||||
emotion: Optional[str] = Field(default=None, nullable=True) # 表情包的情感标签,逗号分隔
|
emotion: Optional[str] = Field(default=None, nullable=True) # 表情包的情感标签,逗号分隔
|
||||||
|
|
@ -116,7 +116,7 @@ class ActionRecord(SQLModel, table=True):
|
||||||
session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id
|
session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id
|
||||||
|
|
||||||
# 调用信息
|
# 调用信息
|
||||||
action_name: str = Field(max_length=255) # 动作名称
|
action_name: str = Field(index=True, max_length=255) # 动作名称
|
||||||
action_reasoning: Optional[str] = Field(default=None) # 动作推理过程
|
action_reasoning: Optional[str] = Field(default=None) # 动作推理过程
|
||||||
action_data: Optional[str] = Field(default=None) # 动作数据,JSON格式存储
|
action_data: Optional[str] = Field(default=None) # 动作数据,JSON格式存储
|
||||||
|
|
||||||
|
|
@ -153,7 +153,7 @@ class OnlineTime(SQLModel, table=True):
|
||||||
timestamp: datetime = Field(default_factory=datetime.now, index=True) # 时间戳
|
timestamp: datetime = Field(default_factory=datetime.now, index=True) # 时间戳
|
||||||
duration_minutes: int = Field() # 时长,单位秒
|
duration_minutes: int = Field() # 时长,单位秒
|
||||||
start_timestamp: datetime = Field(default_factory=datetime.now) # 上线时间
|
start_timestamp: datetime = Field(default_factory=datetime.now) # 上线时间
|
||||||
end_timestamp: datetime = Field(index=True) # 下线时间
|
end_timestamp: datetime = Field() # 下线时间
|
||||||
|
|
||||||
|
|
||||||
class Expression(SQLModel, table=True):
|
class Expression(SQLModel, table=True):
|
||||||
|
|
@ -230,7 +230,68 @@ class ThinkingQuestion(SQLModel, table=True):
|
||||||
context: Optional[str] = Field(default=None, nullable=True) # 上下文
|
context: Optional[str] = Field(default=None, nullable=True) # 上下文
|
||||||
found_answer: bool = Field(default=False) # 是否找到答案
|
found_answer: bool = Field(default=False) # 是否找到答案
|
||||||
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
|
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
|
||||||
|
|
||||||
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储
|
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储
|
||||||
created_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 创建时间
|
created_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 创建时间
|
||||||
updated_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 最后更新时间
|
updated_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 最后更新时间
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryData(SQLModel, table=True):
|
||||||
|
"""存储二进制数据的模型"""
|
||||||
|
|
||||||
|
__tablename__ = "binary_data" # type: ignore
|
||||||
|
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||||
|
|
||||||
|
data_hash: str = Field(index=True, max_length=255) # 数据哈希,使用sha256哈希值,亦作为数据唯一ID
|
||||||
|
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonInfo(SQLModel, table=True):
|
||||||
|
"""存储个人信息的模型"""
|
||||||
|
|
||||||
|
__tablename__ = "person_info" # type: ignore
|
||||||
|
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||||
|
|
||||||
|
is_known: bool = Field(default=False) # 是否为已知人
|
||||||
|
person_id: str = Field(unique=True, index=True, max_length=255) # 人员ID
|
||||||
|
person_name: Optional[str] = Field(default=None, max_length=255, nullable=True) # 人员名称
|
||||||
|
name_reason: Optional[str] = Field(default=None, nullable=True) # 名称原因
|
||||||
|
|
||||||
|
# 身份元数据
|
||||||
|
platform: str = Field(index=True, max_length=100) # 平台名称
|
||||||
|
user_id: str = Field(index=True, max_length=255) # 用户ID
|
||||||
|
user_nickname: str = Field(index=True, max_length=255) # 用户昵称
|
||||||
|
group_nickname: Optional[str] = Field(
|
||||||
|
default=None, nullable=True
|
||||||
|
) # 群昵称 (JSON, [{"group_id": str, "group_nick_name": str}])
|
||||||
|
|
||||||
|
# 印象
|
||||||
|
memory_points: Optional[str] = Field(default=None, nullable=True) # 记忆要点,JSON格式存储
|
||||||
|
|
||||||
|
# 认识次数和时间
|
||||||
|
know_counts: int = Field(default=0) # 认识次数
|
||||||
|
first_known_time: Optional[datetime] = Field(default=None, nullable=True) # 首次认识时间
|
||||||
|
last_known_time: Optional[datetime] = Field(default=None, nullable=True) # 最后认识时间
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSession(SQLModel, table=True):
|
||||||
|
"""存储聊天会话的模型"""
|
||||||
|
|
||||||
|
__tablename__ = "chat_sessions" # type: ignore
|
||||||
|
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||||
|
|
||||||
|
session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID
|
||||||
|
|
||||||
|
created_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 创建时间
|
||||||
|
last_active_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 最后活跃时间
|
||||||
|
|
||||||
|
# 身份元数据
|
||||||
|
user_id: str = Field(index=True, max_length=255) # 用户ID
|
||||||
|
user_nickname: str = Field(index=True, max_length=255) # 用户昵称
|
||||||
|
user_cardname: Optional[str] = Field(default=None, max_length=255, nullable=True) # 用户备注名
|
||||||
|
group_id: Optional[str] = Field(index=True, default=None, max_length=255, nullable=True) # 群组id
|
||||||
|
group_name: Optional[str] = Field(index=True, default=None, max_length=255, nullable=True) # 群组名称
|
||||||
|
platform: str = Field(index=True, max_length=100) # 用户平台
|
||||||
|
|
|
||||||
|
|
@ -206,8 +206,6 @@ class WebSocketLogHandler(logging.Handler):
|
||||||
# 如果是 JSON 格式(文件格式化器),解析它
|
# 如果是 JSON 格式(文件格式化器),解析它
|
||||||
message = formatted_msg
|
message = formatted_msg
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
log_dict = json.loads(formatted_msg)
|
log_dict = json.loads(formatted_msg)
|
||||||
message = log_dict.get("event", formatted_msg)
|
message = log_dict.get("event", formatted_msg)
|
||||||
except (json.JSONDecodeError, ValueError):
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database_model import BinaryData
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
|
||||||
|
logger = get_logger("file_utils")
|
||||||
|
|
||||||
|
class FileUtils:
|
||||||
|
@staticmethod
|
||||||
|
def save_bytes_to_file(file_path: Path, data: bytes):
|
||||||
|
"""
|
||||||
|
将字节数据保存到指定文件路径
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (Path): 目标文件路径
|
||||||
|
data (bytes): 要保存的字节数据
|
||||||
|
Raises:
|
||||||
|
IOError: 如果写入文件时发生错误
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
file_path = file_path.absolute().resolve()
|
||||||
|
with file_path.open("wb") as f:
|
||||||
|
f.write(data)
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 计算数据哈希
|
||||||
|
data_hash = hashlib.sha256(data).hexdigest()
|
||||||
|
# 创建 BinaryData 记录
|
||||||
|
binary_data_record = BinaryData(data_hash=data_hash, full_path=str(file_path))
|
||||||
|
session.add(binary_data_record)
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存文件 {file_path} 失败: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_path_by_hash(data_hash: str) -> Path:
|
||||||
|
"""
|
||||||
|
根据数据哈希获取文件路径
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_hash (str): 数据的哈希值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: 对应的数据文件路径
|
||||||
|
"""
|
||||||
|
with get_db_session() as session:
|
||||||
|
statement = select(BinaryData).filter_by(data_hash=data_hash).limit(1)
|
||||||
|
if binary_data := session.exec(statement).first():
|
||||||
|
return Path(binary_data.full_path)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录")
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
from src.common.data_models.message_component_model import MessageSequence
|
||||||
|
|
||||||
|
|
||||||
|
class MessageUtils:
|
||||||
|
@staticmethod
|
||||||
|
def from_db_record_msg_to_MaiSeq(raw_content: bytes) -> MessageSequence:
|
||||||
|
unpacked_data = msgpack.unpackb(raw_content)
|
||||||
|
return MessageSequence.from_dict(unpacked_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def from_MaiSeq_to_db_record_msg(msg: MessageSequence) -> bytes:
|
||||||
|
dict_representation = msg.to_dict()
|
||||||
|
return msgpack.packb(dict_representation) # type: ignore
|
||||||
|
|
@ -31,6 +31,7 @@ from .official_configs import (
|
||||||
DebugConfig,
|
DebugConfig,
|
||||||
DreamConfig,
|
DreamConfig,
|
||||||
WebUIConfig,
|
WebUIConfig,
|
||||||
|
DatabaseConfig,
|
||||||
)
|
)
|
||||||
from .model_configs import ModelInfo, ModelTaskConfig, APIProvider
|
from .model_configs import ModelInfo, ModelTaskConfig, APIProvider
|
||||||
from .config_base import ConfigBase, Field, AttributeData
|
from .config_base import ConfigBase, Field, AttributeData
|
||||||
|
|
@ -125,6 +126,9 @@ class Config(ConfigBase):
|
||||||
|
|
||||||
webui: WebUIConfig = Field(default_factory=WebUIConfig)
|
webui: WebUIConfig = Field(default_factory=WebUIConfig)
|
||||||
"""WebUI配置类"""
|
"""WebUI配置类"""
|
||||||
|
|
||||||
|
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||||||
|
"""数据库配置类"""
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(ConfigBase):
|
class ModelConfig(ConfigBase):
|
||||||
|
|
|
||||||
|
|
@ -653,4 +653,15 @@ class WebUIConfig(ConfigBase):
|
||||||
"""是否启用安全Cookie(仅通过HTTPS传输,默认false)"""
|
"""是否启用安全Cookie(仅通过HTTPS传输,默认false)"""
|
||||||
|
|
||||||
enable_paragraph_content: bool = False
|
enable_paragraph_content: bool = False
|
||||||
"""是否在知识图谱中加载段落完整内容(需要加载embedding store,会占用额外内存)"""
|
"""是否在知识图谱中加载段落完整内容(需要加载embedding store,会占用额外内存)"""
|
||||||
|
|
||||||
|
class DatabaseConfig(ConfigBase):
|
||||||
|
"""数据库配置类"""
|
||||||
|
|
||||||
|
save_binary_data: bool = False
|
||||||
|
"""
|
||||||
|
是否将消息中的二进制数据保存为独立文件
|
||||||
|
若启用,消息中的语音等二进制数据将会保存为独立文件,并在消息中以特殊标记替代。启用会导致数据文件夹体积增大,但可以实现二次识别等功能。
|
||||||
|
若禁用,则消息中的二进制将会在识别后删除,并在消息中使用识别结果替代,无法二次识别
|
||||||
|
该配置项仅影响新存储的消息,已有消息不会受到影响
|
||||||
|
"""
|
||||||
Loading…
Reference in New Issue