diff --git a/pytests/image_sys_test/emoji_manager_test.py b/pytests/image_sys_test/emoji_manager_test.py index 7c50e9e8..08f7fa64 100644 --- a/pytests/image_sys_test/emoji_manager_test.py +++ b/pytests/image_sys_test/emoji_manager_test.py @@ -1015,13 +1015,14 @@ def test_update_emoji_usage_success(monkeypatch): def __init__(self): self.query_count = 2 self.last_used_time = None + record = _Record() class _Result: def scalars(self): return self def first(self): - return _Record() + return record class _Session: def __enter__(self): @@ -1048,6 +1049,8 @@ def test_update_emoji_usage_success(monkeypatch): result = manager.update_emoji_usage(emoji) assert result is True + assert emoji.query_count == 1 + assert record.query_count == 1 assert any("成功记录表情包使用" in m for m in _messages(logger.info_calls)) @@ -1156,6 +1159,169 @@ def test_update_emoji_usage_get_db_session_error(monkeypatch): assert any("记录使用时出错" in m for m in _messages(logger.error_calls)) +def test_update_emoji_success(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _Select: + def filter_by(self, **_kwargs): + return self + + def limit(self, _num): + return self + + def _select(_model): + return _Select() + + class _Record: + def __init__(self): + self.description = None + self.emotion = None + + class _Result: + def scalars(self): + return self + + def first(self): + return _Record() + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def exec(self, _statement): + return _Result() + + def add(self, _record): + self.added = True + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "select", _select) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.file_hash = "hash-update" + emoji.description = "new-desc" + emoji.emotion = ["a", "b"] + + result = manager.update_emoji(emoji) + + assert result is True + assert any("成功更新表情包信息" in m for m in _messages(logger.info_calls)) + + +def test_update_emoji_missing_record(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _Select: + def filter_by(self, **_kwargs): + return self + + def limit(self, _num): + return self + + def _select(_model): + return _Select() + + class _Result: + def scalars(self): + return self + + def first(self): + return None + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def exec(self, _statement): + return _Result() + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "select", _select) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.file_hash = "hash-missing" + + result = manager.update_emoji(emoji) + + assert result is False + assert any("未找到表情包记录" in m for m in _messages(logger.error_calls)) + + +def test_update_emoji_execute_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _Select: + def filter_by(self, **_kwargs): + return self + + def limit(self, _num): + return self + + def _select(_model): + return _Select() + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def exec(self, _statement): + raise RuntimeError("execute failed") + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "select", _select) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.file_hash = "hash-execute" + + result = manager.update_emoji(emoji) + + assert result is False + assert any("更新数据库记录时出错" in m for m in _messages(logger.error_calls)) + + +def test_update_emoji_get_db_session_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + def _get_db_session(): + raise RuntimeError("get_db_session failed") + + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.file_hash = "hash-session" + + result = manager.update_emoji(emoji) + + assert result is False + assert any("更新数据库记录时出错" in m for m in _messages(logger.error_calls)) + + @pytest.mark.asyncio async def test_get_emoji_for_emotion_empty_list(monkeypatch): emoji_manager_new = import_emoji_manager_new(monkeypatch) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 798cf411..5295dfac 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -160,6 +160,7 @@ class EmojiManager: return True def update_emoji_usage(self, emoji: MaiEmoji) -> bool: + # sourcery skip: extract-method """ 更新表情包的使用情况,更新查询次数和上次使用时间 @@ -168,12 +169,17 @@ class EmojiManager: Returns: return (bool): 更新是否成功 """ + if not emoji or not emoji.file_hash: + logger.error("[更新表情包使用] 无效的表情包对象") + return False try: with get_db_session() as session: statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1) if image_record := session.exec(statement).first(): - image_record.query_count += 1 - image_record.last_used_time = datetime.now() + emoji.query_count += 1 + image_record.query_count= emoji.query_count + emoji.last_used_time = datetime.now() + image_record.last_used_time = emoji.last_used_time session.add(image_record) logger.info(f"[记录表情包使用] 成功记录表情包使用: {emoji.file_hash}") else: @@ -184,6 +190,35 @@ class EmojiManager: return False return True + def update_emoji(self, emoji: MaiEmoji) -> bool: + """ + 更新表情包的情感标签和描述信息 + + Args: + emoji (MaiEmoji): 需要更新的表情包对象,必须包含有效的 file_hash + Returns: + return (bool): 更新是否成功 + """ + if not emoji or not emoji.file_hash: + logger.error("[更新表情包] 无效的表情包对象") + return False + + try: + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1) + if image_record := session.exec(statement).first(): + image_record.description = emoji.description + image_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None + session.add(image_record) + logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}") + else: + logger.error(f"[更新表情包] 未找到表情包记录: {emoji.file_hash}") + return False + except Exception as e: + logger.error(f"[更新表情包] 更新数据库记录时出错: {e}") + return False + return True + def get_emoji_by_hash(self, emoji_hash: str) -> Optional[MaiEmoji]: """ 根据哈希值获取表情包对象