diff --git a/bot.py b/bot.py index 3f3a4e9c..33fcbdd1 100644 --- a/bot.py +++ b/bot.py @@ -1,4 +1,4 @@ -# raise RuntimeError("System Not Ready") +raise RuntimeError("System Not Ready") import asyncio import hashlib import os diff --git a/pytests/image_sys_test/emoji_manager_test.py b/pytests/image_sys_test/emoji_manager_test.py index 24e68c75..f9877558 100644 --- a/pytests/image_sys_test/emoji_manager_test.py +++ b/pytests/image_sys_test/emoji_manager_test.py @@ -106,7 +106,7 @@ def _install_stub_modules(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): class _Result: def scalars(self): return self @@ -231,8 +231,8 @@ def _install_stub_modules(monkeypatch): def import_emoji_manager_new(monkeypatch): _install_stub_modules(monkeypatch) - file_path = Path(__file__).resolve().parents[2] / "src" / "chat" / "emoji_system" / "emoji_manager_new.py" - spec = importlib.util.spec_from_file_location("emoji_manager_new", file_path) + file_path = Path(__file__).resolve().parents[2] / "src" / "chat" / "emoji_system" / "emoji_manager.py" + spec = importlib.util.spec_from_file_location("emoji_manager", file_path) module = importlib.util.module_from_spec(spec) monkeypatch.setitem(sys.modules, "emoji_manager_new", module) spec.loader.exec_module(module) @@ -446,7 +446,7 @@ def test_load_emojis_from_db_empty(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): return _Result() def _get_db_session(): @@ -487,7 +487,7 @@ def test_load_emojis_from_db_partial_bad_records(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): return _Result() def _get_db_session(): @@ -524,7 +524,7 @@ def test_load_emojis_from_db_execute_error(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): raise RuntimeError("execute failed") def _get_db_session(): @@ -581,7 +581,7 @@ def test_load_emojis_from_db_scalars_all_error(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): return _Result() def _get_db_session(): @@ -799,6 +799,8 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self + def limit(self, _num): + return self def _select(_model): return _Select() @@ -817,7 +819,7 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): return _Result() def _get_db_session(): @@ -887,6 +889,9 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self + + def limit(self, _num): + return self def _select(_model): return _Select() @@ -898,7 +903,7 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): raise RuntimeError("db delete failed") def _get_db_session(): @@ -942,6 +947,8 @@ def test_delete_emoji_success(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self + def limit(self, _num): + return self def _select(_model): return _Select() @@ -966,7 +973,7 @@ def test_delete_emoji_success(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): return _Result() def delete(self, _record): @@ -998,6 +1005,8 @@ def test_update_emoji_usage_success(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self + def limit(self, _num): + return self def _select(_model): return _Select() @@ -1021,7 +1030,7 @@ def test_update_emoji_usage_success(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): return _Result() def add(self, _record): @@ -1051,6 +1060,9 @@ def test_update_emoji_usage_missing_record(monkeypatch): def filter_by(self, **_kwargs): return self + def limit(self, _num): + return self + def _select(_model): return _Select() @@ -1068,7 +1080,7 @@ def test_update_emoji_usage_missing_record(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): return _Result() def _get_db_session(): @@ -1094,6 +1106,8 @@ def test_update_emoji_usage_execute_error(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self + def limit(self, _num): + return self def _select(_model): return _Select() @@ -1105,7 +1119,7 @@ def test_update_emoji_usage_execute_error(monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def execute(self, _statement): + def exec(self, _statement): raise RuntimeError("execute failed") def _get_db_session(): diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index c7cb1dd1..c95c2fef 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -62,7 +62,7 @@ class EmojiManager: try: with get_db_session() as session: statement = select(Images) - results = session.execute(statement).scalars().all() + results = session.exec(statement).all() for record in results: try: emoji = MaiEmoji.from_db_instance(record) @@ -144,8 +144,8 @@ class EmojiManager: # 删除数据库记录 try: with get_db_session() as session: - statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI) - if image_record := session.execute(statement).scalars().first(): + statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI).limit(1) + if image_record := session.exec(statement).first(): session.delete(image_record) logger.info(f"[删除表情包] 成功删除数据库中的表情包记录: {emoji.emoji_hash}") else: @@ -170,8 +170,8 @@ class EmojiManager: """ try: with get_db_session() as session: - statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI) - if image_record := session.execute(statement).scalars().first(): + statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI).limit(1) + if image_record := session.exec(statement).first(): image_record.query_count += 1 image_record.last_used_time = datetime.now() session.add(image_record) diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py index d1303dc2..8dbcdbad 100644 --- a/src/common/data_models/__init__.py +++ b/src/common/data_models/__init__.py @@ -1,54 +1,6 @@ import copy -from typing import Any class BaseDataModel: def deepcopy(self): return copy.deepcopy(self) - - -def transform_class_to_dict(obj: Any) -> Any: - # sourcery skip: assign-if-exp, reintroduce-else - """ - 将对象或容器中的 BaseDataModel 子类(类对象)或 BaseDataModel 实例 - 递归转换为普通 dict,不修改原对象。 - - 对于类对象(isinstance(value, type) 且 issubclass(..., BaseDataModel)), - 读取类的 __dict__ 中非 dunder 项并递归转换。 - - 对于实例(isinstance(value, BaseDataModel)),读取 vars(instance) 并递归转换。 - """ - - def _transform(value: Any) -> Any: - # 值是类对象且为 BaseDataModel 的子类 - if isinstance(value, type) and issubclass(value, BaseDataModel): - return {k: _transform(v) for k, v in value.__dict__.items() if not k.startswith("__") and not callable(v)} - - # 值是 BaseDataModel 的实例 - if isinstance(value, BaseDataModel): - return {k: _transform(v) for k, v in vars(value).items()} - - # 常见容器类型,递归处理 - if isinstance(value, dict): - return {k: _transform(v) for k, v in value.items()} - if isinstance(value, list): - return [_transform(v) for v in value] - if isinstance(value, tuple): - return tuple(_transform(v) for v in value) - if isinstance(value, set): - return {_transform(v) for v in value} - # 基本类型,直接返回 - return value - - result = _transform(obj) - - def flatten(target_dict: dict): - flat_dict = {} - for k, v in target_dict.items(): - if isinstance(v, dict): - # 递归扁平化子字典 - sub_flat = flatten(v) - flat_dict.update(sub_flat) - else: - flat_dict[k] = v - return flat_dict - - return flatten(result) if isinstance(result, dict) else result diff --git a/src/common/data_models/message_component_model.py b/src/common/data_models/message_component_model.py new file mode 100644 index 00000000..7733ddce --- /dev/null +++ b/src/common/data_models/message_component_model.py @@ -0,0 +1,226 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo +from typing import Optional, List, Union, Dict, Any + +import asyncio +import hashlib +import base64 + +from src.common.logger import get_logger + +logger = get_logger("base_message_component_model") + + +class BaseMessageComponentModel(ABC): + @abstractmethod + async def to_seg(self) -> Seg: + """将消息组件转换为 maim_message.Seg 对象""" + raise NotImplementedError + + def clone(self): + return deepcopy(self) + + +class ByteComponent: + def __init__(self, *, binary_hash: str, content: Optional[str] = None, binary_data: Optional[bytes] = None) -> None: + self.content: str = content if content is not None else "" + """处理后的内容""" + self.binary_data: bytes = binary_data if binary_data is not None else b"" + """原始二进制数据""" + self.binary_hash: str = hashlib.sha256(self.binary_data).hexdigest() if self.binary_data else binary_hash + """二进制数据的 SHA256 哈希值,用于唯一标识该二进制数据""" + + +class TextComponent(BaseMessageComponentModel): + def __init__(self, text: str): + self.text = text + assert isinstance(text, str), "TextComponent 的 text 必须是字符串类型" + + async def to_seg(self) -> Seg: + return Seg(type="text", data=self.text) + + +class ImageComponent(BaseMessageComponentModel, ByteComponent): + async def load_image_binary(self): + if not self.binary_data: + ... + + async def to_seg(self) -> Seg: + if not self.binary_data: + await self.load_image_binary() + return Seg(type="image", data=base64.b64encode(self.binary_data).decode()) + + +class EmojiComponent(BaseMessageComponentModel, ByteComponent): + async def load_emoji_binary(self) -> None: + """ + 加载表情的二进制数据,如果 binary_data 为空,则通过 emoji_hash 从表情管理器加载 + + Raises: + ValueError: 如果 binary_data 为空且缺少 emoji_hash + ValueError: 如果无法通过 emoji_hash 加载表情二进制数据 + """ + if not self.binary_data: + from src.chat.emoji_system.emoji_manager import emoji_manager + + if not ( + emoji := emoji_manager.get_emoji_by_hash(self.binary_hash) + or emoji_manager.get_emoji_by_hash_from_db(self.binary_hash) + ): + raise ValueError(f"无法通过 emoji_hash 加载表情二进制数据: {self.binary_hash}") + try: + self.binary_data = await asyncio.to_thread(emoji.full_path.read_bytes) + except Exception as e: + raise ValueError(f"通过 emoji_hash 加载表情二进制数据时发生错误: {e}") from e + + async def to_seg(self) -> Seg: + if not self.binary_data: + await self.load_emoji_binary() + return Seg(type="emoji", data=base64.b64encode(self.binary_data).decode()) + + +class VoiceComponent(BaseMessageComponentModel, ByteComponent): + async def load_voice_binary(self) -> None: + if not self.binary_data: + from src.common.utils.utils_file import FileUtils + + try: + file_path = FileUtils.get_file_path_by_hash(self.binary_hash) + self.binary_data = await asyncio.to_thread(file_path.read_bytes) + except Exception as e: + raise ValueError(f"通过 voice_hash 加载语音二进制数据时发生错误: {e}") from e + + async def to_seg(self) -> Seg: + if not self.binary_data: + await self.load_voice_binary() + return Seg(type="voice", data=base64.b64encode(self.binary_data).decode()) + + +class ForwardNodeComponent(BaseMessageComponentModel): + def __init__(self, forward_components: List["ForwardComponent"]): + self.forward_components = forward_components + assert isinstance(forward_components, list), "ForwardNodeComponent 的 forward_components 必须是列表类型" + assert all(isinstance(comp, ForwardComponent) for comp in forward_components), ( + "ForwardNodeComponent 的 forward_components 列表中必须全部是 ForwardComponent 类型" + ) + assert forward_components, "ForwardNodeComponent 的 forward_components 不能为空列表" + + async def to_seg(self) -> "Seg": + resp: List[Dict[str, Any]] = [] + for comp in self.forward_components: + data = await comp.to_seg() + sender_info = UserInfo(None, comp.user_id, comp.user_nickname, comp.user_cardname) + base_message_info = BaseMessageInfo(user_info=sender_info) + base_message = MessageBase(base_message_info, data) + resp.append(base_message.to_dict()) + return Seg(type="forward", data=resp) # type: ignore + + +class DictComponent: + def __init__(self, data: Dict[str, Any]): + self.data = data + assert isinstance(data, dict), "DictComponent 的 data 必须是字典类型" + + +StandardMessageComponents = Union[ + TextComponent, + ImageComponent, + EmojiComponent, + VoiceComponent, + ForwardNodeComponent, + DictComponent, +] + + +class ForwardComponent(BaseMessageComponentModel): + def __init__( + self, + user_nickname: str, + content: List[StandardMessageComponents], + user_id: Optional[str] = None, + user_cardname: Optional[str] = None, + ): + self.user_nickname: str = user_nickname + self.content: List[StandardMessageComponents] = content + self.user_id: Optional[str] = user_id + self.user_cardname: Optional[str] = user_cardname + assert self.content, "ForwardComponent 的 content 不能为空" + + async def to_seg(self) -> "Seg": + return Seg( + type="seglist", data=[await comp.to_seg() for comp in self.content if not isinstance(comp, DictComponent)] + ) + + +class MessageSequence: + def __init__(self, components: List[StandardMessageComponents]): + self.components: List[StandardMessageComponents] = components + + def to_dict(self) -> List[Dict[str, Any]]: + return [self._item_2_dict(comp) for comp in self.components] + + def _item_2_dict(self, item: StandardMessageComponents) -> Dict[str, Any]: + if isinstance(item, TextComponent): + return {"type": "text", "data": item.text} + elif isinstance(item, ImageComponent): + if not item.content: + raise RuntimeError("ImageComponent content 未初始化") + return {"type": "image", "data": item.content, "hash": item.binary_hash} + elif isinstance(item, EmojiComponent): + if not item.content: + raise RuntimeError("EmojiComponent content 未初始化") + return {"type": "emoji", "data": item.content, "hash": item.binary_hash} + elif isinstance(item, VoiceComponent): + if not item.content: + raise RuntimeError("VoiceComponent content 未初始化") + return {"type": "voice", "data": item.content, "hash": item.binary_hash} + elif isinstance(item, ForwardNodeComponent): + return { + "type": "forward", + "data": [ + { + "user_id": comp.user_id, + "user_nickname": comp.user_nickname, + "user_cardname": comp.user_cardname, + "content": [self._item_2_dict(c) for c in comp.content], + } + for comp in item.forward_components + ], + } + else: + logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent") + return {"type": "dict", "data": item.data} + + @classmethod + def from_dict(cls, data: List[Dict[str, Any]]) -> "MessageSequence": + components: List[StandardMessageComponents] = [] + components.extend(cls._dict_2_item(item) for item in data) + return cls(components=components) + + @classmethod + def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents: + item_type = item.get("type") + if item_type == "text": + return TextComponent(text=item["data"]) + elif item_type == "image": + return ImageComponent(binary_hash=item["hash"], content=item["data"]) + elif item_type == "emoji": + return EmojiComponent(binary_hash=item["hash"], content=item["data"]) + elif item_type == "voice": + return VoiceComponent(binary_hash=item["hash"], content=item["data"]) + elif item_type == "forward": + forward_components = [] + for fc in item["data"]: + content = [cls._dict_2_item(c) for c in fc["content"]] + forward_component = ForwardComponent( + user_nickname=fc["user_nickname"], + user_id=fc.get("user_id"), + user_cardname=fc.get("user_cardname"), + content=content, + ) + forward_components.append(forward_component) + return ForwardNodeComponent(forward_components=forward_components) + else: + logger.warning(f"Unofficial component type in dict: {item_type}, defaulting to DictComponent") + return DictComponent(data=item.get("data") or {}) diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py deleted file mode 100644 index dc4ad951..00000000 --- a/src/common/data_models/message_data_model.py +++ /dev/null @@ -1,210 +0,0 @@ -from typing import Optional, TYPE_CHECKING, List, Tuple, Union, Dict, Any -from dataclasses import dataclass, field -from enum import Enum - -from . import BaseDataModel - -if TYPE_CHECKING: - from .database_data_model import DatabaseMessages - - -@dataclass -class MessageAndActionModel(BaseDataModel): - chat_id: str = field(default_factory=str) - time: float = field(default_factory=float) - user_id: str = field(default_factory=str) - user_platform: str = field(default_factory=str) - user_nickname: str = field(default_factory=str) - user_cardname: Optional[str] = None - processed_plain_text: Optional[str] = None - display_message: Optional[str] = None - chat_info_platform: str = field(default_factory=str) - is_action_record: bool = field(default=False) - action_name: Optional[str] = None - is_command: bool = field(default=False) - intercept_message_level: int = field(default=0) - - @classmethod - def from_DatabaseMessages(cls, message: "DatabaseMessages"): - return cls( - chat_id=message.chat_id, - time=message.time, - user_id=message.user_info.user_id, - user_platform=message.user_info.platform, - user_nickname=message.user_info.user_nickname, - user_cardname=message.user_info.user_cardname, - processed_plain_text=message.processed_plain_text, - display_message=message.display_message, - chat_info_platform=message.chat_info.platform, - is_command=message.is_command, - intercept_message_level=getattr(message, "intercept_message_level", 0), - ) - - -class ReplyContentType(Enum): - TEXT = "text" - IMAGE = "image" - EMOJI = "emoji" - COMMAND = "command" - VOICE = "voice" - FORWARD = "forward" - HYBRID = "hybrid" # 混合类型,包含多种内容 - - def __repr__(self) -> str: - return self.value - - -@dataclass -class ForwardNode(BaseDataModel): - user_id: Optional[str] = None - user_nickname: Optional[str] = None - content: Union[List["ReplyContent"], str] = field(default_factory=list) - - @classmethod - def construct_as_id_reference(cls, message_id: str) -> "ForwardNode": - return cls(user_id="", user_nickname="", content=message_id) - - @classmethod - def construct_as_created_node( - cls, user_id: str, user_nickname: str, content: List["ReplyContent"] - ) -> "ForwardNode": - return cls(user_id=user_id, user_nickname=user_nickname, content=content) - - -@dataclass -class ReplyContent(BaseDataModel): - content_type: ReplyContentType | str - content: Union[str, Dict, List[ForwardNode], List["ReplyContent"]] # 支持嵌套的 ReplyContent - - @classmethod - def construct_as_text(cls, text: str): - return cls(content_type=ReplyContentType.TEXT, content=text) - - @classmethod - def construct_as_image(cls, image_base64: str): - return cls(content_type=ReplyContentType.IMAGE, content=image_base64) - - @classmethod - def construct_as_voice(cls, voice_base64: str): - return cls(content_type=ReplyContentType.VOICE, content=voice_base64) - - @classmethod - def construct_as_emoji(cls, emoji_str: str): - return cls(content_type=ReplyContentType.EMOJI, content=emoji_str) - - @classmethod - def construct_as_command(cls, command_arg: Dict): - return cls(content_type=ReplyContentType.COMMAND, content=command_arg) - - @classmethod - def construct_as_hybrid(cls, hybrid_content: List[Tuple[ReplyContentType | str, str]]): - hybrid_content_list: List[ReplyContent] = [] - for content_type, content in hybrid_content: - assert content_type not in [ - ReplyContentType.HYBRID, - ReplyContentType.FORWARD, - ReplyContentType.VOICE, - ReplyContentType.COMMAND, - ], "混合内容的每个项不能是混合、转发、语音或命令类型" - assert isinstance(content, str), "混合内容的每个项必须是字符串" - hybrid_content_list.append(ReplyContent(content_type=content_type, content=content)) - return cls(content_type=ReplyContentType.HYBRID, content=hybrid_content_list) - - @classmethod - def construct_as_forward(cls, forward_nodes: List[ForwardNode]): - return cls(content_type=ReplyContentType.FORWARD, content=forward_nodes) - - def __post_init__(self): - if isinstance(self.content_type, ReplyContentType): - if self.content_type not in [ReplyContentType.HYBRID, ReplyContentType.FORWARD] and isinstance( - self.content, List - ): - raise ValueError( - f"非混合类型/转发类型的内容不能是列表,content_type: {self.content_type}, content: {self.content}" - ) - elif self.content_type in [ReplyContentType.HYBRID, ReplyContentType.FORWARD]: - if not isinstance(self.content, List): - raise ValueError( - f"混合类型/转发类型的内容必须是列表,content_type: {self.content_type}, content: {self.content}" - ) - - -@dataclass -class ReplySetModel(BaseDataModel): - """ - 回复集数据模型,用于多种回复类型的返回 - """ - - reply_data: List[ReplyContent] = field(default_factory=list) - - def __len__(self): - return len(self.reply_data) - - def add_text_content(self, text: str): - """ - 添加文本内容 - Args: - text: 文本内容 - """ - self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text)) - - def add_image_content(self, image_base64: str): - """ - 添加图片内容,base64编码的图片数据 - Args: - image_base64: base64编码的图片数据 - """ - self.reply_data.append(ReplyContent(content_type=ReplyContentType.IMAGE, content=image_base64)) - - def add_voice_content(self, voice_base64: str): - """ - 添加语音内容,base64编码的音频数据 - Args: - voice_base64: base64编码的音频数据 - """ - self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64)) - - def add_hybrid_content_by_raw(self, hybrid_content: List[Tuple[ReplyContentType | str, str]]): - """ - 添加混合型内容,可以包含text, image, emoji的任意组合 - Args: - hybrid_content: 元组 (类型, 消息内容) 构成的列表,如[(ReplyContentType.TEXT, "Hello"), (ReplyContentType.IMAGE, " Generator[Session, None, None]: session = SessionLocal() try: yield session - # 如果启用自动提交且没有异常,则提交事务 if auto_commit: session.commit() except Exception: @@ -132,59 +125,3 @@ def get_db() -> Generator[Session, None, None]: yield session finally: session.close() - - -class _AtomicContext: - def __init__(self) -> None: - self._session: Session | None = None - - def __enter__(self) -> Session: - self._session = SessionLocal() - self._session.begin() - return self._session - - def __exit__(self, exc_type, exc, tb) -> None: - if self._session is None: - return - try: - if exc_type is None: - self._session.commit() - else: - self._session.rollback() - finally: - self._session.close() - - -class DatabaseCompat: - """兼容旧 db 调用接口(Peewee 风格),底层使用 SQLAlchemy。""" - - def connect(self, reuse_if_open: bool = True) -> None: - # SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。 - _ = reuse_if_open - - def create_tables(self, models: list[type], safe: bool = True) -> None: - _ = safe - tables = [model.__table__ for model in models if hasattr(model, "__table__")] - if not tables: - return - from sqlmodel import SQLModel - - SQLModel.metadata.create_all(engine, tables=tables) - - def atomic(self) -> _AtomicContext: - return _AtomicContext() - - def execute_sql(self, sql: str): - with engine.connect() as conn: - result = conn.execute(text(sql)) - conn.commit() - return result - - def table_exists(self, model: type) -> bool: - if not hasattr(model, "__tablename__"): - return False - inspector = sqlalchemy_inspect(engine) - return inspector.has_table(model.__tablename__) - - -db = DatabaseCompat() diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 2e029b42..05bf5c21 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -1,6 +1,6 @@ from typing import Optional from sqlalchemy import Column, Float, Enum as SQLEnum -from sqlmodel import SQLModel, Field +from sqlmodel import SQLModel, Field, LargeBinary from enum import Enum from datetime import datetime @@ -45,8 +45,8 @@ class Messages(SQLModel, table=True): is_notify: bool = Field(default=False) # 是否为通知消息 # 消息内容 - raw_content: str # base64编码的原始消息内容 - processed_plain_text: str = Field(index=True) # 平面化处理后的纯文本消息 + raw_content: bytes = Field(sa_column=Column(LargeBinary)) # base64编码的原始消息内容 + processed_plain_text: str = Field() # 平面化处理后的纯文本消息 display_message: str # 显示的消息内容(被放入Prompt) # 其他配置 @@ -85,9 +85,9 @@ class Images(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 # 元信息 - image_hash: str = Field(default="", max_length=255) # 图片哈希,使用sha256哈希值,亦作为图片唯一ID + image_hash: str = Field(index=True, max_length=255) # 图片哈希,使用sha256哈希值,亦作为图片唯一ID description: str # 图片的描述 - full_path: str = Field(index=True, max_length=1024) # 文件的完整路径 (包括文件名) + full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名) image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI) """图片类型,例如 'emoji' 或 'image'""" emotion: Optional[str] = Field(default=None, nullable=True) # 表情包的情感标签,逗号分隔 @@ -116,7 +116,7 @@ class ActionRecord(SQLModel, table=True): session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id # 调用信息 - action_name: str = Field(max_length=255) # 动作名称 + action_name: str = Field(index=True, max_length=255) # 动作名称 action_reasoning: Optional[str] = Field(default=None) # 动作推理过程 action_data: Optional[str] = Field(default=None) # 动作数据,JSON格式存储 @@ -153,7 +153,7 @@ class OnlineTime(SQLModel, table=True): timestamp: datetime = Field(default_factory=datetime.now, index=True) # 时间戳 duration_minutes: int = Field() # 时长,单位秒 start_timestamp: datetime = Field(default_factory=datetime.now) # 上线时间 - end_timestamp: datetime = Field(index=True) # 下线时间 + end_timestamp: datetime = Field() # 下线时间 class Expression(SQLModel, table=True): @@ -230,7 +230,68 @@ class ThinkingQuestion(SQLModel, table=True): context: Optional[str] = Field(default=None, nullable=True) # 上下文 found_answer: bool = Field(default=False) # 是否找到答案 answer: Optional[str] = Field(default=None, nullable=True) # 问题答案 - + thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储 created_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 创建时间 updated_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 最后更新时间 + + +class BinaryData(SQLModel, table=True): + """存储二进制数据的模型""" + + __tablename__ = "binary_data" # type: ignore + + id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 + + data_hash: str = Field(index=True, max_length=255) # 数据哈希,使用sha256哈希值,亦作为数据唯一ID + full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名) + + +class PersonInfo(SQLModel, table=True): + """存储个人信息的模型""" + + __tablename__ = "person_info" # type: ignore + + id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 + + is_known: bool = Field(default=False) # 是否为已知人 + person_id: str = Field(unique=True, index=True, max_length=255) # 人员ID + person_name: Optional[str] = Field(default=None, max_length=255, nullable=True) # 人员名称 + name_reason: Optional[str] = Field(default=None, nullable=True) # 名称原因 + + # 身份元数据 + platform: str = Field(index=True, max_length=100) # 平台名称 + user_id: str = Field(index=True, max_length=255) # 用户ID + user_nickname: str = Field(index=True, max_length=255) # 用户昵称 + group_nickname: Optional[str] = Field( + default=None, nullable=True + ) # 群昵称 (JSON, [{"group_id": str, "group_nick_name": str}]) + + # 印象 + memory_points: Optional[str] = Field(default=None, nullable=True) # 记忆要点,JSON格式存储 + + # 认识次数和时间 + know_counts: int = Field(default=0) # 认识次数 + first_known_time: Optional[datetime] = Field(default=None, nullable=True) # 首次认识时间 + last_known_time: Optional[datetime] = Field(default=None, nullable=True) # 最后认识时间 + + +class ChatSession(SQLModel, table=True): + """存储聊天会话的模型""" + + __tablename__ = "chat_sessions" # type: ignore + + id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 + + session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID + + created_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 创建时间 + last_active_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 最后活跃时间 + + # 身份元数据 + user_id: str = Field(index=True, max_length=255) # 用户ID + user_nickname: str = Field(index=True, max_length=255) # 用户昵称 + user_cardname: Optional[str] = Field(default=None, max_length=255, nullable=True) # 用户备注名 + group_id: Optional[str] = Field(index=True, default=None, max_length=255, nullable=True) # 群组id + group_name: Optional[str] = Field(index=True, default=None, max_length=255, nullable=True) # 群组名称 + platform: str = Field(index=True, max_length=100) # 用户平台 diff --git a/src/common/logger.py b/src/common/logger.py index 92306f6e..7acc2896 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -206,8 +206,6 @@ class WebSocketLogHandler(logging.Handler): # 如果是 JSON 格式(文件格式化器),解析它 message = formatted_msg try: - import json - log_dict = json.loads(formatted_msg) message = log_dict.get("event", formatted_msg) except (json.JSONDecodeError, ValueError): diff --git a/src/common/utils/utils_file.py b/src/common/utils/utils_file.py new file mode 100644 index 00000000..ea446760 --- /dev/null +++ b/src/common/utils/utils_file.py @@ -0,0 +1,55 @@ +from pathlib import Path +from sqlmodel import select + +import hashlib + +from src.common.logger import get_logger +from src.common.database.database_model import BinaryData +from src.common.database.database import get_db_session + +logger = get_logger("file_utils") + +class FileUtils: + @staticmethod + def save_bytes_to_file(file_path: Path, data: bytes): + """ + 将字节数据保存到指定文件路径 + + Args: + file_path (Path): 目标文件路径 + data (bytes): 要保存的字节数据 + Raises: + IOError: 如果写入文件时发生错误 + """ + try: + file_path = file_path.absolute().resolve() + with file_path.open("wb") as f: + f.write(data) + with get_db_session() as session: + # 计算数据哈希 + data_hash = hashlib.sha256(data).hexdigest() + # 创建 BinaryData 记录 + binary_data_record = BinaryData(data_hash=data_hash, full_path=str(file_path)) + session.add(binary_data_record) + session.commit() + except Exception as e: + logger.error(f"保存文件 {file_path} 失败: {e}") + raise e + + @staticmethod + def get_file_path_by_hash(data_hash: str) -> Path: + """ + 根据数据哈希获取文件路径 + + Args: + data_hash (str): 数据的哈希值 + + Returns: + Path: 对应的数据文件路径 + """ + with get_db_session() as session: + statement = select(BinaryData).filter_by(data_hash=data_hash).limit(1) + if binary_data := session.exec(statement).first(): + return Path(binary_data.full_path) + else: + raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录") \ No newline at end of file diff --git a/src/common/utils/utils_message.py b/src/common/utils/utils_message.py new file mode 100644 index 00000000..36d4c826 --- /dev/null +++ b/src/common/utils/utils_message.py @@ -0,0 +1,15 @@ +import msgpack + +from src.common.data_models.message_component_model import MessageSequence + + +class MessageUtils: + @staticmethod + def from_db_record_msg_to_MaiSeq(raw_content: bytes) -> MessageSequence: + unpacked_data = msgpack.unpackb(raw_content) + return MessageSequence.from_dict(unpacked_data) + + @staticmethod + async def from_MaiSeq_to_db_record_msg(msg: MessageSequence) -> bytes: + dict_representation = msg.to_dict() + return msgpack.packb(dict_representation) # type: ignore diff --git a/src/config/config.py b/src/config/config.py index 26749315..5495d424 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -31,6 +31,7 @@ from .official_configs import ( DebugConfig, DreamConfig, WebUIConfig, + DatabaseConfig, ) from .model_configs import ModelInfo, ModelTaskConfig, APIProvider from .config_base import ConfigBase, Field, AttributeData @@ -125,6 +126,9 @@ class Config(ConfigBase): webui: WebUIConfig = Field(default_factory=WebUIConfig) """WebUI配置类""" + + database: DatabaseConfig = Field(default_factory=DatabaseConfig) + """数据库配置类""" class ModelConfig(ConfigBase): diff --git a/src/config/official_configs.py b/src/config/official_configs.py index d4c2b0bd..deb3e518 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -653,4 +653,15 @@ class WebUIConfig(ConfigBase): """是否启用安全Cookie(仅通过HTTPS传输,默认false)""" enable_paragraph_content: bool = False - """是否在知识图谱中加载段落完整内容(需要加载embedding store,会占用额外内存)""" \ No newline at end of file + """是否在知识图谱中加载段落完整内容(需要加载embedding store,会占用额外内存)""" + +class DatabaseConfig(ConfigBase): + """数据库配置类""" + + save_binary_data: bool = False + """ + 是否将消息中的二进制数据保存为独立文件 + 若启用,消息中的语音等二进制数据将会保存为独立文件,并在消息中以特殊标记替代。启用会导致数据文件夹体积增大,但可以实现二次识别等功能。 + 若禁用,则消息中的二进制将会在识别后删除,并在消息中使用识别结果替代,无法二次识别 + 该配置项仅影响新存储的消息,已有消息不会受到影响 + """ \ No newline at end of file