diff --git a/pytests/image_sys_test/emoji_manager_test.py b/pytests/image_sys_test/emoji_manager_test.py index 75d2c05f..8a788938 100644 --- a/pytests/image_sys_test/emoji_manager_test.py +++ b/pytests/image_sys_test/emoji_manager_test.py @@ -58,6 +58,7 @@ def _install_stub_modules(monkeypatch): query_count: int = 0 register_time: object | None = None image_format: str | None = None + image_bytes: bytes | None = None @staticmethod def from_db_instance(_record): @@ -128,10 +129,17 @@ def _install_stub_modules(monkeypatch): def flush(self): pass + def commit(self): + pass + def get_db_session(): return _DummySession() + def get_db_session_manual(): + return _DummySession() + db_mod.get_db_session = get_db_session + db_mod.get_db_session_manual = get_db_session_manual # src.common.utils.utils_image image_utils_mod = _stub_module("src.common.utils.utils_image") @@ -236,6 +244,15 @@ def import_emoji_manager_new(monkeypatch): module = importlib.util.module_from_spec(spec) monkeypatch.setitem(sys.modules, "emoji_manager_new", module) spec.loader.exec_module(module) + + class _Select: + def filter_by(self, **kwargs): + return self + + def limit(self, n): + return self + + module.select = lambda _model: _Select() return module @@ -763,6 +780,12 @@ def test_register_emoji_to_db_db_error(monkeypatch): def flush(self): pass + def exec(self, _statement): + return self + + def first(self): + return None + def _get_db_session(): return _Session() @@ -821,6 +844,12 @@ def test_register_emoji_to_db_success(monkeypatch, tmp_path): def flush(self): pass + def exec(self, _statement): + return self + + def first(self): + return None + def _get_db_session(): return _Session() @@ -858,6 +887,7 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self + def limit(self, _num): return self @@ -948,7 +978,7 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self - + def limit(self, _num): return self @@ -1006,6 +1036,7 @@ def test_delete_emoji_success(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self + def limit(self, _num): return self @@ -1056,6 +1087,7 @@ def test_delete_emoji_success(monkeypatch): assert any("成功删除表情包文件" in m for m in _messages(logger.info_calls)) assert any("成功修改数据库中的表情包记录" in m for m in _messages(logger.info_calls)) + def test_delete_emoji_no_desc_deletes_record(monkeypatch): emoji_manager_new = import_emoji_manager_new(monkeypatch) logger = emoji_manager_new.logger @@ -1133,6 +1165,7 @@ def test_update_emoji_usage_success(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self + def limit(self, _num): return self @@ -1143,6 +1176,7 @@ def test_update_emoji_usage_success(monkeypatch): def __init__(self): self.query_count = 2 self.last_used_time = None + record = _Record() class _Result: @@ -1237,6 +1271,7 @@ def test_update_emoji_usage_execute_error(monkeypatch): class _Select: def filter_by(self, **_kwargs): return self + def limit(self, _num): return self @@ -2254,6 +2289,38 @@ async def test_register_emoji_by_filename_hash_format_failed(monkeypatch, tmp_pa async def calculate_hash_format(self): return False + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def exec(self, _statement): + return self + + def first(self): + return None + + class _Select: + def __init__(self) -> None: + pass + + def filter_by(self, **_kwargs): + return self + + def limit(self, _num): + return self + + def _get_db_session_manual(): + return _Session() + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "get_db_session_manual", _get_db_session_manual) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + monkeypatch.setattr(emoji_manager_new, "select", lambda _model: _Select()) monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) result = await manager.register_emoji_by_filename(file_path) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 23118e40..996244e1 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -5,6 +5,7 @@ from sqlmodel import select from typing import Optional, Tuple, List import asyncio +import hashlib import heapq import Levenshtein import random @@ -13,7 +14,7 @@ import re from src.common.logger import get_logger from src.common.data_models.image_data_model import MaiEmoji from src.common.database.database_model import Images, ImageType -from src.common.database.database import get_db_session +from src.common.database.database import get_db_session, get_db_session_manual from src.common.utils.utils_image import ImageUtils from src.prompt.prompt_manager import prompt_manager from src.config.config import global_config @@ -43,6 +44,10 @@ emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_c class EmojiManager: + """ + 表情包管理器 + """ + def __init__(self): _ensure_directories() @@ -51,6 +56,57 @@ class EmojiManager: logger.info("启动表情包管理器") + async def get_emoji_description_by_bytes(self, emoji_bytes: bytes) -> Optional[Tuple[str, List[str]]]: + """ + 根据表情包哈希获取表情包描述的封装方法 + + Returns: + return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包,则返回包含描述和情感标签的元组;若没找到,则尝试构建表情包描述并返回,如果构建失败则返回 None + """ + # 先查找 + emoji_hash = hashlib.sha256(emoji_bytes).hexdigest() + if emoji := self.get_emoji_by_hash(emoji_hash): + return emoji.description, emoji.emotion or [] + try: + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1) + if result := session.exec(statement).first(): + return result.description, result.emotion.split(",") if result.emotion else [] + except Exception as e: + logger.warning(f"从数据库查找表情包时出错: {e},将尝试构建表情包描述") + + # 找不到尝试构建 + logger.info(f"未找到哈希值为 {emoji_hash} 的表情包与其描述,尝试构建描述") + full_path = EMOJI_DIR / f"{emoji_hash}.png" + try: + full_path.write_bytes(emoji_bytes) + new_emoji = MaiEmoji(full_path=full_path, image_bytes=emoji_bytes) + await new_emoji.calculate_hash_format() + except Exception as e: + logger.error(f"缓存表情包文件时出错: {e}") + raise e + success_desc, new_emoji = await self.build_emoji_description(new_emoji) + if not success_desc: + logger.error("构建表情包描述失败") + return None + success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji) + if not success_emotion: + logger.error("构建表情包情感标签失败") + return None + + # 缓存结果到数据库 + with get_db_session() as session: + try: + image_record = new_emoji.to_db_instance() + image_record.is_registered = False + image_record.is_banned = False + image_record.register_time = datetime.now() + image_record.no_file_flag = True + session.add(image_record) + except Exception as e: + logger.error(f"缓存表情包描述时出错: {e}") + return new_emoji.description, new_emoji.emotion or [] + def load_emojis_from_db(self) -> None: """ 从数据库加载已注册的表情包 @@ -114,14 +170,33 @@ class EmojiManager: # 注册到数据库 try: with get_db_session() as session: - image_record = emoji.to_db_instance() - image_record.is_registered = True - image_record.is_banned = False - image_record.register_time = datetime.now() - session.add(image_record) - session.flush() - record_id = image_record.id - logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}") + statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1) + if existing_record := session.exec(statement).first(): + if existing_record.no_file_flag: + existing_record.no_file_flag = False + existing_record.is_banned = False + existing_record.full_path = str(emoji.full_path) + existing_record.description = emoji.description + existing_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None + existing_record.query_count = emoji.query_count + existing_record.last_used_time = emoji.last_used_time + existing_record.register_time = emoji.register_time + session.add(existing_record) + logger.info( + f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}" + ) + else: + logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}") + return False + else: + image_record = emoji.to_db_instance() + image_record.is_registered = True + image_record.is_banned = False + image_record.register_time = datetime.now() + session.add(image_record) + session.flush() + record_id = image_record.id + logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}") except Exception as e: logger.error(f"[注册表情包] 注册到数据库时出错: {e}") return False @@ -401,7 +476,9 @@ class EmojiManager: # 调用VLM生成描述 image_format = target_emoji.image_format - image_bytes = await asyncio.to_thread(target_emoji.read_image_bytes, target_emoji.full_path) + image_bytes = target_emoji.image_bytes or await asyncio.to_thread( + target_emoji.read_image_bytes, target_emoji.full_path + ) if image_format == "gif": try: @@ -567,6 +644,29 @@ class EmojiManager: logger.error(f"[注册表情包] 创建表情包对象时出错: {e}") return False + # 0. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建 + try: + with get_db_session_manual() as session: + statement = ( + select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1) + ) + if image_record := session.exec(statement).first(): + if image_record.no_file_flag: + image_record.no_file_flag = False + image_record.is_banned = False + image_record.is_registered = True + image_record.full_path = str(target_emoji.full_path) + session.add(image_record) + session.commit() + logger.info(f"表情包注册成功,Hash: {target_emoji.file_hash}") + return True + else: + logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}") + return False + except Exception as e: + logger.error(f"[注册表情包] 查询数据库时出错: {e}") + return False + # 1. 计算哈希值和格式 calc_success = await target_emoji.calculate_hash_format() if not calc_success: diff --git a/src/common/data_models/image_data_model.py b/src/common/data_models/image_data_model.py index 60fb1f31..529a6423 100644 --- a/src/common/data_models/image_data_model.py +++ b/src/common/data_models/image_data_model.py @@ -83,7 +83,7 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]): async def calculate_hash_format(self) -> bool: """ - 异步计算表情包的哈希值和格式 + 异步计算表情包的哈希值和格式,初始化后应该执行此方法来确保对象的哈希值和格式正确 Returns: return (bool): 如果成功计算哈希值和格式则返回True,否则返回False