MaiBot/src/common/data_models/image_data_model.py

162 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from datetime import datetime
from pathlib import Path
from PIL import Image as PILImage
from rich.traceback import install
from typing import Optional, List
import asyncio
import hashlib
import io
import traceback
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from . import BaseDatabaseDataModel
install(extra_lines=3)
logger = get_logger("emoji")
class BaseImageDataModel(BaseDatabaseDataModel[Images]):
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
if not full_path:
# 创建时候即检测文件路径合法性
raise ValueError("表情包路径不能为空")
if Path(full_path).is_dir() or not Path(full_path).exists():
raise FileNotFoundError(f"表情包路径无效: {full_path}")
resolved_path = Path(full_path).absolute().resolve()
self.full_path: Path = resolved_path
self.dir_path: Path = resolved_path.parent.resolve()
self.file_name: str = resolved_path.name
self.file_hash: str = None # type: ignore
self.image_bytes: Optional[bytes] = image_bytes
def read_image_bytes(self, path: Path) -> bytes:
"""
同步读取图片文件的字节内容
Args:
path (Path): 图片文件的完整路径
Returns:
return (bytes): 图片文件的字节内容
Raises:
FileNotFoundError: 如果文件不存在则抛出该异常
Exception: 其他读取文件时发生的异常
"""
try:
with open(path, "rb") as f:
return f.read()
except FileNotFoundError as e:
logger.error(f"[读取图片文件] 文件未找到: {path}")
raise e
except Exception as e:
logger.error(f"[读取图片文件] 读取文件时发生错误: {e}")
raise e
def get_image_format(self, image_bytes: bytes) -> str:
"""
获取图片的格式
Args:
image_bytes (bytes): 图片的字节内容
Returns:
return (str): 图片的格式(小写)
Raises:
ValueError: 如果无法识别图片格式
Exception: 其他读取图片格式时发生的异常
"""
try:
with PILImage.open(io.BytesIO(image_bytes)) as img:
if not img.format:
raise ValueError("无法识别图片格式")
return img.format.lower()
except Exception as e:
logger.error(f"[获取图片格式] 读取图片格式时发生错误: {e}")
raise e
async def calculate_hash_format(self) -> bool:
"""
异步计算表情包的哈希值和格式
Returns:
return (bool): 如果成功计算哈希值和格式则返回True否则返回False
"""
try:
# 计算哈希值
logger.debug(f"[初始化] 计算 {self.file_name} 的哈希值...")
if not self.image_bytes:
logger.debug(f"[初始化] 正在读取文件: {self.full_path}")
image_bytes = await asyncio.to_thread(self.read_image_bytes, self.full_path)
else:
image_bytes = self.image_bytes
self.file_hash = hashlib.sha256(image_bytes).hexdigest()
logger.debug(f"[初始化] {self.file_name} 计算哈希值成功: {self.file_hash}")
# 用PIL读取图片格式
logger.debug(f"[初始化] 读取 {self.file_name} 的图片格式...")
self._format = await asyncio.to_thread(self.get_image_format, image_bytes)
logger.debug(f"[初始化] {self.file_name} 读取图片格式成功: {self._format}")
# 比对文件扩展名和实际格式
file_ext = self.file_name.split(".")[-1].lower()
if file_ext != self._format:
logger.warning(f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self._format}`")
# 重命名文件以匹配实际格式
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self._format])
new_full_path = self.dir_path / new_file_name
self.full_path.rename(new_full_path)
self.full_path = new_full_path
return True
except Exception as e:
logger.error(f"[初始化] 初始化图片时发生错误: {e}")
logger.error(traceback.format_exc())
self.is_deleted = True
return False
class MaiEmoji(BaseImageDataModel):
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
# self.embedding = []
self.description = ""
self.emotion: List[str] = []
self.query_count = 0
self.register_time: Optional[datetime] = None
self.last_used_time: Optional[datetime] = None
# 私有属性
self.is_deleted = False
self._format: str = "" # 图片格式
super().__init__(full_path, image_bytes)
@classmethod
def from_db_instance(cls, db_record: Images):
obj = cls(db_record.full_path)
obj.file_hash = db_record.image_hash
obj.description = db_record.description
if db_record.emotion:
obj.emotion = db_record.emotion.split(",")
obj.query_count = db_record.query_count
obj.last_used_time = db_record.last_used_time
obj.register_time = db_record.register_time
return obj
def to_db_instance(self) -> Images:
emotion_str = ",".join(self.emotion) if self.emotion else None
return Images(
image_hash=self.file_hash,
description=self.description,
full_path=str(self.full_path),
image_type=ImageType.EMOJI,
emotion=emotion_str,
query_count=self.query_count,
last_used_time=self.last_used_time,
register_time=self.register_time,
)