mirror of https://github.com/Mai-with-u/MaiBot.git
获取和注册一体化修正
parent
ccd1be7bed
commit
537b24c24e
|
|
@ -58,6 +58,7 @@ def _install_stub_modules(monkeypatch):
|
|||
query_count: int = 0
|
||||
register_time: object | None = None
|
||||
image_format: str | None = None
|
||||
image_bytes: bytes | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db_instance(_record):
|
||||
|
|
@ -128,10 +129,17 @@ def _install_stub_modules(monkeypatch):
|
|||
def flush(self):
|
||||
pass
|
||||
|
||||
def commit(self):
|
||||
pass
|
||||
|
||||
def get_db_session():
|
||||
return _DummySession()
|
||||
|
||||
def get_db_session_manual():
|
||||
return _DummySession()
|
||||
|
||||
db_mod.get_db_session = get_db_session
|
||||
db_mod.get_db_session_manual = get_db_session_manual
|
||||
|
||||
# src.common.utils.utils_image
|
||||
image_utils_mod = _stub_module("src.common.utils.utils_image")
|
||||
|
|
@ -236,6 +244,15 @@ def import_emoji_manager_new(monkeypatch):
|
|||
module = importlib.util.module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, "emoji_manager_new", module)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
class _Select:
|
||||
def filter_by(self, **kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, n):
|
||||
return self
|
||||
|
||||
module.select = lambda _model: _Select()
|
||||
return module
|
||||
|
||||
|
||||
|
|
@ -763,6 +780,12 @@ def test_register_emoji_to_db_db_error(monkeypatch):
|
|||
def flush(self):
|
||||
pass
|
||||
|
||||
def exec(self, _statement):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return None
|
||||
|
||||
def _get_db_session():
|
||||
return _Session()
|
||||
|
||||
|
|
@ -821,6 +844,12 @@ def test_register_emoji_to_db_success(monkeypatch, tmp_path):
|
|||
def flush(self):
|
||||
pass
|
||||
|
||||
def exec(self, _statement):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return None
|
||||
|
||||
def _get_db_session():
|
||||
return _Session()
|
||||
|
||||
|
|
@ -858,6 +887,7 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch):
|
|||
class _Select:
|
||||
def filter_by(self, **_kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, _num):
|
||||
return self
|
||||
|
||||
|
|
@ -948,7 +978,7 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch):
|
|||
class _Select:
|
||||
def filter_by(self, **_kwargs):
|
||||
return self
|
||||
|
||||
|
||||
def limit(self, _num):
|
||||
return self
|
||||
|
||||
|
|
@ -1006,6 +1036,7 @@ def test_delete_emoji_success(monkeypatch):
|
|||
class _Select:
|
||||
def filter_by(self, **_kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, _num):
|
||||
return self
|
||||
|
||||
|
|
@ -1056,6 +1087,7 @@ def test_delete_emoji_success(monkeypatch):
|
|||
assert any("成功删除表情包文件" in m for m in _messages(logger.info_calls))
|
||||
assert any("成功修改数据库中的表情包记录" in m for m in _messages(logger.info_calls))
|
||||
|
||||
|
||||
def test_delete_emoji_no_desc_deletes_record(monkeypatch):
|
||||
emoji_manager_new = import_emoji_manager_new(monkeypatch)
|
||||
logger = emoji_manager_new.logger
|
||||
|
|
@ -1133,6 +1165,7 @@ def test_update_emoji_usage_success(monkeypatch):
|
|||
class _Select:
|
||||
def filter_by(self, **_kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, _num):
|
||||
return self
|
||||
|
||||
|
|
@ -1143,6 +1176,7 @@ def test_update_emoji_usage_success(monkeypatch):
|
|||
def __init__(self):
|
||||
self.query_count = 2
|
||||
self.last_used_time = None
|
||||
|
||||
record = _Record()
|
||||
|
||||
class _Result:
|
||||
|
|
@ -1237,6 +1271,7 @@ def test_update_emoji_usage_execute_error(monkeypatch):
|
|||
class _Select:
|
||||
def filter_by(self, **_kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, _num):
|
||||
return self
|
||||
|
||||
|
|
@ -2254,6 +2289,38 @@ async def test_register_emoji_by_filename_hash_format_failed(monkeypatch, tmp_pa
|
|||
async def calculate_hash_format(self):
|
||||
return False
|
||||
|
||||
class _Session:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def exec(self, _statement):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return None
|
||||
|
||||
class _Select:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def filter_by(self, **_kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, _num):
|
||||
return self
|
||||
|
||||
def _get_db_session_manual():
|
||||
return _Session()
|
||||
|
||||
def _get_db_session():
|
||||
return _Session()
|
||||
|
||||
monkeypatch.setattr(emoji_manager_new, "get_db_session_manual", _get_db_session_manual)
|
||||
monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session)
|
||||
monkeypatch.setattr(emoji_manager_new, "select", lambda _model: _Select())
|
||||
monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji)
|
||||
|
||||
result = await manager.register_emoji_by_filename(file_path)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from sqlmodel import select
|
|||
from typing import Optional, Tuple, List
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import heapq
|
||||
import Levenshtein
|
||||
import random
|
||||
|
|
@ -13,7 +14,7 @@ import re
|
|||
from src.common.logger import get_logger
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.config.config import global_config
|
||||
|
|
@ -43,6 +44,10 @@ emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_c
|
|||
|
||||
|
||||
class EmojiManager:
|
||||
"""
|
||||
表情包管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
_ensure_directories()
|
||||
|
||||
|
|
@ -51,6 +56,57 @@ class EmojiManager:
|
|||
|
||||
logger.info("启动表情包管理器")
|
||||
|
||||
async def get_emoji_description_by_bytes(self, emoji_bytes: bytes) -> Optional[Tuple[str, List[str]]]:
|
||||
"""
|
||||
根据表情包哈希获取表情包描述的封装方法
|
||||
|
||||
Returns:
|
||||
return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包,则返回包含描述和情感标签的元组;若没找到,则尝试构建表情包描述并返回,如果构建失败则返回 None
|
||||
"""
|
||||
# 先查找
|
||||
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
|
||||
if emoji := self.get_emoji_by_hash(emoji_hash):
|
||||
return emoji.description, emoji.emotion or []
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if result := session.exec(statement).first():
|
||||
return result.description, result.emotion.split(",") if result.emotion else []
|
||||
except Exception as e:
|
||||
logger.warning(f"从数据库查找表情包时出错: {e},将尝试构建表情包描述")
|
||||
|
||||
# 找不到尝试构建
|
||||
logger.info(f"未找到哈希值为 {emoji_hash} 的表情包与其描述,尝试构建描述")
|
||||
full_path = EMOJI_DIR / f"{emoji_hash}.png"
|
||||
try:
|
||||
full_path.write_bytes(emoji_bytes)
|
||||
new_emoji = MaiEmoji(full_path=full_path, image_bytes=emoji_bytes)
|
||||
await new_emoji.calculate_hash_format()
|
||||
except Exception as e:
|
||||
logger.error(f"缓存表情包文件时出错: {e}")
|
||||
raise e
|
||||
success_desc, new_emoji = await self.build_emoji_description(new_emoji)
|
||||
if not success_desc:
|
||||
logger.error("构建表情包描述失败")
|
||||
return None
|
||||
success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji)
|
||||
if not success_emotion:
|
||||
logger.error("构建表情包情感标签失败")
|
||||
return None
|
||||
|
||||
# 缓存结果到数据库
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
image_record = new_emoji.to_db_instance()
|
||||
image_record.is_registered = False
|
||||
image_record.is_banned = False
|
||||
image_record.register_time = datetime.now()
|
||||
image_record.no_file_flag = True
|
||||
session.add(image_record)
|
||||
except Exception as e:
|
||||
logger.error(f"缓存表情包描述时出错: {e}")
|
||||
return new_emoji.description, new_emoji.emotion or []
|
||||
|
||||
def load_emojis_from_db(self) -> None:
|
||||
"""
|
||||
从数据库加载已注册的表情包
|
||||
|
|
@ -114,14 +170,33 @@ class EmojiManager:
|
|||
# 注册到数据库
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
image_record = emoji.to_db_instance()
|
||||
image_record.is_registered = True
|
||||
image_record.is_banned = False
|
||||
image_record.register_time = datetime.now()
|
||||
session.add(image_record)
|
||||
session.flush()
|
||||
record_id = image_record.id
|
||||
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if existing_record := session.exec(statement).first():
|
||||
if existing_record.no_file_flag:
|
||||
existing_record.no_file_flag = False
|
||||
existing_record.is_banned = False
|
||||
existing_record.full_path = str(emoji.full_path)
|
||||
existing_record.description = emoji.description
|
||||
existing_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None
|
||||
existing_record.query_count = emoji.query_count
|
||||
existing_record.last_used_time = emoji.last_used_time
|
||||
existing_record.register_time = emoji.register_time
|
||||
session.add(existing_record)
|
||||
logger.info(
|
||||
f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
||||
return False
|
||||
else:
|
||||
image_record = emoji.to_db_instance()
|
||||
image_record.is_registered = True
|
||||
image_record.is_banned = False
|
||||
image_record.register_time = datetime.now()
|
||||
session.add(image_record)
|
||||
session.flush()
|
||||
record_id = image_record.id
|
||||
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
|
||||
return False
|
||||
|
|
@ -401,7 +476,9 @@ class EmojiManager:
|
|||
|
||||
# 调用VLM生成描述
|
||||
image_format = target_emoji.image_format
|
||||
image_bytes = await asyncio.to_thread(target_emoji.read_image_bytes, target_emoji.full_path)
|
||||
image_bytes = target_emoji.image_bytes or await asyncio.to_thread(
|
||||
target_emoji.read_image_bytes, target_emoji.full_path
|
||||
)
|
||||
|
||||
if image_format == "gif":
|
||||
try:
|
||||
|
|
@ -567,6 +644,29 @@ class EmojiManager:
|
|||
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
|
||||
return False
|
||||
|
||||
# 0. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
|
||||
try:
|
||||
with get_db_session_manual() as session:
|
||||
statement = (
|
||||
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
)
|
||||
if image_record := session.exec(statement).first():
|
||||
if image_record.no_file_flag:
|
||||
image_record.no_file_flag = False
|
||||
image_record.is_banned = False
|
||||
image_record.is_registered = True
|
||||
image_record.full_path = str(target_emoji.full_path)
|
||||
session.add(image_record)
|
||||
session.commit()
|
||||
logger.info(f"表情包注册成功,Hash: {target_emoji.file_hash}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
||||
return False
|
||||
|
||||
# 1. 计算哈希值和格式
|
||||
calc_success = await target_emoji.calculate_hash_format()
|
||||
if not calc_success:
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
|||
|
||||
async def calculate_hash_format(self) -> bool:
|
||||
"""
|
||||
异步计算表情包的哈希值和格式
|
||||
异步计算表情包的哈希值和格式,初始化后应该执行此方法来确保对象的哈希值和格式正确
|
||||
|
||||
Returns:
|
||||
return (bool): 如果成功计算哈希值和格式则返回True,否则返回False
|
||||
|
|
|
|||
Loading…
Reference in New Issue