更好更规范的类型注解;AGENTSMD试作

pull/1496/head
UnCLAS-Prommer 2026-02-13 16:24:03 +08:00
parent b80b5afe2a
commit c14736ffca
No known key found for this signature in database
5 changed files with 91 additions and 73 deletions

12
AGENTS.md 100644
View File

@ -0,0 +1,12 @@
# import 规范
在从外部库进行导入时候,请遵循以下顺序:
1. 对于标准库和第三方库的导入,请按照如下顺序:
- 需要使用`from ... import ...`语法的导入放在前面。
- 直接使用`import ...`语法的导入放在后面。
- 对于使用`from ... import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。
- 对于使用`import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。
2. 对于本地模块的导入,请按照如下顺序:
- 对于同一个文件夹下的模块导入,使用相对导入,排列顺序按照**不发生import错误的前提下**,随便排列。
- 对于不同文件夹下的模块导入,使用绝对导入。这些导入应该以`from src`开头,并且按照**不发生import错误的前提下**,尽量使得第二层的文件夹名称相同的导入放在一起;第二层文件夹名称排列随机。
3. 标准库和第三方库的导入应该放在本地模块导入的前面。
4. 各个导入块之间应该使用一个空行进行分隔。

View File

@ -1,6 +1,28 @@
from abc import ABC, abstractmethod
from typing import Self, TypeVar, Generic, TYPE_CHECKING
import copy import copy
if TYPE_CHECKING:
from sqlmodel import SQLModel
T = TypeVar("T", bound="SQLModel")
class BaseDataModel: class BaseDataModel:
def deepcopy(self): def deepcopy(self):
return copy.deepcopy(self) return copy.deepcopy(self)
class BaseDatabaseDataModel(ABC, Generic[T]):
@abstractmethod
@classmethod
def from_db_instance(cls, db_record: T) -> Self:
"""从数据库实例创建数据模型对象"""
raise NotImplementedError
@abstractmethod
def to_db_instance(self) -> T:
"""将数据模型对象转换为数据库实例"""
raise NotImplementedError

View File

@ -2,8 +2,9 @@ from typing import Optional
from src.common.database.database_model import ChatSession from src.common.database.database_model import ChatSession
from . import BaseDatabaseDataModel
class MaiChatSession: class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None): def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
self.session_id = session_id self.session_id = session_id
self.platform = platform self.platform = platform
@ -18,10 +19,18 @@ class MaiChatSession:
self.is_group_session = bool(self.group_id) self.is_group_session = bool(self.group_id)
@classmethod @classmethod
def from_db_instance(cls, session: ChatSession): def from_db_instance(cls, db_record: ChatSession):
return cls( return cls(
session_id=session.session_id, session_id=db_record.session_id,
platform=session.platform, platform=db_record.platform,
user_id=session.user_id, user_id=db_record.user_id,
group_id=session.group_id, group_id=db_record.group_id,
) )
def to_db_instance(self) -> ChatSession:
return ChatSession(
session_id=self.session_id,
platform=self.platform,
user_id=self.user_id,
group_id=self.group_id,
)

View File

@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from PIL import Image as PILImage from PIL import Image as PILImage
@ -12,6 +11,7 @@ import traceback
from src.common.database.database_model import Images, ImageType from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger from src.common.logger import get_logger
from . import BaseDatabaseDataModel
install(extra_lines=3) install(extra_lines=3)
@ -19,20 +19,11 @@ install(extra_lines=3)
logger = get_logger("emoji") logger = get_logger("emoji")
class BaseImageDataModel(ABC): class BaseImageDataModel(BaseDatabaseDataModel[Images]):
@classmethod
@abstractmethod
def from_db_instance(cls, image: "Images"):
raise NotImplementedError
@abstractmethod
def to_db_instance(self) -> "Images":
raise NotImplementedError
def read_image_bytes(self, path: Path) -> bytes: def read_image_bytes(self, path: Path) -> bytes:
""" """
同步读取图片文件的字节内容 同步读取图片文件的字节内容
Args: Args:
path (Path): 图片文件的完整路径 path (Path): 图片文件的完整路径
Returns: Returns:
@ -75,10 +66,6 @@ class BaseImageDataModel(ABC):
raise e raise e
class ImageDataModel(BaseImageDataModel):
pass
class MaiEmoji(BaseImageDataModel): class MaiEmoji(BaseImageDataModel):
def __init__(self, full_path: str | Path): def __init__(self, full_path: str | Path):
if not full_path: if not full_path:
@ -103,15 +90,15 @@ class MaiEmoji(BaseImageDataModel):
self._format: str = "" # 图片格式 self._format: str = "" # 图片格式
@classmethod @classmethod
def from_db_instance(cls, image: Images): def from_db_instance(cls, db_record: Images):
obj = cls(image.full_path) obj = cls(db_record.full_path)
obj.emoji_hash = image.image_hash obj.emoji_hash = db_record.image_hash
obj.description = image.description obj.description = db_record.description
if image.emotion: if db_record.emotion:
obj.emotion = image.emotion.split(",") obj.emotion = db_record.emotion.split(",")
obj.query_count = image.query_count obj.query_count = db_record.query_count
obj.last_used_time = image.last_used_time obj.last_used_time = db_record.last_used_time
obj.register_time = image.register_time obj.register_time = db_record.register_time
return obj return obj
def to_db_instance(self) -> Images: def to_db_instance(self) -> Images:
@ -130,7 +117,7 @@ class MaiEmoji(BaseImageDataModel):
async def calculate_hash_format(self) -> bool: async def calculate_hash_format(self) -> bool:
""" """
异步计算表情包的哈希值和格式 异步计算表情包的哈希值和格式
Returns: Returns:
return (bool): 如果成功计算哈希值和格式则返回True否则返回False return (bool): 如果成功计算哈希值和格式则返回True否则返回False
""" """

View File

@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from maim_message import MessageBase from maim_message import MessageBase
from typing import Optional from typing import Optional
@ -10,28 +9,7 @@ from src.common.database.database_model import Messages
from src.common.data_models.message_component_model import MessageSequence from src.common.data_models.message_component_model import MessageSequence
from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_message import MessageUtils
from . import BaseDatabaseDataModel
class BaseMAIMessageModel(ABC):
@classmethod
@abstractmethod
def from_db_instance(cls, message: "Messages"):
raise NotImplementedError
@abstractmethod
def to_db_instance(self) -> "Messages":
raise NotImplementedError
@abstractmethod
def from_maim_message(cls, message: MessageBase):
raise NotImplementedError
@abstractmethod
def to_maim_message(self) -> MessageBase:
raise NotImplementedError
@abstractmethod
def parse_message_segments(self):
raise NotImplementedError
@dataclass @dataclass
@ -54,7 +32,7 @@ class MessageInfo:
additional_config: dict = field(default_factory=dict) additional_config: dict = field(default_factory=dict)
class MaiMessage(BaseMAIMessageModel): class MaiMessage(BaseDatabaseDataModel[Messages]):
def __init__(self, message_id: str, timestamp: datetime): def __init__(self, message_id: str, timestamp: datetime):
self.message_id: str = message_id self.message_id: str = message_id
self.timestamp: datetime = timestamp # 时间戳 self.timestamp: datetime = timestamp # 时间戳
@ -78,31 +56,31 @@ class MaiMessage(BaseMAIMessageModel):
self.raw_message: MessageSequence self.raw_message: MessageSequence
@classmethod @classmethod
def from_db_instance(cls, message: "Messages") -> "MaiMessage": def from_db_instance(cls, db_record: "Messages") -> "MaiMessage":
obj = cls(message_id=message.message_id, timestamp=message.timestamp) obj = cls(message_id=db_record.message_id, timestamp=db_record.timestamp)
user_info = UserInfo(message.user_id, message.user_nickname, message.user_cardname) user_info = UserInfo(db_record.user_id, db_record.user_nickname, db_record.user_cardname)
if message.group_id and message.group_name: if db_record.group_id and db_record.group_name:
group_info = GroupInfo(message.group_id, message.group_name) group_info = GroupInfo(db_record.group_id, db_record.group_name)
else: else:
group_info = None group_info = None
obj.message_info = MessageInfo( obj.message_info = MessageInfo(
user_info=user_info, user_info=user_info,
group_info=group_info, group_info=group_info,
additional_config=json.loads(message.additional_config) if message.additional_config else {}, additional_config=json.loads(db_record.additional_config) if db_record.additional_config else {},
) )
obj.is_mentioned = message.is_mentioned obj.is_mentioned = db_record.is_mentioned
obj.is_at = message.is_at obj.is_at = db_record.is_at
obj.is_emoji = message.is_emoji obj.is_emoji = db_record.is_emoji
obj.is_picture = message.is_picture obj.is_picture = db_record.is_picture
obj.is_command = message.is_command obj.is_command = db_record.is_command
obj.is_notify = message.is_notify obj.is_notify = db_record.is_notify
obj.reply_to = message.reply_to obj.reply_to = db_record.reply_to
obj.session_id = message.session_id obj.session_id = db_record.session_id
obj.processed_plain_text = message.processed_plain_text obj.processed_plain_text = db_record.processed_plain_text
obj.display_message = message.display_message obj.display_message = db_record.display_message
obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(message.raw_content) obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(db_record.raw_content)
return obj return obj
def to_db_instance(self) -> Messages: def to_db_instance(self) -> Messages:
@ -131,3 +109,13 @@ class MaiMessage(BaseMAIMessageModel):
display_message=self.display_message, display_message=self.display_message,
additional_config=additional_config, additional_config=additional_config,
) )
@classmethod
def from_maim_message(cls, message: MessageBase) -> "MaiMessage":
raise NotImplementedError
def to_maim_message(self) -> MessageBase:
raise NotImplementedError
def parse_message_segments(self):
raise NotImplementedError