获取和注册一体化修正

pull/1496/head
UnCLAS-Prommer 2026-02-18 21:34:56 +08:00
parent ccd1be7bed
commit 537b24c24e
No known key found for this signature in database
3 changed files with 179 additions and 12 deletions

View File

@ -58,6 +58,7 @@ def _install_stub_modules(monkeypatch):
query_count: int = 0 query_count: int = 0
register_time: object | None = None register_time: object | None = None
image_format: str | None = None image_format: str | None = None
image_bytes: bytes | None = None
@staticmethod @staticmethod
def from_db_instance(_record): def from_db_instance(_record):
@ -128,10 +129,17 @@ def _install_stub_modules(monkeypatch):
def flush(self): def flush(self):
pass pass
def commit(self):
pass
def get_db_session(): def get_db_session():
return _DummySession() return _DummySession()
def get_db_session_manual():
return _DummySession()
db_mod.get_db_session = get_db_session db_mod.get_db_session = get_db_session
db_mod.get_db_session_manual = get_db_session_manual
# src.common.utils.utils_image # src.common.utils.utils_image
image_utils_mod = _stub_module("src.common.utils.utils_image") image_utils_mod = _stub_module("src.common.utils.utils_image")
@ -236,6 +244,15 @@ def import_emoji_manager_new(monkeypatch):
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "emoji_manager_new", module) monkeypatch.setitem(sys.modules, "emoji_manager_new", module)
spec.loader.exec_module(module) spec.loader.exec_module(module)
class _Select:
def filter_by(self, **kwargs):
return self
def limit(self, n):
return self
module.select = lambda _model: _Select()
return module return module
@ -763,6 +780,12 @@ def test_register_emoji_to_db_db_error(monkeypatch):
def flush(self): def flush(self):
pass pass
def exec(self, _statement):
return self
def first(self):
return None
def _get_db_session(): def _get_db_session():
return _Session() return _Session()
@ -821,6 +844,12 @@ def test_register_emoji_to_db_success(monkeypatch, tmp_path):
def flush(self): def flush(self):
pass pass
def exec(self, _statement):
return self
def first(self):
return None
def _get_db_session(): def _get_db_session():
return _Session() return _Session()
@ -858,6 +887,7 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch):
class _Select: class _Select:
def filter_by(self, **_kwargs): def filter_by(self, **_kwargs):
return self return self
def limit(self, _num): def limit(self, _num):
return self return self
@ -1006,6 +1036,7 @@ def test_delete_emoji_success(monkeypatch):
class _Select: class _Select:
def filter_by(self, **_kwargs): def filter_by(self, **_kwargs):
return self return self
def limit(self, _num): def limit(self, _num):
return self return self
@ -1056,6 +1087,7 @@ def test_delete_emoji_success(monkeypatch):
assert any("成功删除表情包文件" in m for m in _messages(logger.info_calls)) assert any("成功删除表情包文件" in m for m in _messages(logger.info_calls))
assert any("成功修改数据库中的表情包记录" in m for m in _messages(logger.info_calls)) assert any("成功修改数据库中的表情包记录" in m for m in _messages(logger.info_calls))
def test_delete_emoji_no_desc_deletes_record(monkeypatch): def test_delete_emoji_no_desc_deletes_record(monkeypatch):
emoji_manager_new = import_emoji_manager_new(monkeypatch) emoji_manager_new = import_emoji_manager_new(monkeypatch)
logger = emoji_manager_new.logger logger = emoji_manager_new.logger
@ -1133,6 +1165,7 @@ def test_update_emoji_usage_success(monkeypatch):
class _Select: class _Select:
def filter_by(self, **_kwargs): def filter_by(self, **_kwargs):
return self return self
def limit(self, _num): def limit(self, _num):
return self return self
@ -1143,6 +1176,7 @@ def test_update_emoji_usage_success(monkeypatch):
def __init__(self): def __init__(self):
self.query_count = 2 self.query_count = 2
self.last_used_time = None self.last_used_time = None
record = _Record() record = _Record()
class _Result: class _Result:
@ -1237,6 +1271,7 @@ def test_update_emoji_usage_execute_error(monkeypatch):
class _Select: class _Select:
def filter_by(self, **_kwargs): def filter_by(self, **_kwargs):
return self return self
def limit(self, _num): def limit(self, _num):
return self return self
@ -2254,6 +2289,38 @@ async def test_register_emoji_by_filename_hash_format_failed(monkeypatch, tmp_pa
async def calculate_hash_format(self): async def calculate_hash_format(self):
return False return False
class _Session:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def exec(self, _statement):
return self
def first(self):
return None
class _Select:
def __init__(self) -> None:
pass
def filter_by(self, **_kwargs):
return self
def limit(self, _num):
return self
def _get_db_session_manual():
return _Session()
def _get_db_session():
return _Session()
monkeypatch.setattr(emoji_manager_new, "get_db_session_manual", _get_db_session_manual)
monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session)
monkeypatch.setattr(emoji_manager_new, "select", lambda _model: _Select())
monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji)
result = await manager.register_emoji_by_filename(file_path) result = await manager.register_emoji_by_filename(file_path)

View File

@ -5,6 +5,7 @@ from sqlmodel import select
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
import asyncio import asyncio
import hashlib
import heapq import heapq
import Levenshtein import Levenshtein
import random import random
@ -13,7 +14,7 @@ import re
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.image_data_model import MaiEmoji from src.common.data_models.image_data_model import MaiEmoji
from src.common.database.database_model import Images, ImageType from src.common.database.database_model import Images, ImageType
from src.common.database.database import get_db_session from src.common.database.database import get_db_session, get_db_session_manual
from src.common.utils.utils_image import ImageUtils from src.common.utils.utils_image import ImageUtils
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.config.config import global_config from src.config.config import global_config
@ -43,6 +44,10 @@ emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_c
class EmojiManager: class EmojiManager:
"""
表情包管理器
"""
def __init__(self): def __init__(self):
_ensure_directories() _ensure_directories()
@ -51,6 +56,57 @@ class EmojiManager:
logger.info("启动表情包管理器") logger.info("启动表情包管理器")
async def get_emoji_description_by_bytes(self, emoji_bytes: bytes) -> Optional[Tuple[str, List[str]]]:
"""
根据表情包哈希获取表情包描述的封装方法
Returns:
return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包则返回包含描述和情感标签的元组若没找到则尝试构建表情包描述并返回如果构建失败则返回 None
"""
# 先查找
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
if emoji := self.get_emoji_by_hash(emoji_hash):
return emoji.description, emoji.emotion or []
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
if result := session.exec(statement).first():
return result.description, result.emotion.split(",") if result.emotion else []
except Exception as e:
logger.warning(f"从数据库查找表情包时出错: {e},将尝试构建表情包描述")
# 找不到尝试构建
logger.info(f"未找到哈希值为 {emoji_hash} 的表情包与其描述,尝试构建描述")
full_path = EMOJI_DIR / f"{emoji_hash}.png"
try:
full_path.write_bytes(emoji_bytes)
new_emoji = MaiEmoji(full_path=full_path, image_bytes=emoji_bytes)
await new_emoji.calculate_hash_format()
except Exception as e:
logger.error(f"缓存表情包文件时出错: {e}")
raise e
success_desc, new_emoji = await self.build_emoji_description(new_emoji)
if not success_desc:
logger.error("构建表情包描述失败")
return None
success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji)
if not success_emotion:
logger.error("构建表情包情感标签失败")
return None
# 缓存结果到数据库
with get_db_session() as session:
try:
image_record = new_emoji.to_db_instance()
image_record.is_registered = False
image_record.is_banned = False
image_record.register_time = datetime.now()
image_record.no_file_flag = True
session.add(image_record)
except Exception as e:
logger.error(f"缓存表情包描述时出错: {e}")
return new_emoji.description, new_emoji.emotion or []
def load_emojis_from_db(self) -> None: def load_emojis_from_db(self) -> None:
""" """
从数据库加载已注册的表情包 从数据库加载已注册的表情包
@ -114,14 +170,33 @@ class EmojiManager:
# 注册到数据库 # 注册到数据库
try: try:
with get_db_session() as session: with get_db_session() as session:
image_record = emoji.to_db_instance() statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
image_record.is_registered = True if existing_record := session.exec(statement).first():
image_record.is_banned = False if existing_record.no_file_flag:
image_record.register_time = datetime.now() existing_record.no_file_flag = False
session.add(image_record) existing_record.is_banned = False
session.flush() existing_record.full_path = str(emoji.full_path)
record_id = image_record.id existing_record.description = emoji.description
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}") existing_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None
existing_record.query_count = emoji.query_count
existing_record.last_used_time = emoji.last_used_time
existing_record.register_time = emoji.register_time
session.add(existing_record)
logger.info(
f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}"
)
else:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
return False
else:
image_record = emoji.to_db_instance()
image_record.is_registered = True
image_record.is_banned = False
image_record.register_time = datetime.now()
session.add(image_record)
session.flush()
record_id = image_record.id
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
except Exception as e: except Exception as e:
logger.error(f"[注册表情包] 注册到数据库时出错: {e}") logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
return False return False
@ -401,7 +476,9 @@ class EmojiManager:
# 调用VLM生成描述 # 调用VLM生成描述
image_format = target_emoji.image_format image_format = target_emoji.image_format
image_bytes = await asyncio.to_thread(target_emoji.read_image_bytes, target_emoji.full_path) image_bytes = target_emoji.image_bytes or await asyncio.to_thread(
target_emoji.read_image_bytes, target_emoji.full_path
)
if image_format == "gif": if image_format == "gif":
try: try:
@ -567,6 +644,29 @@ class EmojiManager:
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}") logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
return False return False
# 0. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
try:
with get_db_session_manual() as session:
statement = (
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
)
if image_record := session.exec(statement).first():
if image_record.no_file_flag:
image_record.no_file_flag = False
image_record.is_banned = False
image_record.is_registered = True
image_record.full_path = str(target_emoji.full_path)
session.add(image_record)
session.commit()
logger.info(f"表情包注册成功Hash: {target_emoji.file_hash}")
return True
else:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}")
return False
except Exception as e:
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
return False
# 1. 计算哈希值和格式 # 1. 计算哈希值和格式
calc_success = await target_emoji.calculate_hash_format() calc_success = await target_emoji.calculate_hash_format()
if not calc_success: if not calc_success:

View File

@ -83,7 +83,7 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
async def calculate_hash_format(self) -> bool: async def calculate_hash_format(self) -> bool:
""" """
异步计算表情包的哈希值和格式 异步计算表情包的哈希值和格式初始化后应该执行此方法来确保对象的哈希值和格式正确
Returns: Returns:
return (bool): 如果成功计算哈希值和格式则返回True否则返回False return (bool): 如果成功计算哈希值和格式则返回True否则返回False