From c14736ffca526ae2be1fd8c385b4544bd1705686 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Fri, 13 Feb 2026 16:24:03 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E5=A5=BD=E6=9B=B4=E8=A7=84=E8=8C=83?= =?UTF-8?q?=E7=9A=84=E7=B1=BB=E5=9E=8B=E6=B3=A8=E8=A7=A3=EF=BC=9BAGENTSMD?= =?UTF-8?q?=E8=AF=95=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AGENTS.md | 12 ++++ src/common/data_models/__init__.py | 22 ++++++ .../data_models/chat_session_data_model.py | 21 ++++-- src/common/data_models/image_data_model.py | 39 ++++------- .../data_models/mai_message_data_model.py | 70 ++++++++----------- 5 files changed, 91 insertions(+), 73 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..226bff82 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,12 @@ +# import 规范 +在从外部库进行导入时候,请遵循以下顺序: +1. 对于标准库和第三方库的导入,请按照如下顺序: + - 需要使用`from ... import ...`语法的导入放在前面。 + - 直接使用`import ...`语法的导入放在后面。 + - 对于使用`from ... import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。 + - 对于使用`import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。 +2. 对于本地模块的导入,请按照如下顺序: + - 对于同一个文件夹下的模块导入,使用相对导入,排列顺序按照**不发生import错误的前提下**,随便排列。 + - 对于不同文件夹下的模块导入,使用绝对导入。这些导入应该以`from src`开头,并且按照**不发生import错误的前提下**,尽量使得第二层的文件夹名称相同的导入放在一起;第二层文件夹名称排列随机。 +3. 标准库和第三方库的导入应该放在本地模块导入的前面。 +4. 各个导入块之间应该使用一个空行进行分隔。 \ No newline at end of file diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py index 8dbcdbad..a1032806 100644 --- a/src/common/data_models/__init__.py +++ b/src/common/data_models/__init__.py @@ -1,6 +1,28 @@ +from abc import ABC, abstractmethod + +from typing import Self, TypeVar, Generic, TYPE_CHECKING + import copy +if TYPE_CHECKING: + from sqlmodel import SQLModel + +T = TypeVar("T", bound="SQLModel") + class BaseDataModel: def deepcopy(self): return copy.deepcopy(self) + + +class BaseDatabaseDataModel(ABC, Generic[T]): + @abstractmethod + @classmethod + def from_db_instance(cls, db_record: T) -> Self: + """从数据库实例创建数据模型对象""" + raise NotImplementedError + + @abstractmethod + def to_db_instance(self) -> T: + """将数据模型对象转换为数据库实例""" + raise NotImplementedError diff --git a/src/common/data_models/chat_session_data_model.py b/src/common/data_models/chat_session_data_model.py index 9041c08f..fe9a04c7 100644 --- a/src/common/data_models/chat_session_data_model.py +++ b/src/common/data_models/chat_session_data_model.py @@ -2,8 +2,9 @@ from typing import Optional from src.common.database.database_model import ChatSession +from . import BaseDatabaseDataModel -class MaiChatSession: +class MaiChatSession(BaseDatabaseDataModel[ChatSession]): def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None): self.session_id = session_id self.platform = platform @@ -18,10 +19,18 @@ class MaiChatSession: self.is_group_session = bool(self.group_id) @classmethod - def from_db_instance(cls, session: ChatSession): + def from_db_instance(cls, db_record: ChatSession): return cls( - session_id=session.session_id, - platform=session.platform, - user_id=session.user_id, - group_id=session.group_id, + session_id=db_record.session_id, + platform=db_record.platform, + user_id=db_record.user_id, + group_id=db_record.group_id, ) + + def to_db_instance(self) -> ChatSession: + return ChatSession( + session_id=self.session_id, + platform=self.platform, + user_id=self.user_id, + group_id=self.group_id, + ) \ No newline at end of file diff --git a/src/common/data_models/image_data_model.py b/src/common/data_models/image_data_model.py index e6062b66..03bdd83d 100644 --- a/src/common/data_models/image_data_model.py +++ b/src/common/data_models/image_data_model.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path from PIL import Image as PILImage @@ -12,6 +11,7 @@ import traceback from src.common.database.database_model import Images, ImageType from src.common.logger import get_logger +from . import BaseDatabaseDataModel install(extra_lines=3) @@ -19,20 +19,11 @@ install(extra_lines=3) logger = get_logger("emoji") -class BaseImageDataModel(ABC): - @classmethod - @abstractmethod - def from_db_instance(cls, image: "Images"): - raise NotImplementedError - - @abstractmethod - def to_db_instance(self) -> "Images": - raise NotImplementedError - +class BaseImageDataModel(BaseDatabaseDataModel[Images]): def read_image_bytes(self, path: Path) -> bytes: """ 同步读取图片文件的字节内容 - + Args: path (Path): 图片文件的完整路径 Returns: @@ -75,10 +66,6 @@ class BaseImageDataModel(ABC): raise e -class ImageDataModel(BaseImageDataModel): - pass - - class MaiEmoji(BaseImageDataModel): def __init__(self, full_path: str | Path): if not full_path: @@ -103,15 +90,15 @@ class MaiEmoji(BaseImageDataModel): self._format: str = "" # 图片格式 @classmethod - def from_db_instance(cls, image: Images): - obj = cls(image.full_path) - obj.emoji_hash = image.image_hash - obj.description = image.description - if image.emotion: - obj.emotion = image.emotion.split(",") - obj.query_count = image.query_count - obj.last_used_time = image.last_used_time - obj.register_time = image.register_time + def from_db_instance(cls, db_record: Images): + obj = cls(db_record.full_path) + obj.emoji_hash = db_record.image_hash + obj.description = db_record.description + if db_record.emotion: + obj.emotion = db_record.emotion.split(",") + obj.query_count = db_record.query_count + obj.last_used_time = db_record.last_used_time + obj.register_time = db_record.register_time return obj def to_db_instance(self) -> Images: @@ -130,7 +117,7 @@ class MaiEmoji(BaseImageDataModel): async def calculate_hash_format(self) -> bool: """ 异步计算表情包的哈希值和格式 - + Returns: return (bool): 如果成功计算哈希值和格式则返回True,否则返回False """ diff --git a/src/common/data_models/mai_message_data_model.py b/src/common/data_models/mai_message_data_model.py index d11d7a4f..0bd159c7 100644 --- a/src/common/data_models/mai_message_data_model.py +++ b/src/common/data_models/mai_message_data_model.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from dataclasses import dataclass, field from maim_message import MessageBase from typing import Optional @@ -10,28 +9,7 @@ from src.common.database.database_model import Messages from src.common.data_models.message_component_model import MessageSequence from src.common.utils.utils_message import MessageUtils - -class BaseMAIMessageModel(ABC): - @classmethod - @abstractmethod - def from_db_instance(cls, message: "Messages"): - raise NotImplementedError - - @abstractmethod - def to_db_instance(self) -> "Messages": - raise NotImplementedError - - @abstractmethod - def from_maim_message(cls, message: MessageBase): - raise NotImplementedError - - @abstractmethod - def to_maim_message(self) -> MessageBase: - raise NotImplementedError - - @abstractmethod - def parse_message_segments(self): - raise NotImplementedError +from . import BaseDatabaseDataModel @dataclass @@ -54,7 +32,7 @@ class MessageInfo: additional_config: dict = field(default_factory=dict) -class MaiMessage(BaseMAIMessageModel): +class MaiMessage(BaseDatabaseDataModel[Messages]): def __init__(self, message_id: str, timestamp: datetime): self.message_id: str = message_id self.timestamp: datetime = timestamp # 时间戳 @@ -78,31 +56,31 @@ class MaiMessage(BaseMAIMessageModel): self.raw_message: MessageSequence @classmethod - def from_db_instance(cls, message: "Messages") -> "MaiMessage": - obj = cls(message_id=message.message_id, timestamp=message.timestamp) + def from_db_instance(cls, db_record: "Messages") -> "MaiMessage": + obj = cls(message_id=db_record.message_id, timestamp=db_record.timestamp) - user_info = UserInfo(message.user_id, message.user_nickname, message.user_cardname) - if message.group_id and message.group_name: - group_info = GroupInfo(message.group_id, message.group_name) + user_info = UserInfo(db_record.user_id, db_record.user_nickname, db_record.user_cardname) + if db_record.group_id and db_record.group_name: + group_info = GroupInfo(db_record.group_id, db_record.group_name) else: group_info = None obj.message_info = MessageInfo( user_info=user_info, group_info=group_info, - additional_config=json.loads(message.additional_config) if message.additional_config else {}, + additional_config=json.loads(db_record.additional_config) if db_record.additional_config else {}, ) - obj.is_mentioned = message.is_mentioned - obj.is_at = message.is_at - obj.is_emoji = message.is_emoji - obj.is_picture = message.is_picture - obj.is_command = message.is_command - obj.is_notify = message.is_notify - obj.reply_to = message.reply_to - obj.session_id = message.session_id - obj.processed_plain_text = message.processed_plain_text - obj.display_message = message.display_message - obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(message.raw_content) + obj.is_mentioned = db_record.is_mentioned + obj.is_at = db_record.is_at + obj.is_emoji = db_record.is_emoji + obj.is_picture = db_record.is_picture + obj.is_command = db_record.is_command + obj.is_notify = db_record.is_notify + obj.reply_to = db_record.reply_to + obj.session_id = db_record.session_id + obj.processed_plain_text = db_record.processed_plain_text + obj.display_message = db_record.display_message + obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(db_record.raw_content) return obj def to_db_instance(self) -> Messages: @@ -131,3 +109,13 @@ class MaiMessage(BaseMAIMessageModel): display_message=self.display_message, additional_config=additional_config, ) + + @classmethod + def from_maim_message(cls, message: MessageBase) -> "MaiMessage": + raise NotImplementedError + + def to_maim_message(self) -> MessageBase: + raise NotImplementedError + + def parse_message_segments(self): + raise NotImplementedError