mirror of https://github.com/Mai-with-u/MaiBot.git
更好更规范的类型注解;AGENTSMD试作
parent
b80b5afe2a
commit
c14736ffca
|
|
@ -0,0 +1,12 @@
|
|||
# import 规范
|
||||
在从外部库进行导入时候,请遵循以下顺序:
|
||||
1. 对于标准库和第三方库的导入,请按照如下顺序:
|
||||
- 需要使用`from ... import ...`语法的导入放在前面。
|
||||
- 直接使用`import ...`语法的导入放在后面。
|
||||
- 对于使用`from ... import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。
|
||||
- 对于使用`import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。
|
||||
2. 对于本地模块的导入,请按照如下顺序:
|
||||
- 对于同一个文件夹下的模块导入,使用相对导入,排列顺序按照**不发生import错误的前提下**,随便排列。
|
||||
- 对于不同文件夹下的模块导入,使用绝对导入。这些导入应该以`from src`开头,并且按照**不发生import错误的前提下**,尽量使得第二层的文件夹名称相同的导入放在一起;第二层文件夹名称排列随机。
|
||||
3. 标准库和第三方库的导入应该放在本地模块导入的前面。
|
||||
4. 各个导入块之间应该使用一个空行进行分隔。
|
||||
|
|
@ -1,6 +1,28 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from typing import Self, TypeVar, Generic, TYPE_CHECKING
|
||||
|
||||
import copy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
T = TypeVar("T", bound="SQLModel")
|
||||
|
||||
|
||||
class BaseDataModel:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
class BaseDatabaseDataModel(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: T) -> Self:
|
||||
"""从数据库实例创建数据模型对象"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_db_instance(self) -> T:
|
||||
"""将数据模型对象转换为数据库实例"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ from typing import Optional
|
|||
|
||||
from src.common.database.database_model import ChatSession
|
||||
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
class MaiChatSession:
|
||||
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
||||
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
|
||||
self.session_id = session_id
|
||||
self.platform = platform
|
||||
|
|
@ -18,10 +19,18 @@ class MaiChatSession:
|
|||
self.is_group_session = bool(self.group_id)
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, session: ChatSession):
|
||||
def from_db_instance(cls, db_record: ChatSession):
|
||||
return cls(
|
||||
session_id=session.session_id,
|
||||
platform=session.platform,
|
||||
user_id=session.user_id,
|
||||
group_id=session.group_id,
|
||||
session_id=db_record.session_id,
|
||||
platform=db_record.platform,
|
||||
user_id=db_record.user_id,
|
||||
group_id=db_record.group_id,
|
||||
)
|
||||
|
||||
def to_db_instance(self) -> ChatSession:
|
||||
return ChatSession(
|
||||
session_id=self.session_id,
|
||||
platform=self.platform,
|
||||
user_id=self.user_id,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from PIL import Image as PILImage
|
||||
|
|
@ -12,6 +11,7 @@ import traceback
|
|||
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.logger import get_logger
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
|
@ -19,20 +19,11 @@ install(extra_lines=3)
|
|||
logger = get_logger("emoji")
|
||||
|
||||
|
||||
class BaseImageDataModel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_db_instance(cls, image: "Images"):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_db_instance(self) -> "Images":
|
||||
raise NotImplementedError
|
||||
|
||||
class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||
def read_image_bytes(self, path: Path) -> bytes:
|
||||
"""
|
||||
同步读取图片文件的字节内容
|
||||
|
||||
|
||||
Args:
|
||||
path (Path): 图片文件的完整路径
|
||||
Returns:
|
||||
|
|
@ -75,10 +66,6 @@ class BaseImageDataModel(ABC):
|
|||
raise e
|
||||
|
||||
|
||||
class ImageDataModel(BaseImageDataModel):
|
||||
pass
|
||||
|
||||
|
||||
class MaiEmoji(BaseImageDataModel):
|
||||
def __init__(self, full_path: str | Path):
|
||||
if not full_path:
|
||||
|
|
@ -103,15 +90,15 @@ class MaiEmoji(BaseImageDataModel):
|
|||
self._format: str = "" # 图片格式
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, image: Images):
|
||||
obj = cls(image.full_path)
|
||||
obj.emoji_hash = image.image_hash
|
||||
obj.description = image.description
|
||||
if image.emotion:
|
||||
obj.emotion = image.emotion.split(",")
|
||||
obj.query_count = image.query_count
|
||||
obj.last_used_time = image.last_used_time
|
||||
obj.register_time = image.register_time
|
||||
def from_db_instance(cls, db_record: Images):
|
||||
obj = cls(db_record.full_path)
|
||||
obj.emoji_hash = db_record.image_hash
|
||||
obj.description = db_record.description
|
||||
if db_record.emotion:
|
||||
obj.emotion = db_record.emotion.split(",")
|
||||
obj.query_count = db_record.query_count
|
||||
obj.last_used_time = db_record.last_used_time
|
||||
obj.register_time = db_record.register_time
|
||||
return obj
|
||||
|
||||
def to_db_instance(self) -> Images:
|
||||
|
|
@ -130,7 +117,7 @@ class MaiEmoji(BaseImageDataModel):
|
|||
async def calculate_hash_format(self) -> bool:
|
||||
"""
|
||||
异步计算表情包的哈希值和格式
|
||||
|
||||
|
||||
Returns:
|
||||
return (bool): 如果成功计算哈希值和格式则返回True,否则返回False
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from maim_message import MessageBase
|
||||
from typing import Optional
|
||||
|
|
@ -10,28 +9,7 @@ from src.common.database.database_model import Messages
|
|||
from src.common.data_models.message_component_model import MessageSequence
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
|
||||
|
||||
class BaseMAIMessageModel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_db_instance(cls, message: "Messages"):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_db_instance(self) -> "Messages":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def from_maim_message(cls, message: MessageBase):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_maim_message(self) -> MessageBase:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_message_segments(self):
|
||||
raise NotImplementedError
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -54,7 +32,7 @@ class MessageInfo:
|
|||
additional_config: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class MaiMessage(BaseMAIMessageModel):
|
||||
class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||
def __init__(self, message_id: str, timestamp: datetime):
|
||||
self.message_id: str = message_id
|
||||
self.timestamp: datetime = timestamp # 时间戳
|
||||
|
|
@ -78,31 +56,31 @@ class MaiMessage(BaseMAIMessageModel):
|
|||
self.raw_message: MessageSequence
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, message: "Messages") -> "MaiMessage":
|
||||
obj = cls(message_id=message.message_id, timestamp=message.timestamp)
|
||||
def from_db_instance(cls, db_record: "Messages") -> "MaiMessage":
|
||||
obj = cls(message_id=db_record.message_id, timestamp=db_record.timestamp)
|
||||
|
||||
user_info = UserInfo(message.user_id, message.user_nickname, message.user_cardname)
|
||||
if message.group_id and message.group_name:
|
||||
group_info = GroupInfo(message.group_id, message.group_name)
|
||||
user_info = UserInfo(db_record.user_id, db_record.user_nickname, db_record.user_cardname)
|
||||
if db_record.group_id and db_record.group_name:
|
||||
group_info = GroupInfo(db_record.group_id, db_record.group_name)
|
||||
else:
|
||||
group_info = None
|
||||
obj.message_info = MessageInfo(
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
additional_config=json.loads(message.additional_config) if message.additional_config else {},
|
||||
additional_config=json.loads(db_record.additional_config) if db_record.additional_config else {},
|
||||
)
|
||||
|
||||
obj.is_mentioned = message.is_mentioned
|
||||
obj.is_at = message.is_at
|
||||
obj.is_emoji = message.is_emoji
|
||||
obj.is_picture = message.is_picture
|
||||
obj.is_command = message.is_command
|
||||
obj.is_notify = message.is_notify
|
||||
obj.reply_to = message.reply_to
|
||||
obj.session_id = message.session_id
|
||||
obj.processed_plain_text = message.processed_plain_text
|
||||
obj.display_message = message.display_message
|
||||
obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(message.raw_content)
|
||||
obj.is_mentioned = db_record.is_mentioned
|
||||
obj.is_at = db_record.is_at
|
||||
obj.is_emoji = db_record.is_emoji
|
||||
obj.is_picture = db_record.is_picture
|
||||
obj.is_command = db_record.is_command
|
||||
obj.is_notify = db_record.is_notify
|
||||
obj.reply_to = db_record.reply_to
|
||||
obj.session_id = db_record.session_id
|
||||
obj.processed_plain_text = db_record.processed_plain_text
|
||||
obj.display_message = db_record.display_message
|
||||
obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(db_record.raw_content)
|
||||
return obj
|
||||
|
||||
def to_db_instance(self) -> Messages:
|
||||
|
|
@ -131,3 +109,13 @@ class MaiMessage(BaseMAIMessageModel):
|
|||
display_message=self.display_message,
|
||||
additional_config=additional_config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_maim_message(cls, message: MessageBase) -> "MaiMessage":
|
||||
raise NotImplementedError
|
||||
|
||||
def to_maim_message(self) -> MessageBase:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_message_segments(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
Loading…
Reference in New Issue