mirror of https://github.com/Mai-with-u/MaiBot.git
获取和注册一体化修正
parent
ccd1be7bed
commit
537b24c24e
|
|
@ -58,6 +58,7 @@ def _install_stub_modules(monkeypatch):
|
||||||
query_count: int = 0
|
query_count: int = 0
|
||||||
register_time: object | None = None
|
register_time: object | None = None
|
||||||
image_format: str | None = None
|
image_format: str | None = None
|
||||||
|
image_bytes: bytes | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_db_instance(_record):
|
def from_db_instance(_record):
|
||||||
|
|
@ -128,10 +129,17 @@ def _install_stub_modules(monkeypatch):
|
||||||
def flush(self):
|
def flush(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def get_db_session():
|
def get_db_session():
|
||||||
return _DummySession()
|
return _DummySession()
|
||||||
|
|
||||||
|
def get_db_session_manual():
|
||||||
|
return _DummySession()
|
||||||
|
|
||||||
db_mod.get_db_session = get_db_session
|
db_mod.get_db_session = get_db_session
|
||||||
|
db_mod.get_db_session_manual = get_db_session_manual
|
||||||
|
|
||||||
# src.common.utils.utils_image
|
# src.common.utils.utils_image
|
||||||
image_utils_mod = _stub_module("src.common.utils.utils_image")
|
image_utils_mod = _stub_module("src.common.utils.utils_image")
|
||||||
|
|
@ -236,6 +244,15 @@ def import_emoji_manager_new(monkeypatch):
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
monkeypatch.setitem(sys.modules, "emoji_manager_new", module)
|
monkeypatch.setitem(sys.modules, "emoji_manager_new", module)
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
class _Select:
|
||||||
|
def filter_by(self, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def limit(self, n):
|
||||||
|
return self
|
||||||
|
|
||||||
|
module.select = lambda _model: _Select()
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -763,6 +780,12 @@ def test_register_emoji_to_db_db_error(monkeypatch):
|
||||||
def flush(self):
|
def flush(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def exec(self, _statement):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return None
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
return _Session()
|
return _Session()
|
||||||
|
|
||||||
|
|
@ -821,6 +844,12 @@ def test_register_emoji_to_db_success(monkeypatch, tmp_path):
|
||||||
def flush(self):
|
def flush(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def exec(self, _statement):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return None
|
||||||
|
|
||||||
def _get_db_session():
|
def _get_db_session():
|
||||||
return _Session()
|
return _Session()
|
||||||
|
|
||||||
|
|
@ -858,6 +887,7 @@ def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def limit(self, _num):
|
def limit(self, _num):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -948,7 +978,7 @@ def test_delete_emoji_db_error_file_still_exists(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def limit(self, _num):
|
def limit(self, _num):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -1006,6 +1036,7 @@ def test_delete_emoji_success(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def limit(self, _num):
|
def limit(self, _num):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -1056,6 +1087,7 @@ def test_delete_emoji_success(monkeypatch):
|
||||||
assert any("成功删除表情包文件" in m for m in _messages(logger.info_calls))
|
assert any("成功删除表情包文件" in m for m in _messages(logger.info_calls))
|
||||||
assert any("成功修改数据库中的表情包记录" in m for m in _messages(logger.info_calls))
|
assert any("成功修改数据库中的表情包记录" in m for m in _messages(logger.info_calls))
|
||||||
|
|
||||||
|
|
||||||
def test_delete_emoji_no_desc_deletes_record(monkeypatch):
|
def test_delete_emoji_no_desc_deletes_record(monkeypatch):
|
||||||
emoji_manager_new = import_emoji_manager_new(monkeypatch)
|
emoji_manager_new = import_emoji_manager_new(monkeypatch)
|
||||||
logger = emoji_manager_new.logger
|
logger = emoji_manager_new.logger
|
||||||
|
|
@ -1133,6 +1165,7 @@ def test_update_emoji_usage_success(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def limit(self, _num):
|
def limit(self, _num):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -1143,6 +1176,7 @@ def test_update_emoji_usage_success(monkeypatch):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.query_count = 2
|
self.query_count = 2
|
||||||
self.last_used_time = None
|
self.last_used_time = None
|
||||||
|
|
||||||
record = _Record()
|
record = _Record()
|
||||||
|
|
||||||
class _Result:
|
class _Result:
|
||||||
|
|
@ -1237,6 +1271,7 @@ def test_update_emoji_usage_execute_error(monkeypatch):
|
||||||
class _Select:
|
class _Select:
|
||||||
def filter_by(self, **_kwargs):
|
def filter_by(self, **_kwargs):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def limit(self, _num):
|
def limit(self, _num):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -2254,6 +2289,38 @@ async def test_register_emoji_by_filename_hash_format_failed(monkeypatch, tmp_pa
|
||||||
async def calculate_hash_format(self):
|
async def calculate_hash_format(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class _Session:
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def exec(self, _statement):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class _Select:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def filter_by(self, **_kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def limit(self, _num):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _get_db_session_manual():
|
||||||
|
return _Session()
|
||||||
|
|
||||||
|
def _get_db_session():
|
||||||
|
return _Session()
|
||||||
|
|
||||||
|
monkeypatch.setattr(emoji_manager_new, "get_db_session_manual", _get_db_session_manual)
|
||||||
|
monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session)
|
||||||
|
monkeypatch.setattr(emoji_manager_new, "select", lambda _model: _Select())
|
||||||
monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji)
|
monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji)
|
||||||
|
|
||||||
result = await manager.register_emoji_by_filename(file_path)
|
result = await manager.register_emoji_by_filename(file_path)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from sqlmodel import select
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import heapq
|
import heapq
|
||||||
import Levenshtein
|
import Levenshtein
|
||||||
import random
|
import random
|
||||||
|
|
@ -13,7 +14,7 @@ import re
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.image_data_model import MaiEmoji
|
from src.common.data_models.image_data_model import MaiEmoji
|
||||||
from src.common.database.database_model import Images, ImageType
|
from src.common.database.database_model import Images, ImageType
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session, get_db_session_manual
|
||||||
from src.common.utils.utils_image import ImageUtils
|
from src.common.utils.utils_image import ImageUtils
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
@ -43,6 +44,10 @@ emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_c
|
||||||
|
|
||||||
|
|
||||||
class EmojiManager:
|
class EmojiManager:
|
||||||
|
"""
|
||||||
|
表情包管理器
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
_ensure_directories()
|
_ensure_directories()
|
||||||
|
|
||||||
|
|
@ -51,6 +56,57 @@ class EmojiManager:
|
||||||
|
|
||||||
logger.info("启动表情包管理器")
|
logger.info("启动表情包管理器")
|
||||||
|
|
||||||
|
async def get_emoji_description_by_bytes(self, emoji_bytes: bytes) -> Optional[Tuple[str, List[str]]]:
|
||||||
|
"""
|
||||||
|
根据表情包哈希获取表情包描述的封装方法
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包,则返回包含描述和情感标签的元组;若没找到,则尝试构建表情包描述并返回,如果构建失败则返回 None
|
||||||
|
"""
|
||||||
|
# 先查找
|
||||||
|
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
|
||||||
|
if emoji := self.get_emoji_by_hash(emoji_hash):
|
||||||
|
return emoji.description, emoji.emotion or []
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
|
if result := session.exec(statement).first():
|
||||||
|
return result.description, result.emotion.split(",") if result.emotion else []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"从数据库查找表情包时出错: {e},将尝试构建表情包描述")
|
||||||
|
|
||||||
|
# 找不到尝试构建
|
||||||
|
logger.info(f"未找到哈希值为 {emoji_hash} 的表情包与其描述,尝试构建描述")
|
||||||
|
full_path = EMOJI_DIR / f"{emoji_hash}.png"
|
||||||
|
try:
|
||||||
|
full_path.write_bytes(emoji_bytes)
|
||||||
|
new_emoji = MaiEmoji(full_path=full_path, image_bytes=emoji_bytes)
|
||||||
|
await new_emoji.calculate_hash_format()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"缓存表情包文件时出错: {e}")
|
||||||
|
raise e
|
||||||
|
success_desc, new_emoji = await self.build_emoji_description(new_emoji)
|
||||||
|
if not success_desc:
|
||||||
|
logger.error("构建表情包描述失败")
|
||||||
|
return None
|
||||||
|
success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji)
|
||||||
|
if not success_emotion:
|
||||||
|
logger.error("构建表情包情感标签失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 缓存结果到数据库
|
||||||
|
with get_db_session() as session:
|
||||||
|
try:
|
||||||
|
image_record = new_emoji.to_db_instance()
|
||||||
|
image_record.is_registered = False
|
||||||
|
image_record.is_banned = False
|
||||||
|
image_record.register_time = datetime.now()
|
||||||
|
image_record.no_file_flag = True
|
||||||
|
session.add(image_record)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"缓存表情包描述时出错: {e}")
|
||||||
|
return new_emoji.description, new_emoji.emotion or []
|
||||||
|
|
||||||
def load_emojis_from_db(self) -> None:
|
def load_emojis_from_db(self) -> None:
|
||||||
"""
|
"""
|
||||||
从数据库加载已注册的表情包
|
从数据库加载已注册的表情包
|
||||||
|
|
@ -114,14 +170,33 @@ class EmojiManager:
|
||||||
# 注册到数据库
|
# 注册到数据库
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
image_record = emoji.to_db_instance()
|
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
image_record.is_registered = True
|
if existing_record := session.exec(statement).first():
|
||||||
image_record.is_banned = False
|
if existing_record.no_file_flag:
|
||||||
image_record.register_time = datetime.now()
|
existing_record.no_file_flag = False
|
||||||
session.add(image_record)
|
existing_record.is_banned = False
|
||||||
session.flush()
|
existing_record.full_path = str(emoji.full_path)
|
||||||
record_id = image_record.id
|
existing_record.description = emoji.description
|
||||||
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
|
existing_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None
|
||||||
|
existing_record.query_count = emoji.query_count
|
||||||
|
existing_record.last_used_time = emoji.last_used_time
|
||||||
|
existing_record.register_time = emoji.register_time
|
||||||
|
session.add(existing_record)
|
||||||
|
logger.info(
|
||||||
|
f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
image_record = emoji.to_db_instance()
|
||||||
|
image_record.is_registered = True
|
||||||
|
image_record.is_banned = False
|
||||||
|
image_record.register_time = datetime.now()
|
||||||
|
session.add(image_record)
|
||||||
|
session.flush()
|
||||||
|
record_id = image_record.id
|
||||||
|
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
|
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
@ -401,7 +476,9 @@ class EmojiManager:
|
||||||
|
|
||||||
# 调用VLM生成描述
|
# 调用VLM生成描述
|
||||||
image_format = target_emoji.image_format
|
image_format = target_emoji.image_format
|
||||||
image_bytes = await asyncio.to_thread(target_emoji.read_image_bytes, target_emoji.full_path)
|
image_bytes = target_emoji.image_bytes or await asyncio.to_thread(
|
||||||
|
target_emoji.read_image_bytes, target_emoji.full_path
|
||||||
|
)
|
||||||
|
|
||||||
if image_format == "gif":
|
if image_format == "gif":
|
||||||
try:
|
try:
|
||||||
|
|
@ -567,6 +644,29 @@ class EmojiManager:
|
||||||
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
|
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# 0. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
|
||||||
|
try:
|
||||||
|
with get_db_session_manual() as session:
|
||||||
|
statement = (
|
||||||
|
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
|
)
|
||||||
|
if image_record := session.exec(statement).first():
|
||||||
|
if image_record.no_file_flag:
|
||||||
|
image_record.no_file_flag = False
|
||||||
|
image_record.is_banned = False
|
||||||
|
image_record.is_registered = True
|
||||||
|
image_record.full_path = str(target_emoji.full_path)
|
||||||
|
session.add(image_record)
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"表情包注册成功,Hash: {target_emoji.file_hash}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
# 1. 计算哈希值和格式
|
# 1. 计算哈希值和格式
|
||||||
calc_success = await target_emoji.calculate_hash_format()
|
calc_success = await target_emoji.calculate_hash_format()
|
||||||
if not calc_success:
|
if not calc_success:
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||||
|
|
||||||
async def calculate_hash_format(self) -> bool:
|
async def calculate_hash_format(self) -> bool:
|
||||||
"""
|
"""
|
||||||
异步计算表情包的哈希值和格式
|
异步计算表情包的哈希值和格式,初始化后应该执行此方法来确保对象的哈希值和格式正确
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
return (bool): 如果成功计算哈希值和格式则返回True,否则返回False
|
return (bool): 如果成功计算哈希值和格式则返回True,否则返回False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue