mirror of https://github.com/Mai-with-u/MaiBot.git
方法名调整;确保公共属性被定义
parent
275608abea
commit
75e154741d
|
|
@ -58,7 +58,7 @@ def _install_stub_modules(monkeypatch):
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
query_count: int = 0
|
query_count: int = 0
|
||||||
register_time: object | None = None
|
register_time: object | None = None
|
||||||
_format: str | None = None
|
image_format: str | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_db_instance(_record):
|
def from_db_instance(_record):
|
||||||
|
|
@ -1431,7 +1431,7 @@ async def test_build_emoji_description_calls_hash_and_sets_description(monkeypat
|
||||||
|
|
||||||
emoji = emoji_manager_new.MaiEmoji()
|
emoji = emoji_manager_new.MaiEmoji()
|
||||||
emoji.file_hash = None
|
emoji.file_hash = None
|
||||||
emoji._format = "png"
|
emoji.image_format = "png"
|
||||||
emoji.full_path = Path("/tmp/a.png")
|
emoji.full_path = Path("/tmp/a.png")
|
||||||
|
|
||||||
result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji)
|
result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji)
|
||||||
|
|
@ -1459,7 +1459,7 @@ async def test_build_emoji_description_gif_conversion_error(monkeypatch):
|
||||||
|
|
||||||
emoji = emoji_manager_new.MaiEmoji()
|
emoji = emoji_manager_new.MaiEmoji()
|
||||||
emoji.file_hash = "hash"
|
emoji.file_hash = "hash"
|
||||||
emoji._format = "gif"
|
emoji.image_format = "gif"
|
||||||
emoji.full_path = Path("/tmp/a.gif")
|
emoji.full_path = Path("/tmp/a.gif")
|
||||||
|
|
||||||
result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji)
|
result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji)
|
||||||
|
|
@ -1497,7 +1497,7 @@ async def test_build_emoji_description_content_filtration_reject(monkeypatch):
|
||||||
|
|
||||||
emoji = emoji_manager_new.MaiEmoji()
|
emoji = emoji_manager_new.MaiEmoji()
|
||||||
emoji.file_hash = "hash"
|
emoji.file_hash = "hash"
|
||||||
emoji._format = "png"
|
emoji.image_format = "png"
|
||||||
emoji.full_path = Path("/tmp/a.png")
|
emoji.full_path = Path("/tmp/a.png")
|
||||||
|
|
||||||
result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji)
|
result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji)
|
||||||
|
|
@ -1532,7 +1532,7 @@ async def test_build_emoji_description_content_filtration_pass(monkeypatch):
|
||||||
|
|
||||||
emoji = emoji_manager_new.MaiEmoji()
|
emoji = emoji_manager_new.MaiEmoji()
|
||||||
emoji.file_hash = "hash"
|
emoji.file_hash = "hash"
|
||||||
emoji._format = "png"
|
emoji.image_format = "png"
|
||||||
emoji.full_path = Path("/tmp/a.png")
|
emoji.full_path = Path("/tmp/a.png")
|
||||||
|
|
||||||
result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji)
|
result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji)
|
||||||
|
|
@ -1561,7 +1561,7 @@ async def test_build_emoji_description_vlm_exception_propagates(monkeypatch):
|
||||||
|
|
||||||
emoji = emoji_manager_new.MaiEmoji()
|
emoji = emoji_manager_new.MaiEmoji()
|
||||||
emoji.file_hash = "hash"
|
emoji.file_hash = "hash"
|
||||||
emoji._format = "png"
|
emoji.image_format = "png"
|
||||||
emoji.full_path = Path("/tmp/a.png")
|
emoji.full_path = Path("/tmp/a.png")
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
|
|
|
||||||
|
|
@ -360,7 +360,7 @@ class EmojiManager:
|
||||||
await target_emoji.calculate_hash_format()
|
await target_emoji.calculate_hash_format()
|
||||||
|
|
||||||
# 调用VLM生成描述
|
# 调用VLM生成描述
|
||||||
image_format = target_emoji._format
|
image_format = target_emoji.image_format
|
||||||
image_bytes = target_emoji.read_image_bytes(target_emoji.full_path)
|
image_bytes = target_emoji.read_image_bytes(target_emoji.full_path)
|
||||||
|
|
||||||
if image_format == "gif":
|
if image_format == "gif":
|
||||||
|
|
@ -415,7 +415,7 @@ class EmojiManager:
|
||||||
emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template)
|
emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template)
|
||||||
# 调用LLM生成情感标签
|
# 调用LLM生成情感标签
|
||||||
emotion_result, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
|
emotion_result, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
|
||||||
emotion_prompt, temperature=0.7, max_tokens=200
|
emotion_prompt, temperature=0.3, max_tokens=200
|
||||||
)
|
)
|
||||||
|
|
||||||
# 解析情感标签结果
|
# 解析情感标签结果
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,9 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||||
self.file_hash: str = None # type: ignore
|
self.file_hash: str = None # type: ignore
|
||||||
|
|
||||||
self.image_bytes: Optional[bytes] = image_bytes
|
self.image_bytes: Optional[bytes] = image_bytes
|
||||||
|
|
||||||
|
self.image_format: str = "" # 图片格式
|
||||||
|
self.is_deleted: bool = False # 是否已被标记为删除
|
||||||
|
|
||||||
def read_image_bytes(self, path: Path) -> bytes:
|
def read_image_bytes(self, path: Path) -> bytes:
|
||||||
"""
|
"""
|
||||||
|
|
@ -100,15 +103,15 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||||
|
|
||||||
# 用PIL读取图片格式
|
# 用PIL读取图片格式
|
||||||
logger.debug(f"[初始化] 读取 {self.file_name} 的图片格式...")
|
logger.debug(f"[初始化] 读取 {self.file_name} 的图片格式...")
|
||||||
self._format = await asyncio.to_thread(self.get_image_format, image_bytes)
|
self.image_format = await asyncio.to_thread(self.get_image_format, image_bytes)
|
||||||
logger.debug(f"[初始化] {self.file_name} 读取图片格式成功: {self._format}")
|
logger.debug(f"[初始化] {self.file_name} 读取图片格式成功: {self.image_format}")
|
||||||
|
|
||||||
# 比对文件扩展名和实际格式
|
# 比对文件扩展名和实际格式
|
||||||
file_ext = self.file_name.split(".")[-1].lower()
|
file_ext = self.file_name.split(".")[-1].lower()
|
||||||
if file_ext != self._format:
|
if file_ext != self.image_format:
|
||||||
logger.warning(f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self._format}`")
|
logger.warning(f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self.image_format}`")
|
||||||
# 重命名文件以匹配实际格式
|
# 重命名文件以匹配实际格式
|
||||||
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self._format])
|
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self.image_format])
|
||||||
new_full_path = self.dir_path / new_file_name
|
new_full_path = self.dir_path / new_file_name
|
||||||
self.full_path.rename(new_full_path)
|
self.full_path.rename(new_full_path)
|
||||||
self.full_path = new_full_path
|
self.full_path = new_full_path
|
||||||
|
|
@ -124,15 +127,11 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||||
class MaiEmoji(BaseImageDataModel):
|
class MaiEmoji(BaseImageDataModel):
|
||||||
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
|
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
|
||||||
# self.embedding = []
|
# self.embedding = []
|
||||||
self.description = ""
|
self.description: str = ""
|
||||||
self.emotion: List[str] = []
|
self.emotion: List[str] = []
|
||||||
self.query_count = 0
|
self.query_count = 0
|
||||||
self.register_time: Optional[datetime] = None
|
self.register_time: Optional[datetime] = None
|
||||||
self.last_used_time: Optional[datetime] = None
|
self.last_used_time: Optional[datetime] = None
|
||||||
|
|
||||||
# 私有属性
|
|
||||||
self.is_deleted = False
|
|
||||||
self._format: str = "" # 图片格式
|
|
||||||
super().__init__(full_path, image_bytes)
|
super().__init__(full_path, image_bytes)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -159,3 +158,28 @@ class MaiEmoji(BaseImageDataModel):
|
||||||
last_used_time=self.last_used_time,
|
last_used_time=self.last_used_time,
|
||||||
register_time=self.register_time,
|
register_time=self.register_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MaiImage(BaseImageDataModel):
|
||||||
|
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
|
||||||
|
self.description: str = ""
|
||||||
|
self.vlm_processed: bool = False
|
||||||
|
super().__init__(full_path, image_bytes)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db_instance(cls, db_record: Images):
|
||||||
|
obj = cls(db_record.full_path)
|
||||||
|
obj.file_hash = db_record.image_hash
|
||||||
|
obj.full_path = Path(db_record.full_path)
|
||||||
|
obj.description = db_record.description
|
||||||
|
obj.vlm_processed = db_record.vlm_processed
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def to_db_instance(self) -> Images:
|
||||||
|
return Images(
|
||||||
|
image_hash=self.file_hash,
|
||||||
|
description=self.description,
|
||||||
|
full_path=str(self.full_path),
|
||||||
|
image_type=ImageType.IMAGE,
|
||||||
|
vlm_processed=self.vlm_processed,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ logger = get_logger("file_utils")
|
||||||
|
|
||||||
class FileUtils:
|
class FileUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_bytes_to_file(file_path: Path, data: bytes):
|
def save_binary_to_file(file_path: Path, data: bytes):
|
||||||
"""
|
"""
|
||||||
将字节数据保存到指定文件路径
|
将字节数据保存到指定文件路径
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue