获取和注册一体化修正

pull/1496/head
UnCLAS-Prommer 2026-02-18 21:34:56 +08:00
parent ccd1be7bed
commit 537b24c24e
No known key found for this signature in database
3 changed files with 179 additions and 12 deletions

View File

@ -58,6 +58,7 @@ def _install_stub_modules(monkeypatch):
query_count: int = 0
register_time: object | None = None
image_format: str | None = None
image_bytes: bytes | None = None
@staticmethod
def from_db_instance(_record):
@ -128,10 +129,17 @@ def _install_stub_modules(monkeypatch):
def flush(self):
pass
def commit(self):
pass
def get_db_session():
return _DummySession()
def get_db_session_manual():
return _DummySession()
db_mod.get_db_session = get_db_session
db_mod.get_db_session_manual = get_db_session_manual
# src.common.utils.utils_image
image_utils_mod = _stub_module("src.common.utils.utils_image")
@ -236,6 +244,15 @@ def import_emoji_manager_new(monkeypatch):
module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "emoji_manager_new", module)
spec.loader.exec_module(module)
class _Select:
def filter_by(self, **kwargs):
return self
def limit(self, n):
return self
module.select = lambda _model: _Select()
return module
@ -763,6 +780,12 @@ def test_register_emoji_to_db_db_error(monkeypatch):
def flush(self):
pass
def exec(self, _statement):
return self
def first(self):
return None
def _get_db_session():
return _Session()
@ -821,6 +844,12 @@ def test_register_emoji_to_db_success(monkeypatch, tmp_path):
def flush(self):
pass
def exec(self, _statement):
return self
def first(self):
return None
def _get_db_session():
return _Session()
@ -858,6 +887,7 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
@ -948,7 +978,7 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
@ -1006,6 +1036,7 @@ def test_delete_emoji_success(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
@ -1056,6 +1087,7 @@ def test_delete_emoji_success(monkeypatch):
assert any("成功删除表情包文件" in m for m in _messages(logger.info_calls))
assert any("成功修改数据库中的表情包记录" in m for m in _messages(logger.info_calls))
def test_delete_emoji_no_desc_deletes_record(monkeypatch):
emoji_manager_new = import_emoji_manager_new(monkeypatch)
logger = emoji_manager_new.logger
@ -1133,6 +1165,7 @@ def test_update_emoji_usage_success(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
@ -1143,6 +1176,7 @@ def test_update_emoji_usage_success(monkeypatch):
def __init__(self):
self.query_count = 2
self.last_used_time = None
record = _Record()
class _Result:
@ -1237,6 +1271,7 @@ def test_update_emoji_usage_execute_error(monkeypatch):
class _Select:
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
@ -2254,6 +2289,38 @@ async def test_register_emoji_by_filename_hash_format_failed(monkeypatch, tmp_pa
async def calculate_hash_format(self):
return False
class _Session:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def exec(self, _statement):
return self
def first(self):
return None
class _Select:
def __init__(self) -> None:
pass
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
def _get_db_session_manual():
return _Session()
def _get_db_session():
return _Session()
monkeypatch.setattr(emoji_manager_new, "get_db_session_manual", _get_db_session_manual)
monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session)
monkeypatch.setattr(emoji_manager_new, "select", lambda _model: _Select())
monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji)
result = await manager.register_emoji_by_filename(file_path)

View File

@ -5,6 +5,7 @@ from sqlmodel import select
from typing import Optional, Tuple, List
import asyncio
import hashlib
import heapq
import Levenshtein
import random
@ -13,7 +14,7 @@ import re
from src.common.logger import get_logger
from src.common.data_models.image_data_model import MaiEmoji
from src.common.database.database_model import Images, ImageType
from src.common.database.database import get_db_session
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.utils.utils_image import ImageUtils
from src.prompt.prompt_manager import prompt_manager
from src.config.config import global_config
@ -43,6 +44,10 @@ emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_c
class EmojiManager:
"""
表情包管理器
"""
def __init__(self):
_ensure_directories()
@ -51,6 +56,57 @@ class EmojiManager:
logger.info("启动表情包管理器")
async def get_emoji_description_by_bytes(self, emoji_bytes: bytes) -> Optional[Tuple[str, List[str]]]:
"""
根据表情包哈希获取表情包描述的封装方法
Returns:
return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包则返回包含描述和情感标签的元组若没找到则尝试构建表情包描述并返回如果构建失败则返回 None
"""
# 先查找
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
if emoji := self.get_emoji_by_hash(emoji_hash):
return emoji.description, emoji.emotion or []
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
if result := session.exec(statement).first():
return result.description, result.emotion.split(",") if result.emotion else []
except Exception as e:
logger.warning(f"从数据库查找表情包时出错: {e},将尝试构建表情包描述")
# 找不到尝试构建
logger.info(f"未找到哈希值为 {emoji_hash} 的表情包与其描述,尝试构建描述")
full_path = EMOJI_DIR / f"{emoji_hash}.png"
try:
full_path.write_bytes(emoji_bytes)
new_emoji = MaiEmoji(full_path=full_path, image_bytes=emoji_bytes)
await new_emoji.calculate_hash_format()
except Exception as e:
logger.error(f"缓存表情包文件时出错: {e}")
raise e
success_desc, new_emoji = await self.build_emoji_description(new_emoji)
if not success_desc:
logger.error("构建表情包描述失败")
return None
success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji)
if not success_emotion:
logger.error("构建表情包情感标签失败")
return None
# 缓存结果到数据库
with get_db_session() as session:
try:
image_record = new_emoji.to_db_instance()
image_record.is_registered = False
image_record.is_banned = False
image_record.register_time = datetime.now()
image_record.no_file_flag = True
session.add(image_record)
except Exception as e:
logger.error(f"缓存表情包描述时出错: {e}")
return new_emoji.description, new_emoji.emotion or []
def load_emojis_from_db(self) -> None:
"""
从数据库加载已注册的表情包
@ -114,14 +170,33 @@ class EmojiManager:
# 注册到数据库
try:
with get_db_session() as session:
image_record = emoji.to_db_instance()
image_record.is_registered = True
image_record.is_banned = False
image_record.register_time = datetime.now()
session.add(image_record)
session.flush()
record_id = image_record.id
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
if existing_record := session.exec(statement).first():
if existing_record.no_file_flag:
existing_record.no_file_flag = False
existing_record.is_banned = False
existing_record.full_path = str(emoji.full_path)
existing_record.description = emoji.description
existing_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None
existing_record.query_count = emoji.query_count
existing_record.last_used_time = emoji.last_used_time
existing_record.register_time = emoji.register_time
session.add(existing_record)
logger.info(
f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}"
)
else:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
return False
else:
image_record = emoji.to_db_instance()
image_record.is_registered = True
image_record.is_banned = False
image_record.register_time = datetime.now()
session.add(image_record)
session.flush()
record_id = image_record.id
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
except Exception as e:
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
return False
@ -401,7 +476,9 @@ class EmojiManager:
# 调用VLM生成描述
image_format = target_emoji.image_format
image_bytes = await asyncio.to_thread(target_emoji.read_image_bytes, target_emoji.full_path)
image_bytes = target_emoji.image_bytes or await asyncio.to_thread(
target_emoji.read_image_bytes, target_emoji.full_path
)
if image_format == "gif":
try:
@ -567,6 +644,29 @@ class EmojiManager:
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
return False
# 0. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
try:
with get_db_session_manual() as session:
statement = (
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
)
if image_record := session.exec(statement).first():
if image_record.no_file_flag:
image_record.no_file_flag = False
image_record.is_banned = False
image_record.is_registered = True
image_record.full_path = str(target_emoji.full_path)
session.add(image_record)
session.commit()
logger.info(f"表情包注册成功Hash: {target_emoji.file_hash}")
return True
else:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}")
return False
except Exception as e:
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
return False
# 1. 计算哈希值和格式
calc_success = await target_emoji.calculate_hash_format()
if not calc_success:

View File

@ -83,7 +83,7 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
async def calculate_hash_format(self) -> bool:
"""
异步计算表情包的哈希值和格式
异步计算表情包的哈希值和格式初始化后应该执行此方法来确保对象的哈希值和格式正确
Returns:
return (bool): 如果成功计算哈希值和格式则返回True否则返回False