更好更规范的类型注解;AGENTSMD试作

pull/1496/head
UnCLAS-Prommer 2026-02-13 16:24:03 +08:00
parent b80b5afe2a
commit c14736ffca
No known key found for this signature in database
5 changed files with 91 additions and 73 deletions

12
AGENTS.md 100644
View File

@ -0,0 +1,12 @@
# import 规范
在从外部库进行导入时候,请遵循以下顺序:
1. 对于标准库和第三方库的导入,请按照如下顺序:
- 需要使用`from ... import ...`语法的导入放在前面。
- 直接使用`import ...`语法的导入放在后面。
- 对于使用`from ... import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。
- 对于使用`import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。
2. 对于本地模块的导入,请按照如下顺序:
- 对于同一个文件夹下的模块导入,使用相对导入,排列顺序按照**不发生import错误的前提下**,随便排列。
- 对于不同文件夹下的模块导入,使用绝对导入。这些导入应该以`from src`开头,并且按照**不发生import错误的前提下**,尽量使得第二层的文件夹名称相同的导入放在一起;第二层文件夹名称排列随机。
3. 标准库和第三方库的导入应该放在本地模块导入的前面。
4. 各个导入块之间应该使用一个空行进行分隔。

View File

@ -1,6 +1,28 @@
from abc import ABC, abstractmethod
from typing import Self, TypeVar, Generic, TYPE_CHECKING
import copy
if TYPE_CHECKING:
from sqlmodel import SQLModel
T = TypeVar("T", bound="SQLModel")
class BaseDataModel:
def deepcopy(self):
return copy.deepcopy(self)
class BaseDatabaseDataModel(ABC, Generic[T]):
@abstractmethod
@classmethod
def from_db_instance(cls, db_record: T) -> Self:
"""从数据库实例创建数据模型对象"""
raise NotImplementedError
@abstractmethod
def to_db_instance(self) -> T:
"""将数据模型对象转换为数据库实例"""
raise NotImplementedError

View File

@ -2,8 +2,9 @@ from typing import Optional
from src.common.database.database_model import ChatSession
from . import BaseDatabaseDataModel
class MaiChatSession:
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
self.session_id = session_id
self.platform = platform
@ -18,10 +19,18 @@ class MaiChatSession:
self.is_group_session = bool(self.group_id)
@classmethod
def from_db_instance(cls, session: ChatSession):
def from_db_instance(cls, db_record: ChatSession):
return cls(
session_id=session.session_id,
platform=session.platform,
user_id=session.user_id,
group_id=session.group_id,
session_id=db_record.session_id,
platform=db_record.platform,
user_id=db_record.user_id,
group_id=db_record.group_id,
)
def to_db_instance(self) -> ChatSession:
return ChatSession(
session_id=self.session_id,
platform=self.platform,
user_id=self.user_id,
group_id=self.group_id,
)

View File

@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from PIL import Image as PILImage
@ -12,6 +11,7 @@ import traceback
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from . import BaseDatabaseDataModel
install(extra_lines=3)
@ -19,20 +19,11 @@ install(extra_lines=3)
logger = get_logger("emoji")
class BaseImageDataModel(ABC):
@classmethod
@abstractmethod
def from_db_instance(cls, image: "Images"):
raise NotImplementedError
@abstractmethod
def to_db_instance(self) -> "Images":
raise NotImplementedError
class BaseImageDataModel(BaseDatabaseDataModel[Images]):
def read_image_bytes(self, path: Path) -> bytes:
"""
同步读取图片文件的字节内容
Args:
path (Path): 图片文件的完整路径
Returns:
@ -75,10 +66,6 @@ class BaseImageDataModel(ABC):
raise e
class ImageDataModel(BaseImageDataModel):
pass
class MaiEmoji(BaseImageDataModel):
def __init__(self, full_path: str | Path):
if not full_path:
@ -103,15 +90,15 @@ class MaiEmoji(BaseImageDataModel):
self._format: str = "" # 图片格式
@classmethod
def from_db_instance(cls, image: Images):
obj = cls(image.full_path)
obj.emoji_hash = image.image_hash
obj.description = image.description
if image.emotion:
obj.emotion = image.emotion.split(",")
obj.query_count = image.query_count
obj.last_used_time = image.last_used_time
obj.register_time = image.register_time
def from_db_instance(cls, db_record: Images):
obj = cls(db_record.full_path)
obj.emoji_hash = db_record.image_hash
obj.description = db_record.description
if db_record.emotion:
obj.emotion = db_record.emotion.split(",")
obj.query_count = db_record.query_count
obj.last_used_time = db_record.last_used_time
obj.register_time = db_record.register_time
return obj
def to_db_instance(self) -> Images:
@ -130,7 +117,7 @@ class MaiEmoji(BaseImageDataModel):
async def calculate_hash_format(self) -> bool:
"""
异步计算表情包的哈希值和格式
Returns:
return (bool): 如果成功计算哈希值和格式则返回True否则返回False
"""

View File

@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from maim_message import MessageBase
from typing import Optional
@ -10,28 +9,7 @@ from src.common.database.database_model import Messages
from src.common.data_models.message_component_model import MessageSequence
from src.common.utils.utils_message import MessageUtils
class BaseMAIMessageModel(ABC):
@classmethod
@abstractmethod
def from_db_instance(cls, message: "Messages"):
raise NotImplementedError
@abstractmethod
def to_db_instance(self) -> "Messages":
raise NotImplementedError
@abstractmethod
def from_maim_message(cls, message: MessageBase):
raise NotImplementedError
@abstractmethod
def to_maim_message(self) -> MessageBase:
raise NotImplementedError
@abstractmethod
def parse_message_segments(self):
raise NotImplementedError
from . import BaseDatabaseDataModel
@dataclass
@ -54,7 +32,7 @@ class MessageInfo:
additional_config: dict = field(default_factory=dict)
class MaiMessage(BaseMAIMessageModel):
class MaiMessage(BaseDatabaseDataModel[Messages]):
def __init__(self, message_id: str, timestamp: datetime):
self.message_id: str = message_id
self.timestamp: datetime = timestamp # 时间戳
@ -78,31 +56,31 @@ class MaiMessage(BaseMAIMessageModel):
self.raw_message: MessageSequence
@classmethod
def from_db_instance(cls, message: "Messages") -> "MaiMessage":
obj = cls(message_id=message.message_id, timestamp=message.timestamp)
def from_db_instance(cls, db_record: "Messages") -> "MaiMessage":
obj = cls(message_id=db_record.message_id, timestamp=db_record.timestamp)
user_info = UserInfo(message.user_id, message.user_nickname, message.user_cardname)
if message.group_id and message.group_name:
group_info = GroupInfo(message.group_id, message.group_name)
user_info = UserInfo(db_record.user_id, db_record.user_nickname, db_record.user_cardname)
if db_record.group_id and db_record.group_name:
group_info = GroupInfo(db_record.group_id, db_record.group_name)
else:
group_info = None
obj.message_info = MessageInfo(
user_info=user_info,
group_info=group_info,
additional_config=json.loads(message.additional_config) if message.additional_config else {},
additional_config=json.loads(db_record.additional_config) if db_record.additional_config else {},
)
obj.is_mentioned = message.is_mentioned
obj.is_at = message.is_at
obj.is_emoji = message.is_emoji
obj.is_picture = message.is_picture
obj.is_command = message.is_command
obj.is_notify = message.is_notify
obj.reply_to = message.reply_to
obj.session_id = message.session_id
obj.processed_plain_text = message.processed_plain_text
obj.display_message = message.display_message
obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(message.raw_content)
obj.is_mentioned = db_record.is_mentioned
obj.is_at = db_record.is_at
obj.is_emoji = db_record.is_emoji
obj.is_picture = db_record.is_picture
obj.is_command = db_record.is_command
obj.is_notify = db_record.is_notify
obj.reply_to = db_record.reply_to
obj.session_id = db_record.session_id
obj.processed_plain_text = db_record.processed_plain_text
obj.display_message = db_record.display_message
obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(db_record.raw_content)
return obj
def to_db_instance(self) -> Messages:
@ -131,3 +109,13 @@ class MaiMessage(BaseMAIMessageModel):
display_message=self.display_message,
additional_config=additional_config,
)
@classmethod
def from_maim_message(cls, message: MessageBase) -> "MaiMessage":
raise NotImplementedError
def to_maim_message(self) -> MessageBase:
raise NotImplementedError
def parse_message_segments(self):
raise NotImplementedError