diff --git a/changelogs/mai_next_todo.md b/changelogs/mai_next_todo.md index 40c7bdcb..faa00d97 100644 --- a/changelogs/mai_next_todo.md +++ b/changelogs/mai_next_todo.md @@ -142,6 +142,18 @@ version 0.3.0 - 2026-01-11 - [x] Prompt删除 - [x] **只保存被标记为需要保存的Prompt,其他的Prompt文件全部删除** +## LLM相关内容 +- [ ] 统一LLM调用接口 + - [ ] 统一LLM调用返回格式为专有数据模型 + - [ ] 取消所有__init__方法中对LLM Client的初始化,转而使用获取方式 + - [ ] 统一使用`get_llm_client`方法获取LLM Client实例 + - [ ] __init__方法中只保存配置信息 +- [ ] LLM Client管理器 + - [ ] LLM Client单例/多例管理 + - [ ] LLM Client缓存管理/生命周期管理 + - [ ] LLM Client根据配置热重载 + + ## 一些细枝末节的东西 - [ ] 将`stream_id`和`chat_id`统一命名为`session_id` - [ ] 映射表 diff --git a/prompts/emoji_content_analysis.prompt b/prompts/emoji_content_analysis.prompt new file mode 100644 index 00000000..c3834ce3 --- /dev/null +++ b/prompts/emoji_content_analysis.prompt @@ -0,0 +1,5 @@ +这是一个聊天场景中的表情包描述:"{description}" + +请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字 +你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗、meme的角度去分析 +请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔 \ No newline at end of file diff --git a/prompts/emoji_content_filtration.prompt b/prompts/emoji_content_filtration.prompt new file mode 100644 index 00000000..6bb73a53 --- /dev/null +++ b/prompts/emoji_content_filtration.prompt @@ -0,0 +1,6 @@ +这是一个表情包,请对这个表情包进行审核,标准如下: +1. 必须符合"{demand}"的要求 +2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗 +3. 不能是任何形式的截图,聊天记录或视频截图 +4. 不要出现5个以上文字 +请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容 \ No newline at end of file diff --git a/prompts/emoji_replace.prompt b/prompts/emoji_replace.prompt new file mode 100644 index 00000000..69093fda --- /dev/null +++ b/prompts/emoji_replace.prompt @@ -0,0 +1,12 @@ +{nickname}的表情包存储已满({emoji_num}/{emoji_num_max}),需要决定是否删除一个旧表情包来为新表情包腾出空间。 + +新表情包信息: +描述: {description} + +现有表情包列表: +{emoji_list} + +请决定: +1. 是否要删除某个现有表情包来为新表情包腾出空间? +2. 如果要删除,应该删除哪一个(给出编号)? +请只回答:'不删除'或'删除编号X'(X为表情包编号)。 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ec2dff16..1d432bd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ version = "0.11.6" description = "MaiCore 是一个基于大语言模型的可交互智能体" requires-python = ">=3.10" dependencies = [ - "aiohttp>=3.12.14", "aiohttp-cors>=0.8.1", + "aiohttp>=3.12.14", "colorama>=0.4.6", "faiss-cpu>=1.11.0", "fastapi>=0.116.0", @@ -14,7 +14,6 @@ dependencies = [ "json-repair>=0.47.6", "maim-message>=0.6.2", "matplotlib>=3.10.3", - "msgpack>=1.1.2", "numpy>=2.2.6", "openai>=1.95.0", "pandas>=2.3.1", @@ -25,6 +24,7 @@ dependencies = [ "pypinyin>=0.54.0", "python-dotenv>=1.1.1", "python-multipart>=0.0.20", + "python-levenshtein", "quick-algo>=0.1.3", "rich>=14.0.0", "ruff>=0.12.2", @@ -34,7 +34,6 @@ dependencies = [ "tomlkit>=0.13.3", "urllib3>=2.5.0", "uvicorn>=0.35.0", - "zstandard>=0.25.0", ] diff --git a/pytests/image_sys_test/emoji_manager_test.py b/pytests/image_sys_test/emoji_manager_test.py new file mode 100644 index 00000000..24e68c75 --- /dev/null +++ b/pytests/image_sys_test/emoji_manager_test.py @@ -0,0 +1,1950 @@ +# 本文件为测试文件,含有大量的MonkeyPatch和Mock代码,请忽略TypeChecker的报错 +import importlib.util +import sys +from dataclasses import dataclass +from types import ModuleType +from pathlib import Path + +import pytest + + +def _install_stub_modules(monkeypatch): + def _stub_module(name: str) -> ModuleType: + module = ModuleType(name) + monkeypatch.setitem(sys.modules, name, module) + return module + + # src.common.logger + logger_mod = _stub_module("src.common.logger") + + class _Logger: + def __init__(self): + self.info_calls = [] + self.debug_calls = [] + self.warning_calls = [] + self.error_calls = [] + self.critical_calls = [] + + def info(self, *args, **kwargs): + self.info_calls.append(args) + + def debug(self, *args, **kwargs): + self.debug_calls.append(args) + + def warning(self, *args, **kwargs): + self.warning_calls.append(args) + + def error(self, *args, **kwargs): + self.error_calls.append(args) + + def critical(self, *args, **kwargs): + self.critical_calls.append(args) + + def get_logger(_name: str): + return _Logger() + + logger_mod.get_logger = get_logger + + # src.common.data_models.image_data_model + data_model_mod = _stub_module("src.common.data_models.image_data_model") + + @dataclass + class MaiEmoji: + full_path: Path | None = None + file_name: str = "" + description: str | None = None + emotion: list[str] | None = None + emoji_hash: str | None = None + is_deleted: bool = False + query_count: int = 0 + register_time: object | None = None + _format: str | None = None + + @staticmethod + def from_db_instance(_record): + return MaiEmoji() + + def to_db_instance(self): + return Images() + + async def calculate_hash_format(self): + return True + + @staticmethod + def read_image_bytes(_path): + return b"" + + data_model_mod.MaiEmoji = MaiEmoji + + # src.common.database.database_model + db_model_mod = _stub_module("src.common.database.database_model") + + class Images: + id = 0 + is_registered = False + is_banned = False + register_time = None + query_count = 0 + last_used_time = None + full_path = "" + image_hash = "" + image_type = None + + class ImageType: + EMOJI = "EMOJI" + + db_model_mod.Images = Images + db_model_mod.ImageType = ImageType + + # src.common.database.database + db_mod = _stub_module("src.common.database.database") + + class _DummySession: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + class _Result: + def scalars(self): + return self + + def all(self): + return [] + + def first(self): + return None + + return _Result() + + def add(self, _record): + pass + + def delete(self, _record): + pass + + def flush(self): + pass + + def get_db_session(): + return _DummySession() + + db_mod.get_db_session = get_db_session + + # src.common.utils.utils_image + image_utils_mod = _stub_module("src.common.utils.utils_image") + + class ImageUtils: + @staticmethod + def gif_2_static_image(_image_bytes): + return b"" + + @staticmethod + def image_bytes_to_base64(_image_bytes): + return "" + + image_utils_mod.ImageUtils = ImageUtils + + # src.prompt.prompt_manager + prompt_manager_mod = _stub_module("src.prompt.prompt_manager") + + class _Prompt: + def add_context(self, _key, _value): + pass + + class _PromptManager: + def get_prompt(self, _name): + return _Prompt() + + async def render_prompt(self, _prompt): + return "" + + prompt_manager_mod.prompt_manager = _PromptManager() + + # src.config.config + config_mod = _stub_module("src.config.config") + + class _EmojiConfig: + max_reg_num = 20 + content_filtration = False + filtration_prompt = "" + steal_emoji = False + do_replace = False + check_interval = 1 + + class _BotConfig: + nickname = "bot" + + class _ModelTaskConfig: + vlm = None + utils = None + + class _ModelConfig: + model_task_config = _ModelTaskConfig() + + class _GlobalConfig: + emoji = _EmojiConfig() + bot = _BotConfig() + + config_mod.global_config = _GlobalConfig() + config_mod.model_config = _ModelConfig() + + # src.llm_models.utils_model + llm_mod = _stub_module("src.llm_models.utils_model") + + class LLMRequest: + def __init__(self, *args, **kwargs): + pass + + async def generate_response_async(self, *args, **kwargs): + return "", None + + async def generate_response_for_image(self, *args, **kwargs): + return "", None + + llm_mod.LLMRequest = LLMRequest + + # third-party stubs + rich_traceback_mod = _stub_module("rich.traceback") + + def install(*_args, **_kwargs): + pass + + rich_traceback_mod.install = install + + sqlmodel_mod = _stub_module("sqlmodel") + + def select(_model): + return object() + + sqlmodel_mod.select = select + + levenshtein_mod = _stub_module("Levenshtein") + + def distance(a, b): + return abs(len(str(a)) - len(str(b))) + + levenshtein_mod.distance = distance + + +def import_emoji_manager_new(monkeypatch): + _install_stub_modules(monkeypatch) + file_path = Path(__file__).resolve().parents[2] / "src" / "chat" / "emoji_system" / "emoji_manager_new.py" + spec = importlib.util.spec_from_file_location("emoji_manager_new", file_path) + module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, "emoji_manager_new", module) + spec.loader.exec_module(module) + return module + + +def _messages(calls): + return [" ".join(map(str, args)) for args in calls] + + +@pytest.mark.asyncio +async def test_replace_an_emoji_by_llm_decision_no_delete(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + + async def _generate_response_async(*_args, **_kwargs): + return "不删除", None + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + + result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) + + assert result is False + assert any("不删除任何表情包" in m for m in _messages(logger.info_calls)) + + +@pytest.mark.asyncio +async def test_replace_an_emoji_by_llm_decision_parse_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + + async def _generate_response_async(*_args, **_kwargs): + return "删除编号1", None + + def _bad_search(*_args, **_kwargs): + raise RuntimeError("search failed") + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + monkeypatch.setattr(emoji_manager_new.re, "search", _bad_search) + + result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) + + assert result is False + assert any("解析决策结果时出错" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_replace_an_emoji_by_llm_decision_missing_number(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + + async def _generate_response_async(*_args, **_kwargs): + return "删除编号ABC", None + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + + result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) + + assert result is False + assert any("未能解析删除编号" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_replace_an_emoji_by_llm_decision_index_out_of_range(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + + async def _generate_response_async(*_args, **_kwargs): + return "删除编号3", None + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + + result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) + + assert result is False + assert any("无效的表情包编号" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_replace_an_emoji_by_llm_delete_failed(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + + async def _generate_response_async(*_args, **_kwargs): + return "删除编号1", None + + def _delete(_emoji): + return False + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + monkeypatch.setattr(manager, "delete_emoji", _delete) + + result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) + + assert result is False + assert any("删除表情包失败" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_replace_an_emoji_by_llm_register_failed(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + + async def _generate_response_async(*_args, **_kwargs): + return "删除编号1", None + + def _delete(_emoji): + return True + + def _register(_emoji): + return False + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + monkeypatch.setattr(manager, "delete_emoji", _delete) + monkeypatch.setattr(manager, "register_emoji_to_db", _register) + + result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) + + assert result is False + assert any("注册新表情包失败" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_replace_an_emoji_by_llm_success(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + old_emoji = emoji_manager_new.MaiEmoji() + old_emoji.description = "old" + manager.emojis = [old_emoji] + + async def _generate_response_async(*_args, **_kwargs): + return "删除编号1", None + + def _delete(_emoji): + return True + + def _register(_emoji): + return True + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + monkeypatch.setattr(manager, "delete_emoji", _delete) + monkeypatch.setattr(manager, "register_emoji_to_db", _register) + + new_emoji = emoji_manager_new.MaiEmoji() + new_emoji.description = "new" + + result = await manager.replace_an_emoji_by_llm(new_emoji) + + assert result is True + assert new_emoji in manager.emojis + assert old_emoji not in manager.emojis + assert any("成功替换并注册新表情包" in m for m in _messages(logger.info_calls)) + + +def test_load_emojis_from_db_empty(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + class _Result: + def scalars(self): + return self + + def all(self): + return [] + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + return _Result() + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + manager = emoji_manager_new.EmojiManager() + + manager.load_emojis_from_db() + + assert manager.emojis == [] + assert manager._emoji_num == 0 + assert any("成功加载" in m for m in _messages(logger.info_calls)) + + +def test_load_emojis_from_db_partial_bad_records(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + class _Record: + def __init__(self, record_id, full_path): + self.id = record_id + self.full_path = full_path + + records = [_Record(1, "bad"), _Record(2, "ok")] + + class _Result: + def scalars(self): + return self + + def all(self): + return records + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + return _Result() + + def _get_db_session(): + return _Session() + + def _from_db_instance(record): + if record.id == 1: + raise ValueError("bad record") + emoji = emoji_manager_new.MaiEmoji() + emoji.file_name = "ok" + return emoji + + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + monkeypatch.setattr(emoji_manager_new.MaiEmoji, "from_db_instance", staticmethod(_from_db_instance)) + manager = emoji_manager_new.EmojiManager() + + manager.load_emojis_from_db() + + assert len(manager.emojis) == 1 + assert manager.emojis[0].file_name == "ok" + assert manager._emoji_num == 1 + assert any("加载表情包记录时出错" in m for m in _messages(logger.error_calls)) + assert any("成功加载" in m for m in _messages(logger.info_calls)) + + +def test_load_emojis_from_db_execute_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + raise RuntimeError("execute failed") + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + manager._emoji_num = 1 + + with pytest.raises(RuntimeError): + manager.load_emojis_from_db() + + assert manager.emojis == [] + assert manager._emoji_num == 0 + assert any("不可恢复错误" in m for m in _messages(logger.critical_calls)) + + +def test_load_emojis_from_db_get_db_session_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + def _get_db_session(): + raise RuntimeError("get_db_session failed") + + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + manager._emoji_num = 1 + + with pytest.raises(RuntimeError): + manager.load_emojis_from_db() + + assert manager.emojis == [] + assert manager._emoji_num == 0 + assert any("不可恢复错误" in m for m in _messages(logger.critical_calls)) + + +def test_load_emojis_from_db_scalars_all_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + class _Result: + def scalars(self): + return self + + def all(self): + raise RuntimeError("all failed") + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + return _Result() + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + manager._emoji_num = 1 + + with pytest.raises(RuntimeError): + manager.load_emojis_from_db() + + assert manager.emojis == [] + assert manager._emoji_num == 0 + assert any("不可恢复错误" in m for m in _messages(logger.critical_calls)) + + +def test_register_emoji_to_db_invalid_object(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + result = manager.register_emoji_to_db(None) + + assert result is False + assert any("无效的表情包对象" in m for m in _messages(logger.error_calls)) + + +def test_register_emoji_to_db_wrong_type(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + result = manager.register_emoji_to_db(object()) + + assert result is False + assert any("无效的表情包对象" in m for m in _messages(logger.error_calls)) + + +def test_register_emoji_to_db_file_missing(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + emoji = emoji_manager_new.MaiEmoji() + emoji.full_path = Path("/missing/file.png") + + result = manager.register_emoji_to_db(emoji) + + assert result is False + assert any("表情包文件不存在" in m for m in _messages(logger.error_calls)) + + +def test_register_emoji_to_db_move_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _DummyPath: + def __init__(self): + self._name = "a.png" + self._exists = True + + def exists(self): + return self._exists + + def replace(self, _target): + raise RuntimeError("move failed") + + @property + def name(self): + return self._name + + emoji = emoji_manager_new.MaiEmoji() + emoji.full_path = _DummyPath() + emoji.file_name = "a.png" + + result = manager.register_emoji_to_db(emoji) + + assert result is False + assert any("移动表情包文件时出错" in m for m in _messages(logger.error_calls)) + + +def test_register_emoji_to_db_db_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _DummyPath: + def __init__(self): + self._name = "a.png" + self._exists = True + self._replaced = False + + def exists(self): + return self._exists + + def replace(self, _target): + self._replaced = True + + @property + def name(self): + return self._name + + emoji = emoji_manager_new.MaiEmoji() + emoji.full_path = _DummyPath() + emoji.file_name = "a.png" + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def add(self, _record): + raise RuntimeError("db add failed") + + def flush(self): + pass + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + result = manager.register_emoji_to_db(emoji) + + assert result is False + assert any("注册到数据库时出错" in m for m in _messages(logger.error_calls)) + + +def test_register_emoji_to_db_success(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _DummyPath: + def __init__(self, name): + self._name = name + self._exists = True + self._replaced = False + self._target = None + + def exists(self): + return self._exists + + def replace(self, target): + self._replaced = True + self._target = target + + @property + def name(self): + return self._name + + emoji = emoji_manager_new.MaiEmoji() + emoji.full_path = _DummyPath("a.png") + emoji.file_name = "a.png" + + class _Record: + def __init__(self): + self.id = 123 + self.is_registered = False + self.is_banned = False + self.register_time = None + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def add(self, _record): + pass + + def flush(self): + pass + + def _get_db_session(): + return _Session() + + def _to_db_instance(self): + return _Record() + + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + monkeypatch.setattr(emoji_manager_new.MaiEmoji, "to_db_instance", _to_db_instance, raising=False) + + result = manager.register_emoji_to_db(emoji) + + assert result is True + assert any("成功注册表情包到数据库" in m for m in _messages(logger.info_calls)) + + +def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _DummyPath: + def __init__(self): + self._name = "missing.png" + + def unlink(self): + raise FileNotFoundError("missing") + + def exists(self): + return False + + @property + def name(self): + return self._name + + class _Select: + def filter_by(self, **_kwargs): + return self + + def _select(_model): + return _Select() + + class _Result: + def scalars(self): + return self + + def first(self): + return None + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + return _Result() + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "select", _select) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.full_path = _DummyPath() + emoji.file_name = "missing.png" + emoji.emoji_hash = "hash-missing" + + result = manager.delete_emoji(emoji) + + assert result is True + assert any("不存在" in m for m in _messages(logger.warning_calls)) + assert any("未找到表情包记录" in m for m in _messages(logger.warning_calls)) + + +def test_delete_emoji_file_delete_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _DummyPath: + def __init__(self): + self._name = "boom.png" + + def unlink(self): + raise RuntimeError("unlink failed") + + @property + def name(self): + return self._name + + emoji = emoji_manager_new.MaiEmoji() + emoji.full_path = _DummyPath() + emoji.file_name = "boom.png" + emoji.emoji_hash = "hash-boom" + + result = manager.delete_emoji(emoji) + + assert result is False + assert any("删除表情包文件时出错" in m for m in _messages(logger.error_calls)) + + +def test_delete_emoji_db_error_file_still_exists(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _DummyPath: + def __init__(self): + self._name = "keep.png" + + def unlink(self): + return None + + def exists(self): + return True + + @property + def name(self): + return self._name + + class _Select: + def filter_by(self, **_kwargs): + return self + + def _select(_model): + return _Select() + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + raise RuntimeError("db delete failed") + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "select", _select) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.full_path = _DummyPath() + emoji.file_name = "keep.png" + emoji.emoji_hash = "hash-keep" + + result = manager.delete_emoji(emoji) + + assert result is False + assert any("删除数据库记录时出错" in m for m in _messages(logger.error_calls)) + assert any("数据库记录删除失败,但文件仍存在" in m for m in _messages(logger.warning_calls)) + + +def test_delete_emoji_success(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _DummyPath: + def __init__(self): + self._name = "ok.png" + self._deleted = False + + def unlink(self): + self._deleted = True + + def exists(self): + return not self._deleted + + @property + def name(self): + return self._name + + class _Select: + def filter_by(self, **_kwargs): + return self + + def _select(_model): + return _Select() + + class _Record: + pass + + class _Result: + def scalars(self): + return self + + def first(self): + return _Record() + + class _Session: + def __init__(self): + self.deleted = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + return _Result() + + def delete(self, _record): + self.deleted = True + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "select", _select) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.full_path = _DummyPath() + emoji.file_name = "ok.png" + emoji.emoji_hash = "hash-ok" + + result = manager.delete_emoji(emoji) + + assert result is True + assert any("成功删除表情包文件" in m for m in _messages(logger.info_calls)) + assert any("成功删除数据库中的表情包记录" in m for m in _messages(logger.info_calls)) + + +def test_update_emoji_usage_success(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _Select: + def filter_by(self, **_kwargs): + return self + + def _select(_model): + return _Select() + + class _Record: + def __init__(self): + self.query_count = 2 + self.last_used_time = None + + class _Result: + def scalars(self): + return self + + def first(self): + return _Record() + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + return _Result() + + def add(self, _record): + self.added = True + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "select", _select) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.emoji_hash = "hash-ok" + + result = manager.update_emoji_usage(emoji) + + assert result is True + assert any("成功记录表情包使用" in m for m in _messages(logger.info_calls)) + + +def test_update_emoji_usage_missing_record(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _Select: + def filter_by(self, **_kwargs): + return self + + def _select(_model): + return _Select() + + class _Result: + def scalars(self): + return self + + def first(self): + return None + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + return _Result() + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "select", _select) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.emoji_hash = "hash-missing" + + result = manager.update_emoji_usage(emoji) + + assert result is False + assert any("未找到表情包记录" in m for m in _messages(logger.error_calls)) + + +def test_update_emoji_usage_execute_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + class _Select: + def filter_by(self, **_kwargs): + return self + + def _select(_model): + return _Select() + + class _Session: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _statement): + raise RuntimeError("execute failed") + + def _get_db_session(): + return _Session() + + monkeypatch.setattr(emoji_manager_new, "select", _select) + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.emoji_hash = "hash-execute" + + result = manager.update_emoji_usage(emoji) + + assert result is False + assert any("记录使用时出错" in m for m in _messages(logger.error_calls)) + + +def test_update_emoji_usage_get_db_session_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + def _get_db_session(): + raise RuntimeError("get_db_session failed") + + monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) + + emoji = emoji_manager_new.MaiEmoji() + emoji.emoji_hash = "hash-session" + + result = manager.update_emoji_usage(emoji) + + assert result is False + assert any("记录使用时出错" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_get_emoji_for_emotion_empty_list(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager.emojis = [] + + result = await manager.get_emoji_for_emotion("开心") + + assert result is None + assert any("表情包列表为空" in m for m in _messages(logger.warning_calls)) + + +@pytest.mark.asyncio +async def test_get_emoji_for_emotion_no_matches(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + + def _calc(_label): + return [] + + monkeypatch.setattr(manager, "_calculate_emotion_similarity_list", _calc) + + result = await manager.get_emoji_for_emotion("无匹配") + + assert result is None + assert any("未找到匹配的表情包" in m for m in _messages(logger.info_calls)) + + +@pytest.mark.asyncio +async def test_get_emoji_for_emotion_success_updates_usage(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + emoji1 = emoji_manager_new.MaiEmoji() + emoji1.file_name = "e1.png" + emoji1.emotion = ["开心"] + emoji2 = emoji_manager_new.MaiEmoji() + emoji2.file_name = "e2.png" + emoji2.emotion = ["难过"] + manager.emojis = [emoji1, emoji2] + + def _calc(_label): + return [(emoji1, 0.9), (emoji2, 0.2)] + + monkeypatch.setattr(manager, "_calculate_emotion_similarity_list", _calc) + monkeypatch.setattr(emoji_manager_new.random, "choice", lambda items: items[0]) + + called = {"emoji": None} + + def _update(emoji): + called["emoji"] = emoji + return True + + monkeypatch.setattr(manager, "update_emoji_usage", _update) + + result = await manager.get_emoji_for_emotion("开心") + + assert result is emoji1 + assert called["emoji"] is emoji1 + assert any("选中表情包" in m for m in _messages(logger.info_calls)) + + +@pytest.mark.asyncio +async def test_get_emoji_for_emotion_similarity_error_propagates(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + manager = emoji_manager_new.EmojiManager() + manager.emojis = [emoji_manager_new.MaiEmoji()] + + def _calc(_label): + raise RuntimeError("calc failed") + + monkeypatch.setattr(manager, "_calculate_emotion_similarity_list", _calc) + + with pytest.raises(RuntimeError): + await manager.get_emoji_for_emotion("异常") + + +@pytest.mark.asyncio +async def test_build_emoji_description_calls_hash_and_sets_description(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + called = {"hash": False, "vlm": False} + + async def _calc(self): + called["hash"] = True + return True + + def _read_bytes(_path): + return b"" + + async def _vlm_response(*_args, **_kwargs): + called["vlm"] = True + return "desc", None + + monkeypatch.setattr(emoji_manager_new.MaiEmoji, "calculate_hash_format", _calc, raising=False) + monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) + monkeypatch.setattr( + emoji_manager_new.emoji_manager_vlm, + "generate_response_for_image", + _vlm_response, + ) + + emoji = emoji_manager_new.MaiEmoji() + emoji.emoji_hash = None + emoji._format = "png" + emoji.full_path = Path("/tmp/a.png") + + result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji) + + assert result is True + assert updated.description == "desc" + assert called["hash"] is True + assert called["vlm"] is True + assert any("成功为表情包构建描述" in m for m in _messages(logger.info_calls)) + + +@pytest.mark.asyncio +async def test_build_emoji_description_gif_conversion_error(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + def _read_bytes(_path): + return b"" + + def _gif_to_static(_bytes): + raise RuntimeError("gif fail") + + monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) + monkeypatch.setattr(emoji_manager_new.ImageUtils, "gif_2_static_image", staticmethod(_gif_to_static)) + + emoji = emoji_manager_new.MaiEmoji() + emoji.emoji_hash = "hash" + emoji._format = "gif" + emoji.full_path = Path("/tmp/a.gif") + + result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji) + + assert result is False + assert updated.description is None + assert any("转换 GIF 图片时出错" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_build_emoji_description_content_filtration_reject(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + emoji_manager_new.global_config.emoji.content_filtration = True + emoji_manager_new.global_config.emoji.filtration_prompt = "rule" + + def _read_bytes(_path): + return b"" + + call_count = {"n": 0} + + async def _vlm_response(*_args, **_kwargs): + call_count["n"] += 1 + if call_count["n"] == 2: + return "否", None + return "desc", None + + monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) + monkeypatch.setattr( + emoji_manager_new.emoji_manager_vlm, + "generate_response_for_image", + _vlm_response, + ) + + emoji = emoji_manager_new.MaiEmoji() + emoji.emoji_hash = "hash" + emoji._format = "png" + emoji.full_path = Path("/tmp/a.png") + + result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji) + + assert result is False + assert updated.description is None + assert any("表情包内容不符合要求" in m for m in _messages(logger.warning_calls)) + + +@pytest.mark.asyncio +async def test_build_emoji_description_content_filtration_pass(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + emoji_manager_new.global_config.emoji.content_filtration = True + emoji_manager_new.global_config.emoji.filtration_prompt = "rule" + + def _read_bytes(_path): + return b"" + + async def _vlm_response(prompt, *_args, **_kwargs): + if "rule" in str(prompt): + return "是", None + return "desc", None + + monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) + monkeypatch.setattr( + emoji_manager_new.emoji_manager_vlm, + "generate_response_for_image", + _vlm_response, + ) + + emoji = emoji_manager_new.MaiEmoji() + emoji.emoji_hash = "hash" + emoji._format = "png" + emoji.full_path = Path("/tmp/a.png") + + result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji) + + assert result is True + assert updated.description == "desc" + assert any("成功为表情包构建描述" in m for m in _messages(logger.info_calls)) + + +@pytest.mark.asyncio +async def test_build_emoji_description_vlm_exception_propagates(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + + def _read_bytes(_path): + return b"" + + async def _vlm_response(*_args, **_kwargs): + raise RuntimeError("vlm failed") + + monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) + monkeypatch.setattr( + emoji_manager_new.emoji_manager_vlm, + "generate_response_for_image", + _vlm_response, + ) + + emoji = emoji_manager_new.MaiEmoji() + emoji.emoji_hash = "hash" + emoji._format = "png" + emoji.full_path = Path("/tmp/a.png") + + with pytest.raises(RuntimeError): + await emoji_manager_new.EmojiManager().build_emoji_description(emoji) + + +@pytest.mark.asyncio +async def test_build_emoji_emotion_description_missing(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + emoji = emoji_manager_new.MaiEmoji() + emoji.description = None + + result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) + + assert result is False + assert updated.emotion is None + assert any("表情包描述为空" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_build_emoji_emotion_llm_exception_propagates(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + + async def _generate_response_async(*_args, **_kwargs): + raise RuntimeError("llm failed") + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + + emoji = emoji_manager_new.MaiEmoji() + emoji.description = "desc" + + with pytest.raises(RuntimeError): + await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) + + +@pytest.mark.asyncio +async def test_build_emoji_emotion_empty_result(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + async def _generate_response_async(*_args, **_kwargs): + return " , , ", None + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + + emoji = emoji_manager_new.MaiEmoji() + emoji.description = "desc" + + result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) + + assert result is True + assert updated.emotion == [] + assert any("成功为表情包构建情感标签" in m for m in _messages(logger.info_calls)) + + +@pytest.mark.asyncio +async def test_build_emoji_emotion_more_than_five_random_sample(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + async def _generate_response_async(*_args, **_kwargs): + return "a,b,c,d,e,f", None + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + monkeypatch.setattr(emoji_manager_new.random, "sample", lambda items, _k: items[:3]) + + emoji = emoji_manager_new.MaiEmoji() + emoji.description = "desc" + + result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) + + assert result is True + assert updated.emotion == ["a", "b", "c"] + assert any("成功为表情包构建情感标签" in m for m in _messages(logger.info_calls)) + + +@pytest.mark.asyncio +async def test_build_emoji_emotion_three_items_random_sample(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + async def _generate_response_async(*_args, **_kwargs): + return "a,b,c", None + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + monkeypatch.setattr(emoji_manager_new.random, "sample", lambda items, _k: items[:2]) + + emoji = emoji_manager_new.MaiEmoji() + emoji.description = "desc" + + result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) + + assert result is True + assert updated.emotion == ["a", "b"] + assert any("成功为表情包构建情感标签" in m for m in _messages(logger.info_calls)) + + +def test_check_emoji_file_integrity_no_issues(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + class _DummyPath: + def __init__(self, name): + self._name = name + self._exists = True + + def exists(self): + return self._exists + + @property + def name(self): + return self._name + + manager = emoji_manager_new.EmojiManager() + emoji = emoji_manager_new.MaiEmoji() + emoji.file_name = "ok.png" + emoji.full_path = _DummyPath("ok.png") + emoji.is_deleted = False + emoji.description = "desc" + manager.emojis = [emoji] + manager._emoji_num = 1 + + called = {"count": 0} + + def _delete(_emoji): + called["count"] += 1 + return True + + monkeypatch.setattr(manager, "delete_emoji", _delete) + + manager.check_emoji_file_integrity() + + assert manager.emojis == [emoji] + assert manager._emoji_num == 1 + assert called["count"] == 0 + assert logger.warning_calls == [] + assert any("完整性检查完成" in m for m in _messages(logger.info_calls)) + + +def test_check_emoji_file_integrity_removes_invalid_records(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + class _DummyPath: + def __init__(self, name, exists=True): + self._name = name + self._exists = exists + + def exists(self): + return self._exists + + @property + def name(self): + return self._name + + manager = emoji_manager_new.EmojiManager() + missing_file = emoji_manager_new.MaiEmoji() + missing_file.file_name = "missing.png" + missing_file.full_path = _DummyPath("missing.png", exists=False) + missing_file.description = "desc" + + deleted_flag = emoji_manager_new.MaiEmoji() + deleted_flag.file_name = "deleted.png" + deleted_flag.full_path = _DummyPath("deleted.png", exists=True) + deleted_flag.is_deleted = True + deleted_flag.description = "desc" + + missing_desc = emoji_manager_new.MaiEmoji() + missing_desc.file_name = "nodesc.png" + missing_desc.full_path = _DummyPath("nodesc.png", exists=True) + missing_desc.description = None + + manager.emojis = [missing_file, deleted_flag, missing_desc] + manager._emoji_num = 3 + + deleted = [] + + def _delete(emoji): + deleted.append(emoji.file_name) + return True + + monkeypatch.setattr(manager, "delete_emoji", _delete) + + manager.check_emoji_file_integrity() + + assert manager.emojis == [] + assert manager._emoji_num == 0 + assert set(deleted) == {"missing.png", "deleted.png", "nodesc.png"} + messages = _messages(logger.warning_calls) + assert any("文件缺失" in m for m in messages) + assert any("标记为已删除" in m for m in messages) + assert any("缺失描述" in m for m in messages) + assert any("成功删除缺失文件的表情包记录" in m for m in _messages(logger.info_calls)) + assert any("删除了 3 条记录" in m for m in _messages(logger.info_calls)) + + +def test_check_emoji_file_integrity_delete_failed(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + class _DummyPath: + def __init__(self, name): + self._name = name + self._exists = False + + def exists(self): + return self._exists + + @property + def name(self): + return self._name + + manager = emoji_manager_new.EmojiManager() + emoji = emoji_manager_new.MaiEmoji() + emoji.file_name = "bad.png" + emoji.full_path = _DummyPath("bad.png") + emoji.description = "desc" + manager.emojis = [emoji] + manager._emoji_num = 1 + + def _delete(_emoji): + return False + + monkeypatch.setattr(manager, "delete_emoji", _delete) + + manager.check_emoji_file_integrity() + + assert manager.emojis == [emoji] + assert manager._emoji_num == 1 + assert any("表情包文件缺失" in m for m in _messages(logger.warning_calls)) + assert any("删除缺失文件的表情包记录失败" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_build_emoji_emotion_two_items_no_sample(monkeypatch): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + + async def _generate_response_async(*_args, **_kwargs): + return "a, b", None + + monkeypatch.setattr( + emoji_manager_new.emoji_manager_emotion_judge_llm, + "generate_response_async", + _generate_response_async, + ) + + emoji = emoji_manager_new.MaiEmoji() + emoji.description = "desc" + + result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) + + assert result is True + assert updated.emotion == ["a", "b"] + assert any("成功为表情包构建情感标签" in m for m in _messages(logger.info_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_file_missing(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + missing_file = tmp_path / "missing.png" + + result = await manager.register_emoji_by_filename(missing_file) + + assert result is False + assert any("表情包文件不存在" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_create_object_error(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + file_path = tmp_path / "ok.png" + file_path.write_bytes(b"") + + class _BadEmoji: + def __init__(self, *args, **kwargs): + raise RuntimeError("create failed") + + monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _BadEmoji) + + result = await manager.register_emoji_by_filename(file_path) + + assert result is False + assert any("创建表情包对象时出错" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_hash_format_failed(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + file_path = tmp_path / "hash.png" + file_path.write_bytes(b"") + + class _Emoji(emoji_manager_new.MaiEmoji): + async def calculate_hash_format(self): + return False + + monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) + + result = await manager.register_emoji_by_filename(file_path) + + assert result is False + assert any("计算表情包哈希值和格式失败" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_duplicate_hash(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + file_path = tmp_path / "dup.png" + file_path.write_bytes(b"") + + class _Emoji(emoji_manager_new.MaiEmoji): + async def calculate_hash_format(self): + self.emoji_hash = "hash-dup" + self.full_path = file_path + return True + + monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) + + existing = emoji_manager_new.MaiEmoji() + existing.file_name = "exist.png" + monkeypatch.setattr(manager, "get_emoji_by_hash", lambda _h: existing) + + result = await manager.register_emoji_by_filename(file_path) + + assert result is False + assert any("表情包已存在" in m for m in _messages(logger.warning_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_build_description_failed(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + file_path = tmp_path / "desc.png" + file_path.write_bytes(b"") + + class _Emoji(emoji_manager_new.MaiEmoji): + async def calculate_hash_format(self): + self.emoji_hash = "hash-desc" + self.full_path = file_path + return True + + monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) + + async def _build_desc(_e): + return False, _e + + monkeypatch.setattr(manager, "build_emoji_description", _build_desc) + + result = await manager.register_emoji_by_filename(file_path) + + assert result is False + assert any("构建表情包描述失败" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_build_emotion_failed(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + + file_path = tmp_path / "emo.png" + file_path.write_bytes(b"") + + class _Emoji(emoji_manager_new.MaiEmoji): + async def calculate_hash_format(self): + self.emoji_hash = "hash-emo" + self.full_path = file_path + return True + + monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) + + async def _build_desc(_e): + return True, _e + + async def _build_emo(_e): + return False, _e + + monkeypatch.setattr(manager, "build_emoji_description", _build_desc) + monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) + + result = await manager.register_emoji_by_filename(file_path) + + assert result is False + assert any("构建表情包情感标签失败" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_capacity_replace_failed(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager._emoji_num = 1 + emoji_manager_new.global_config.emoji.max_reg_num = 1 + + file_path = tmp_path / "full.png" + file_path.write_bytes(b"") + + class _Emoji(emoji_manager_new.MaiEmoji): + async def calculate_hash_format(self): + self.emoji_hash = "hash-full" + self.full_path = file_path + return True + + monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) + + async def _build_desc(_e): + return True, _e + + async def _build_emo(_e): + return True, _e + + async def _replace(_e): + return False + + monkeypatch.setattr(manager, "build_emoji_description", _build_desc) + monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) + monkeypatch.setattr(manager, "replace_an_emoji_by_llm", _replace) + + result = await manager.register_emoji_by_filename(file_path) + + assert result is False + assert any("数量已达上限" in m for m in _messages(logger.warning_calls)) + assert any("替换表情包失败" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_capacity_replace_success(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager._emoji_num = 1 + emoji_manager_new.global_config.emoji.max_reg_num = 1 + + file_path = tmp_path / "full-ok.png" + file_path.write_bytes(b"") + + class _Emoji(emoji_manager_new.MaiEmoji): + async def calculate_hash_format(self): + self.emoji_hash = "hash-full-ok" + self.full_path = file_path + return True + + monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) + + async def _build_desc(_e): + return True, _e + + async def _build_emo(_e): + return True, _e + + async def _replace(_e): + return True + + monkeypatch.setattr(manager, "build_emoji_description", _build_desc) + monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) + monkeypatch.setattr(manager, "replace_an_emoji_by_llm", _replace) + + result = await manager.register_emoji_by_filename(file_path) + + assert result is True + assert any("数量已达上限" in m for m in _messages(logger.warning_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_register_db_failed(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager._emoji_num = 0 + emoji_manager_new.global_config.emoji.max_reg_num = 10 + + file_path = tmp_path / "db-fail.png" + file_path.write_bytes(b"") + + class _Emoji(emoji_manager_new.MaiEmoji): + async def calculate_hash_format(self): + self.emoji_hash = "hash-db-fail" + self.full_path = file_path + return True + + monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) + + async def _build_desc(_e): + return True, _e + + async def _build_emo(_e): + return True, _e + + monkeypatch.setattr(manager, "build_emoji_description", _build_desc) + monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) + monkeypatch.setattr(manager, "register_emoji_to_db", lambda _e: False) + + result = await manager.register_emoji_by_filename(file_path) + + assert result is False + assert any("注册表情包到数据库失败" in m for m in _messages(logger.error_calls)) + + +@pytest.mark.asyncio +async def test_register_emoji_by_filename_register_db_success(monkeypatch, tmp_path): + emoji_manager_new = import_emoji_manager_new(monkeypatch) + logger = emoji_manager_new.logger + manager = emoji_manager_new.EmojiManager() + manager._emoji_num = 0 + emoji_manager_new.global_config.emoji.max_reg_num = 10 + + file_path = tmp_path / "db-ok.png" + file_path.write_bytes(b"") + + class _Emoji(emoji_manager_new.MaiEmoji): + async def calculate_hash_format(self): + self.emoji_hash = "hash-db-ok" + self.full_path = file_path + self.file_name = "db-ok.png" + return True + + monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) + + async def _build_desc(_e): + return True, _e + + async def _build_emo(_e): + return True, _e + + monkeypatch.setattr(manager, "build_emoji_description", _build_desc) + monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) + monkeypatch.setattr(manager, "register_emoji_to_db", lambda _e: True) + + result = await manager.register_emoji_by_filename(file_path) + + assert result is True + assert manager._emoji_num == 1 + assert len(manager.emojis) == 1 + assert any("成功注册新表情包" in m for m in _messages(logger.info_calls)) diff --git a/requirements.txt b/requirements.txt index 56c944e8..4cc63bc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ -aiohttp>=3.12.14 aiohttp-cors>=0.8.1 +aiohttp>=3.12.14 colorama>=0.4.6 faiss-cpu>=1.11.0 fastapi>=0.116.0 google-genai>=1.39.1 jieba>=0.42.1 json-repair>=0.47.6 -maim-message +maim-message>=0.6.2 matplotlib>=3.10.3 numpy>=2.2.6 openai>=1.95.0 @@ -17,6 +17,7 @@ pyarrow>=20.0.0 pydantic>=2.11.7 pypinyin>=0.54.0 python-dotenv>=1.1.1 +python-levenshtein python-multipart>=0.0.20 quick-algo>=0.1.3 rich>=14.0.0 @@ -26,6 +27,4 @@ structlog>=25.4.0 toml>=0.10.2 tomlkit>=0.13.3 urllib3>=2.5.0 -uvicorn>=0.35.0 -msgpack -zstandard \ No newline at end of file +uvicorn>=0.35.0 \ No newline at end of file diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index e0c4d103..c7cb1dd1 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -1,1154 +1,558 @@ -import asyncio -import base64 -import hashlib -import os -import random -import time -import traceback -import io -import re -import binascii - -from typing import Optional, Tuple, List, Any -from PIL import Image +from datetime import datetime +from pathlib import Path from rich.traceback import install +from sqlmodel import select +from typing import Optional, Tuple, List + +import asyncio +import heapq +import Levenshtein +import random +import re -from src.common.database.database_model import Emoji, EmojiDescriptionCache -from src.common.database.database import db as peewee_db from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.chat.utils.utils_image import image_path_to_base64, get_image_manager +from src.common.data_models.image_data_model import MaiEmoji +from src.common.database.database_model import Images, ImageType +from src.common.database.database import get_db_session +from src.common.utils.utils_image import ImageUtils +from src.prompt.prompt_manager import prompt_manager +from src.config.config import global_config +from src.config.config import model_config from src.llm_models.utils_model import LLMRequest -install(extra_lines=3) - logger = get_logger("emoji") -BASE_DIR = os.path.join("data") -EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录 -EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录 +install(extra_lines=3) + +PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve() +DATA_DIR = PROJECT_ROOT / "data" +EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录 +EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注册目录 MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中 -""" -还没经过测试,有些地方数据库和内存数据同步可能不完全 -""" +def _ensure_directories(): + """确保表情包相关目录存在""" + EMOJI_DIR.mkdir(parents=True, exist_ok=True) + EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True) -class MaiEmoji: - """定义一个表情包""" - - def __init__(self, full_path: str): - if not full_path: - raise ValueError("full_path cannot be empty") - self.full_path = full_path # 文件的完整路径 (包括文件名) - self.path = os.path.dirname(full_path) # 文件所在的目录路径 - self.filename = os.path.basename(full_path) # 文件名 - self.embedding = [] - self.hash = "" # 初始为空,在创建实例时会计算 - self.description = "" - self.emotion: List[str] = [] - self.usage_count = 0 - self.last_used_time = time.time() - self.register_time = time.time() - self.is_deleted = False # 标记是否已被删除 - self.format = "" - - async def initialize_hash_format(self) -> Optional[bool]: - """从文件创建表情包实例, 计算哈希值和格式""" - try: - # 使用 full_path 检查文件是否存在 - if not os.path.exists(self.full_path): - logger.error(f"[初始化错误] 表情包文件不存在: {self.full_path}") - self.is_deleted = True - return None - - # 使用 full_path 读取文件 - logger.debug(f"[初始化] 正在读取文件: {self.full_path}") - image_base64 = image_path_to_base64(self.full_path) - if image_base64 is None: - logger.error(f"[初始化错误] 无法读取或转换Base64: {self.full_path}") - self.is_deleted = True - return None - logger.debug(f"[初始化] 文件读取成功 (Base64预览: {image_base64[:50]}...)") - - # 计算哈希值 - logger.debug(f"[初始化] 正在解码Base64并计算哈希: {self.filename}") - # 确保base64字符串只包含ASCII字符 - if isinstance(image_base64, str): - image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") - image_bytes = base64.b64decode(image_base64) - self.hash = hashlib.md5(image_bytes).hexdigest() - logger.debug(f"[初始化] 哈希计算成功: {self.hash}") - - # 获取图片格式 - logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}") - try: - with Image.open(io.BytesIO(image_bytes)) as img: - self.format = img.format.lower() # type: ignore - logger.debug(f"[初始化] 格式获取成功: {self.format}") - except Exception as pil_error: - logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}") - logger.error(traceback.format_exc()) - self.is_deleted = True - return None - - # 如果所有步骤成功,返回 True - return True - - except FileNotFoundError: - logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}") - self.is_deleted = True - return None - except (binascii.Error, ValueError) as b64_error: - logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}") - self.is_deleted = True - return None - except Exception as e: - logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {str(e)}") - logger.error(traceback.format_exc()) - self.is_deleted = True - return None - - async def register_to_db(self) -> bool: - """ - 注册表情包 - 将表情包对应的文件,从当前路径移动到EMOJI_REGISTERED_DIR目录下 - 并修改对应的实例属性,然后将表情包信息保存到数据库中 - """ - try: - # 确保目标目录存在 - - # 源路径是当前实例的完整路径 self.full_path - source_full_path = self.full_path - # 目标完整路径 - destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename) - - # 检查源文件是否存在 - if not os.path.exists(source_full_path): - logger.error(f"[错误] 源文件不存在: {source_full_path}") - return False - - # --- 文件移动 --- - try: - # 如果目标文件已存在,先删除 (确保移动成功) - if os.path.exists(destination_full_path): - os.remove(destination_full_path) - - os.rename(source_full_path, destination_full_path) - logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}") - # 更新实例的路径属性为新路径 - self.full_path = destination_full_path - self.path = EMOJI_REGISTERED_DIR - # self.filename 保持不变 - except Exception as move_error: - logger.error(f"[错误] 移动文件失败: {str(move_error)}") - # 如果移动失败,尝试将实例状态恢复?暂时不处理,仅返回失败 - return False - - # --- 数据库操作 --- - try: - # 准备数据库记录 for emoji collection - emotion_str = ",".join(self.emotion) if self.emotion else "" - - Emoji.create( - emoji_hash=self.hash, - full_path=self.full_path, - format=self.format, - description=self.description, - emotion=emotion_str, # Store as comma-separated string - query_count=0, # Default value - is_registered=True, - is_banned=False, # Default value - record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time - register_time=self.register_time, - usage_count=self.usage_count, - last_used_time=self.last_used_time, - ) - - logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") - - return True - - except Exception as db_error: - logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") - return False - - except Exception as e: - logger.error(f"[错误] 注册表情包失败 ({self.filename}): {str(e)}") - logger.error(traceback.format_exc()) - return False - - async def delete(self) -> bool: - """删除表情包 - - 删除表情包的文件和数据库记录 - - 返回: - bool: 是否成功删除 - """ - try: - # 1. 删除文件 - file_to_delete = self.full_path - if os.path.exists(file_to_delete): - try: - os.remove(file_to_delete) - logger.debug(f"[删除] 文件: {file_to_delete}") - except Exception as e: - logger.error(f"[错误] 删除文件失败 {file_to_delete}: {str(e)}") - # 文件删除失败,但仍然尝试删除数据库记录 - - # 2. 删除数据库记录 - try: - will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash) - result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. - except Emoji.DoesNotExist: # type: ignore - logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") - result = 0 # Indicate no DB record was deleted - - if result > 0: - logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") - # 3. 标记对象已被删除 - self.is_deleted = True - return True - else: - # 如果数据库记录删除失败,但文件可能已删除,记录一个警告 - if not os.path.exists(file_to_delete): - logger.warning( - f"[警告] 表情包文件 {file_to_delete} 已删除,但数据库记录删除失败 (Hash: {self.hash})" - ) - else: - logger.error(f"[错误] 删除表情包数据库记录失败: {self.hash}") - return False - - except Exception as e: - logger.error(f"[错误] 删除表情包失败 ({self.filename}): {str(e)}") - return False - - -def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]: - """将表情包对象列表转换为可读的字符串列表 - - 参数: - emoji_objects: MaiEmoji对象列表 - - 返回: - list[str]: 可读的表情包信息字符串列表 - """ - emoji_info_list = [] - for i, emoji in enumerate(emoji_objects): - # 转换时间戳为可读时间 - time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(emoji.register_time)) - # 构建每个表情包的信息字符串 - emoji_info = f"编号: {i + 1}\n描述: {emoji.description}\n使用次数: {emoji.usage_count}\n添加时间: {time_str}\n" - emoji_info_list.append(emoji_info) - return emoji_info_list - - -def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: - emoji_objects = [] - load_errors = 0 - # data is now an iterable of Peewee Emoji model instances - emoji_data_list = list(data) - - for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance - full_path = emoji_data.full_path - if not full_path: - logger.warning( - f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}" - ) - load_errors += 1 - continue - - try: - emoji = MaiEmoji(full_path=full_path) - - emoji.hash = emoji_data.emoji_hash - if not emoji.hash: - logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") - load_errors += 1 - continue - - emoji.description = emoji_data.description - # Deserialize emotion string from DB to list - emoji.emotion = emoji_data.emotion.replace(",", ",").split(",") if emoji_data.emotion else [] - emoji.usage_count = emoji_data.usage_count - - db_last_used_time = emoji_data.last_used_time - db_register_time = emoji_data.register_time - - # If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time - emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time - # If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time()) - emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time - - emoji.format = emoji_data.format - - emoji_objects.append(emoji) - - except ValueError as ve: - logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") - load_errors += 1 - except Exception as e: - logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {str(e)}") - load_errors += 1 - return emoji_objects, load_errors - - -def _ensure_emoji_dir() -> None: - """确保表情存储目录存在""" - os.makedirs(EMOJI_DIR, exist_ok=True) - os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True) - - -async def clear_temp_emoji() -> None: - """清理临时表情包 - 清理/data/emoji、/data/image和/data/images目录下的所有文件 - 当目录中文件数超过100时,会全部删除 - """ - - logger.info("[清理] 开始清理缓存...") - - for need_clear in ( - os.path.join(BASE_DIR, "emoji"), - os.path.join(BASE_DIR, "image"), - os.path.join(BASE_DIR, "images"), - ): - if os.path.exists(need_clear): - files = os.listdir(need_clear) - # 如果文件数超过100就全部删除 - if len(files) > 100: - for filename in files: - file_path = os.path.join(need_clear, filename) - if os.path.isfile(file_path): - os.remove(file_path) - logger.debug(f"[清理] 删除: {filename}") - - -async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], removed_count: int) -> int: - """清理指定目录中未被 emoji_objects 追踪的表情包文件""" - if not os.path.exists(emoji_dir): - logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") - return removed_count - - cleaned_count = 0 - try: - # 获取内存中所有有效表情包的完整路径集合 - tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted} - - # 遍历指定目录中的所有文件 - for file_name in os.listdir(emoji_dir): - file_full_path = os.path.join(emoji_dir, file_name) - - # 确保处理的是文件而不是子目录 - if not os.path.isfile(file_full_path): - continue - - # 如果文件不在被追踪的集合中,则删除 - if file_full_path not in tracked_full_paths: - try: - os.remove(file_full_path) - logger.info(f"[清理] 删除未追踪的表情包文件: {file_full_path}") - cleaned_count += 1 - except Exception as e: - logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {str(e)}") - - if cleaned_count > 0: - logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。") - else: - logger.debug(f"[清理] 目录 {emoji_dir} 中没有需要清理的。") - - except Exception as e: - logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}") - - return removed_count + cleaned_count +# TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法 +emoji_manager_vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see") +emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji") class EmojiManager: - _instance = None + def __init__(self): + _ensure_directories() - def __new__(cls) -> "EmojiManager": - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self) -> None: - if self._initialized: - return # 如果已经初始化过,直接返回 - - self._scan_task = None - - self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see") - self.llm_emotion_judge = LLMRequest( - model_set=model_config.model_task_config.utils, request_type="emoji" - ) - - self.emoji_num = 0 - self.emoji_num_max = global_config.emoji.max_reg_num - self.emoji_num_max_reach_deletion = global_config.emoji.do_replace - self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型 + self._emoji_num: int = 0 + self.emojis: list[MaiEmoji] = [] logger.info("启动表情包管理器") - def initialize(self) -> None: - """初始化数据库连接和表情目录""" - peewee_db.connect(reuse_if_open=True) - if peewee_db.is_closed(): - raise RuntimeError("数据库连接失败") - _ensure_emoji_dir() - Emoji.create_table(safe=True) # Ensures table exists - EmojiDescriptionCache.create_table(safe=True) - self._initialized = True + def load_emojis_from_db(self) -> None: + """ + 从数据库加载已注册的表情包 - def _ensure_db(self) -> None: - """确保数据库已初始化""" - if not self._initialized: - self.initialize() - if not self._initialized: - raise RuntimeError("EmojiManager not initialized") - - def record_usage(self, emoji_hash: str) -> None: - """记录表情使用次数""" + Raises: + Exception: 如果加载过程中发生不可恢复错误,则抛出异常 + """ + logger.debug("[数据库] 开始加载所有表情包记录...") try: - emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash) - emoji_update.usage_count += 1 - emoji_update.last_used_time = time.time() # Update last used time - emoji_update.save() # Persist changes to DB - except Emoji.DoesNotExist: # type: ignore - logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") + with get_db_session() as session: + statement = select(Images) + results = session.execute(statement).scalars().all() + for record in results: + try: + emoji = MaiEmoji.from_db_instance(record) + self.emojis.append(emoji) + except Exception as e: + logger.error( + f"[数据库] 加载表情包记录时出错: {e}\n记录ID: {record.id}, 路径: {record.full_path}" + ) + self._emoji_num = len(self.emojis) + logger.info(f"[数据库] 成功加载 {self._emoji_num} 个已注册表情包") except Exception as e: - logger.error(f"记录表情使用失败: {str(e)}") + logger.critical(f"[数据库] 加载表情包记录时发生不可恢复错误: {e}") + self.emojis = [] + self._emoji_num = 0 + raise e + + def register_emoji_to_db(self, emoji: MaiEmoji) -> bool: + # sourcery skip: extract-method + """ + 将表情包注册到数据库中 - async def get_emoji_for_text(self, text_emotion: str) -> Optional[Tuple[str, str, str]]: - """根据文本内容获取相关表情包 Args: - text_emotion: 输入的情感描述文本 + emoji (MaiEmoji): 需要注册的表情包对象 Returns: - Optional[Tuple[str, str]]: (表情包完整文件路径, 表情包描述),如果没有找到则返回None + return (bool): 注册是否成功 + """ + if not emoji or not isinstance(emoji, MaiEmoji): + logger.error("[注册表情包] 无效的表情包对象") + return False + if not emoji.full_path.exists(): + logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}") + return False + + # 将表情包移动到已注册目录 + target_path = EMOJI_REGISTERED_DIR / emoji.file_name + try: + emoji.full_path.replace(target_path) + emoji.full_path = target_path + except Exception as e: + logger.error(f"[注册表情包] 移动表情包文件时出错: {e}") + return False + + # 注册到数据库 + try: + with get_db_session() as session: + image_record = emoji.to_db_instance() + image_record.is_registered = True + image_record.is_banned = False + image_record.register_time = datetime.now() + session.add(image_record) + session.flush() + record_id = image_record.id + logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}") + except Exception as e: + logger.error(f"[注册表情包] 注册到数据库时出错: {e}") + return False + return True + + def delete_emoji(self, emoji: MaiEmoji) -> bool: + """ + 删除表情包的文件和数据库记录 + + Args: + emoji (MaiEmoji): 需要删除的表情包对象 + Returns: + return (bool): 删除是否成功 + """ + # 删除文件 + file_to_delete = emoji.full_path + try: + file_to_delete.unlink() + logger.info(f"[删除表情包] 成功删除表情包文件: {emoji.file_name}") + except FileNotFoundError: + logger.warning(f"[删除表情包] 表情包文件 {emoji.file_name} 不存在,跳过文件删除") + except Exception as e: + logger.error(f"[删除表情包] 删除表情包文件时出错: {e}") + return False + + # 删除数据库记录 + try: + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI) + if image_record := session.execute(statement).scalars().first(): + session.delete(image_record) + logger.info(f"[删除表情包] 成功删除数据库中的表情包记录: {emoji.emoji_hash}") + else: + logger.warning(f"[删除表情包] 数据库中未找到表情包记录: {emoji.emoji_hash}") + except Exception as e: + logger.error(f"[删除表情包] 删除数据库记录时出错: {e}") + # 如果数据库记录删除失败,但文件可能已删除,记录一个警告 + if file_to_delete.exists(): + logger.warning(f"[删除表情包] 数据库记录删除失败,但文件仍存在: {emoji.file_name}") + return False + + return True + + def update_emoji_usage(self, emoji: MaiEmoji) -> bool: + """ + 更新表情包的使用情况,更新查询次数和上次使用时间 + + Args: + emoji (MaiEmoji): 使用的表情包对象 + Returns: + return (bool): 更新是否成功 """ try: - self._ensure_db() - _time_start = time.time() - - # 获取所有表情包 (从内存缓存中获取) - all_emojis = self.emoji_objects - - if not all_emojis: - logger.warning("内存中没有任何表情包对象") - return None - - # 计算每个表情包与输入文本的最大情感相似度 - emoji_similarities = [] - for emoji in all_emojis: - # 跳过已标记为删除的对象 - if emoji.is_deleted: - continue - - emotions = emoji.emotion - if not emotions: - continue - - # 计算与每个emotion标签的相似度,取最大值 - max_similarity = 0 - best_matching_emotion = "" - for emotion in emotions: - # 使用编辑距离计算相似度 - distance = self._levenshtein_distance(text_emotion, emotion) - max_len = max(len(text_emotion), len(emotion)) - similarity = 1 - (distance / max_len if max_len > 0 else 0) - if similarity > max_similarity: - max_similarity = similarity - best_matching_emotion = emotion - - if best_matching_emotion: - emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) - - # 按相似度降序排序 - emoji_similarities.sort(key=lambda x: x[1], reverse=True) - - # 获取前10个最相似的表情包 - top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities - - if not top_emojis: - logger.warning("未找到匹配的表情包") - return None - - # 从前几个中随机选择一个 - selected_emoji, similarity, matched_emotion = random.choice(top_emojis) - - # 更新使用次数 - self.record_usage(selected_emoji.hash) - - _time_end = time.time() - - logger.info( - f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}" - ) - # 返回完整文件路径和描述 - return selected_emoji.full_path, f"[ {selected_emoji.description} ]", matched_emotion - + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=emoji.emoji_hash, image_type=ImageType.EMOJI) + if image_record := session.execute(statement).scalars().first(): + image_record.query_count += 1 + image_record.last_used_time = datetime.now() + session.add(image_record) + logger.info(f"[记录表情包使用] 成功记录表情包使用: {emoji.emoji_hash}") + else: + logger.error(f"[记录表情包使用] 未找到表情包记录: {emoji.emoji_hash}") + return False except Exception as e: - logger.error(f"[错误] 获取表情包失败: {str(e)}") + logger.error(f"[记录表情包使用] 记录使用时出错: {e}") + return False + return True + + def get_emoji_by_hash(self, emoji_hash: str) -> Optional[MaiEmoji]: + """ + 根据哈希值获取表情包对象 + + Args: + emoji_hash (str): 表情包的哈希值 + Returns: + return (Optional[MaiEmoji]): 返回表情包对象,如果未找到则返回 None + """ + for emoji in self.emojis: + if emoji.emoji_hash == emoji_hash and not emoji.is_deleted: + return emoji + logger.info(f"[获取表情包] 未找到哈希值为 {emoji_hash} 的表情包") + return None + + def get_emoji_by_hash_from_db(self, emoji_hash: str) -> Optional[MaiEmoji]: + """ + 根据哈希值从数据库获取表情包对象 + + Args: + emoji_hash (str): 表情包的哈希值 + Returns: + return (Optional[MaiEmoji]): 返回表情包对象,如果未找到则返回 None + """ + try: + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI) + if image_record := session.execute(statement).scalars().first(): + return MaiEmoji.from_db_instance(image_record) + logger.info(f"[数据库] 未找到哈希值为 {emoji_hash} 的表情包记录") + return None + except Exception as e: + logger.error(f"[数据库] 获取表情包时出错: {e}") return None - def _levenshtein_distance(self, s1: str, s2: str) -> int: - # sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison - """计算两个字符串的编辑距离 + async def get_emoji_for_emotion(self, emotion_label: str) -> Optional[MaiEmoji]: + """ + 根据文本情感标签获取合适的表情包 Args: - s1: 第一个字符串 - s2: 第二个字符串 - + text_emotion (str): 文本的情感标签 Returns: - int: 编辑距离 + return (Optional[MaiEmoji]): 返回表情包对象,如果未找到则返回 None """ - if len(s1) < len(s2): - return self._levenshtein_distance(s2, s1) + if not self.emojis: + logger.warning("[获取表情包] 表情包列表为空") + return None - if len(s2) == 0: - return len(s1) + emoji_similarities = await asyncio.to_thread(self._calculate_emotion_similarity_list, emotion_label) + if not emoji_similarities: + logger.info("[获取表情包] 未找到匹配的表情包") + return None - previous_row = range(len(s2) + 1) - for i, c1 in enumerate(s1): - current_row = [i + 1] - for j, c2 in enumerate(s2): - insertions = previous_row[j + 1] + 1 - deletions = current_row[j] + 1 - substitutions = previous_row[j] + (c1 != c2) - current_row.append(min(insertions, deletions, substitutions)) - previous_row = current_row + # 获取前10个相似度最高的表情包 + top_emojis = heapq.nlargest(10, emoji_similarities, key=lambda x: x[1]) + selected_emoji, similarity = random.choice(top_emojis) + self.update_emoji_usage(selected_emoji) + logger.info( + f"[获取表情包] 为[{emotion_label}]选中表情包: {selected_emoji.file_name}({selected_emoji.emotion}),相似度: {similarity:.4f}" + ) + return selected_emoji - return previous_row[-1] - - async def check_emoji_file_integrity(self) -> None: - """检查表情包文件完整性 - 遍历self.emoji_objects中的所有对象,检查文件是否存在 - 如果文件已被删除,则执行对象的删除方法并从列表中移除 + async def replace_an_emoji_by_llm(self, new_emoji: MaiEmoji) -> bool: """ + 使用 LLM 决策替换一个表情包 + + Args: + new_emoji (MaiEmoji): 新添加的表情包对象 + Returns: + return (bool): 是否成功替换了一个表情包 + """ + # sourcery skip: use-getitem-for-re-match-groups + probabilities = [1 / (emoji.query_count + 1) for emoji in self.emojis] + selected_emojis = random.choices( + self.emojis, weights=probabilities, k=min(MAX_EMOJI_FOR_PROMPT, len(self.emojis)) + ) + emoji_info_list: list[str] = [] + for i, emoji in enumerate(selected_emojis): + time_str = emoji.register_time.strftime("%Y-%m-%d %H:%M:%S") if emoji.register_time else "未知时间" + emoji_info = ( + f"编号: {i + 1}\n描述: {emoji.description}\n使用次数: {emoji.query_count}\n添加时间: {time_str}\n" + ) + emoji_info_list.append(emoji_info) + + emoji_replace_prompt_template = prompt_manager.get_prompt("emoji_replace") + emoji_replace_prompt_template.add_context("nickname", global_config.bot.nickname) + emoji_replace_prompt_template.add_context("emoji_num", str(self._emoji_num)) + emoji_replace_prompt_template.add_context("emoji_num_max", str(global_config.emoji.max_reg_num)) + emoji_replace_prompt_template.add_context("emoji_list", "\n".join(emoji_info_list)) + emoji_replace_prompt = await prompt_manager.render_prompt(emoji_replace_prompt_template) + + decision, _ = await emoji_manager_emotion_judge_llm.generate_response_async( + emoji_replace_prompt, temperature=0.8, max_tokens=600 + ) + logger.info(f"[决策] 结果: {decision}") + + # 解析决策结果 + if "不删除" in decision: + logger.info("[决策] 不删除任何表情包") + return False try: - # if not self.emoji_objects: - # logger.warning("[检查] emoji_objects为空,跳过完整性检查") - # return - - total_count = len(self.emoji_objects) - self.emoji_num = total_count - removed_count = 0 - # 使用列表复制进行遍历,因为我们会在遍历过程中修改列表 - objects_to_remove = [] - for emoji in self.emoji_objects: - try: - # 跳过已经标记为删除的,避免重复处理 - if emoji.is_deleted: - objects_to_remove.append(emoji) # 收集起来一次性移除 - continue - - # 检查文件是否存在 - if not os.path.exists(emoji.full_path): - logger.warning(f"[检查] 表情包文件丢失: {emoji.full_path}") - # 执行表情包对象的删除方法 - await emoji.delete() # delete 方法现在会标记 is_deleted - objects_to_remove.append(emoji) # 标记删除后,也收集起来移除 - # 更新计数 - self.emoji_num -= 1 - removed_count += 1 - continue - - # 检查描述是否为空 (如果为空也视为无效) - if not emoji.description: - logger.warning(f"[检查] 表情包描述为空,视为无效: {emoji.filename}") - await emoji.delete() - objects_to_remove.append(emoji) - self.emoji_num -= 1 - removed_count += 1 - continue - - except Exception as item_error: - logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {str(item_error)}") - # 即使出错,也尝试继续检查下一个 - continue - - # 从 self.emoji_objects 中移除标记的对象 - if objects_to_remove: - self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove] - - # 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件 - removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count) - - # 输出清理结果 - if removed_count > 0: - logger.info(f"[清理] 已清理 {removed_count} 个失效/文件丢失的表情包记录") - logger.info(f"[统计] 清理前记录数: {total_count} | 清理后有效记录数: {len(self.emoji_objects)}") - else: - logger.info(f"[检查] 已检查 {total_count} 个表情包记录,全部完好") - + match = re.search(r"删除编号(\d+)", decision) except Exception as e: - logger.error(f"[错误] 检查表情包完整性失败: {str(e)}") - logger.error(traceback.format_exc()) + logger.error(f"[决策] 解析决策结果时出错: {e}") + return False + if match: + emoji_index = int(match.group(1)) - 1 # 转换为0-based索引 + # 检查索引是否有效 + if 0 <= emoji_index < len(selected_emojis): + emoji_to_delete = selected_emojis[emoji_index] + logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}") + if self.delete_emoji(emoji_to_delete): + self.emojis.remove(emoji_to_delete) + if self.register_emoji_to_db(new_emoji): + self.emojis.append(new_emoji) + logger.info(f"[注册表情包] 成功替换并注册新表情包: {new_emoji.description}") + return True + else: + logger.error(f"[注册表情包] 注册新表情包失败: {new_emoji.description}") + else: + logger.error("[错误] 删除表情包失败,无法完成替换") + else: + logger.error(f"[决策] 无效的表情包编号: {emoji_index + 1}") + else: + logger.error("[决策] 未能解析删除编号") + return False - async def start_periodic_check_register(self) -> None: - """定期检查表情包完整性和数量""" - await self.get_all_emoji_from_db() - while True: - # logger.info("[扫描] 开始检查表情包完整性...") - await self.check_emoji_file_integrity() - await clear_temp_emoji() - logger.info("[扫描] 开始扫描新表情包...") + async def build_emoji_description(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]: + """ + 构建表情包描述 - # 检查表情包目录是否存在 - if not os.path.exists(EMOJI_DIR): - logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}") - os.makedirs(EMOJI_DIR, exist_ok=True) - logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}") - await asyncio.sleep(global_config.emoji.check_interval * 60) - continue + Args: + target_emoji (MaiEmoji): 目标表情包对象 + Returns: + return (Tuple[bool, MaiEmoji]): 返回是否成功构建描述,及表情包对象 + """ + if not target_emoji.emoji_hash: + # Should not happen, but just in case + await target_emoji.calculate_hash_format() - # 检查目录是否为空 - files = os.listdir(EMOJI_DIR) - if not files: - logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}") - await asyncio.sleep(global_config.emoji.check_interval * 60) - continue + # 调用VLM生成描述 + image_format = target_emoji._format + image_bytes = target_emoji.read_image_bytes(target_emoji.full_path) - # 检查是否需要处理表情包(数量超过最大值或不足) - if global_config.emoji.steal_emoji and ( - (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) - or (self.emoji_num < self.emoji_num_max) - ): + if image_format == "gif": + try: + image_bytes = await asyncio.to_thread(ImageUtils.gif_2_static_image, image_bytes) + except Exception as e: + logger.error(f"[构建描述] 转换 GIF 图片时出错: {e}") + return False, target_emoji + prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答" + image_base64 = ImageUtils.image_bytes_to_base64(image_bytes) + description, _ = await emoji_manager_vlm.generate_response_for_image( + prompt, image_base64, "jpg", temperature=0.5 + ) + else: + prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答" + image_base64 = ImageUtils.image_bytes_to_base64(image_bytes) + description, _ = await emoji_manager_vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.5 + ) + + # 表情包审查 + if global_config.emoji.content_filtration: + filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration") + filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt) + filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template) + llm_response, _ = await emoji_manager_vlm.generate_response_for_image( + filtration_prompt, image_base64, image_format, temperature=0.3 + ) + if "否" in llm_response: + logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}") + return False, target_emoji + target_emoji.description = description + logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}") + return True, target_emoji + + async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]: + """ + 构建表情包情感,使用场景标签 + + Args: + target_emoji (MaiEmoji): 目标表情包对象 + Returns: + return (Tuple[bool, MaiEmoji]): 返回是否成功构建情感标签,及表情包对象 + """ + if not target_emoji.description: + logger.error("[构建情感标签] 表情包描述为空,无法构建情感标签") + return False, target_emoji + + # 获取Prompt + emotion_prompt_template = prompt_manager.get_prompt("emoji_content_analysis") + emotion_prompt_template.add_context("description", target_emoji.description) + emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template) + # 调用LLM生成情感标签 + emotion_result, _ = await emoji_manager_emotion_judge_llm.generate_response_async( + emotion_prompt, temperature=0.7, max_tokens=200 + ) + + # 解析情感标签结果 + emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()] + + # 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个 + if len(emotions) > 5: + emotions = random.sample(emotions, 3) + elif len(emotions) > 2: + emotions = random.sample(emotions, 2) + + logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}") + target_emoji.emotion = emotions + return True, target_emoji + + def check_emoji_file_integrity(self) -> None: + """ + 检查表情包完整性,删除文件缺失的表情包记录 + """ + logger.info("[完整性检查] 开始检查表情包文件完整性...") + to_delete_emojis: list[MaiEmoji] = [] + removal_count = 0 + for emoji in self.emojis: + if not emoji.full_path.exists(): + logger.warning(f"[完整性检查] 表情包文件缺失,准备删除记录: {emoji.file_name}") + to_delete_emojis.append(emoji) + if emoji.is_deleted: + logger.warning(f"[完整性检查] 表情包记录标记为已删除,准备删除记录: {emoji.file_name}") + to_delete_emojis.append(emoji) + if not emoji.description: + logger.warning(f"[完整性检查] 表情包记录缺失描述,准备删除记录: {emoji.file_name}") + to_delete_emojis.append(emoji) + + for emoji in to_delete_emojis: + if self.delete_emoji(emoji): + self.emojis.remove(emoji) + self._emoji_num -= 1 + removal_count += 1 + logger.info(f"[完整性检查] 成功删除缺失文件的表情包记录: {emoji.file_name}") + else: + logger.error(f"[完整性检查] 删除缺失文件的表情包记录失败: {emoji.file_name}") + + logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录") + + def remove_untracked_emoji_files(self) -> None: + """ + 删除未被数据库记录跟踪的表情包文件 + """ + logger.info("[未跟踪表情包清理] 开始清理未被数据库记录跟踪的表情包文件...") + tracked_files = {emoji.full_path.name for emoji in self.emojis} + all_files = set(EMOJI_REGISTERED_DIR.glob("*")) + removal_count = 0 + + for file_path in all_files: + if file_path.name not in tracked_files: try: - # 获取目录下所有图片文件 - files_to_process = [ - f - for f in files - if os.path.isfile(os.path.join(EMOJI_DIR, f)) - and f.lower().endswith((".jpg", ".jpeg", ".png", ".gif")) - ] - - # 处理每个符合条件的文件 - for filename in files_to_process: - # 尝试注册表情包 - success = await self.register_emoji_by_filename(filename) - if success: - # 注册成功则跳出循环 - break - - # 注册失败则删除对应文件 - file_path = os.path.join(EMOJI_DIR, filename) - os.remove(file_path) - logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") + file_path.unlink() + removal_count += 1 + logger.info(f"[未跟踪表情包清理] 删除未跟踪的表情包文件: {file_path.name}") except Exception as e: - logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") + logger.error(f"[未跟踪表情包清理] 删除文件 {file_path.name} 时出错: {e}") + logger.info(f"[未跟踪表情包清理] 未跟踪表情包文件清理完成,删除了 {removal_count} 个文件") + + async def periodic_emoji_maintenance(self) -> None: + """ + 定期执行表情包维护任务,包括完整性检查和未跟踪文件清理 + """ + while True: + EMOJI_DIR.mkdir(parents=True, exist_ok=True) + EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True) + try: + self.check_emoji_file_integrity() + self.remove_untracked_emoji_files() + except Exception as e: + logger.error(f"[定期维护] 执行表情包维护任务时出错: {e}") + + if global_config.emoji.steal_emoji and ( + self._emoji_num < global_config.emoji.max_reg_num + or (self._emoji_num > global_config.emoji.max_reg_num and global_config.emoji.do_replace) + ): + logger.info("[定期维护] 尝试从表情包盗取目录注册新表情包...") + for emoji_file in EMOJI_DIR.iterdir(): + if not emoji_file.is_file(): + continue + if await self.register_emoji_by_filename(emoji_file): + break # 每次只注册一个表情包 + try: + emoji_file.unlink() + logger.info(f"[定期维护] 删除无法注册的表情包文件: {emoji_file.name}") + except Exception as e: + logger.error(f"[定期维护] 删除文件 {emoji_file.name} 时出错: {e}") await asyncio.sleep(global_config.emoji.check_interval * 60) - async def get_all_emoji_from_db(self) -> None: - """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" - try: - self._ensure_db() - logger.debug("[数据库] 开始加载所有表情包记录 (Peewee)...") - - emoji_peewee_instances = Emoji.select() - emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) - - # 更新内存中的列表和数量 - self.emoji_objects = emoji_objects - self.emoji_num = len(emoji_objects) - - logger.info(f"[数据库] 加载完成: 共加载 {self.emoji_num} 个表情包记录。") - if load_errors > 0: - logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。") - - except Exception as e: - logger.error(f"[错误] 从数据库加载所有表情包对象失败: {str(e)}") - self.emoji_objects = [] # 加载失败则清空列表 - self.emoji_num = 0 - - async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: - """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) - - 参数: - emoji_hash: 可选,如果提供则只返回指定哈希值的表情包 - - 返回: - list[MaiEmoji]: 表情包对象列表 + async def register_emoji_by_filename(self, filename: Path | str) -> bool: """ - try: - self._ensure_db() - - if emoji_hash: - query = Emoji.select().where(Emoji.emoji_hash == emoji_hash) - else: - logger.warning( - "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" - ) - query = Emoji.select() - - emoji_peewee_instances = query - emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) - - if load_errors > 0: - logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") - - return emoji_objects - - except Exception as e: - logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}") - return [] - - async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: - # sourcery skip: use-next - """从内存中的 emoji_objects 列表获取表情包 - - 参数: - emoji_hash: 要查找的表情包哈希值 - 返回: - MaiEmoji 或 None: 如果找到则返回 MaiEmoji 对象,否则返回 None - """ - for emoji in self.emoji_objects: - # 确保对象未被标记为删除且哈希值匹配 - if not emoji.is_deleted and emoji.hash == emoji_hash: - return emoji - return None # 如果循环结束还没找到,则返回 None - - async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]: - """根据哈希值获取已注册表情包的情感标签列表 + 根据指定的表情包图片,分析并注册到数据库 Args: - emoji_hash: 表情包的哈希值 + filename (Path | str): 表情包图片的完整文件路径(可能根据文件实际格式修正) Returns: - Optional[List[str]]: 情感标签列表,如果未找到则返回None + return (bool): 注册是否成功 """ + file_full_path = Path(filename).absolute().resolve() + if not file_full_path.exists(): + logger.error(f"[注册表情包] 表情包文件不存在: {file_full_path}") + return False try: - # 先从内存中查找 - emoji = await self.get_emoji_from_manager(emoji_hash) - if emoji and emoji.emotion: - logger.info(f"[缓存命中] 从内存获取表情包情感标签: {emoji.emotion}...") - return emoji.emotion - - # 如果内存中没有,从数据库查找 - self._ensure_db() - try: - emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) - if emoji_record and emoji_record.emotion: - logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...") - return emoji_record.emotion.replace(",", ",").split(",") - except Exception as e: - logger.error(f"从数据库查询表情包情感标签时出错: {e}") - - return None - + target_emoji = MaiEmoji(full_path=file_full_path) except Exception as e: - logger.error(f"获取表情包情感标签失败 (Hash: {emoji_hash}): {str(e)}") - return None + logger.error(f"[注册表情包] 创建表情包对象时出错: {e}") + return False - async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]: - """根据哈希值获取已注册表情包的描述 + # 1. 计算哈希值和格式 + calc_success = await target_emoji.calculate_hash_format() + if not calc_success: + logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}") + return False + file_full_path = target_emoji.full_path # 更新为可能修正后的路径 + # 2. 检查是否已经存在过 + if existing_emoji := self.get_emoji_by_hash(target_emoji.emoji_hash): + logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}") + return False + # 3. 构建描述 + desc_success, target_emoji = await self.build_emoji_description(target_emoji) + if not desc_success: + logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}") + return False + # 4. 构建情感标签 + emo_success, target_emoji = await self.build_emoji_emotion(target_emoji) + if not emo_success: + logger.error(f"[注册表情包] 构建表情包情感标签失败: {file_full_path}") + return False - Args: - emoji_hash: 表情包的哈希值 - - Returns: - Optional[str]: 表情包描述,如果未找到则返回None - """ - try: - # 先从内存中查找 - emoji = await self.get_emoji_from_manager(emoji_hash) - if emoji and emoji.description: - logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...") - return emoji.description - - # 如果内存中没有,从数据库查找 - self._ensure_db() - try: - emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) - if emoji_record and emoji_record.description: - logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") - return emoji_record.description - except Exception as e: - logger.error(f"从数据库查询表情包描述时出错: {e}") - - return None - - except Exception as e: - logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") - return None - - async def delete_emoji(self, emoji_hash: str) -> bool: - """根据哈希值删除表情包 - - Args: - emoji_hash: 表情包的哈希值 - - Returns: - bool: 是否成功删除 - """ - try: - self._ensure_db() - - # 从emoji_objects中查找表情包对象 - emoji = await self.get_emoji_from_manager(emoji_hash) - - if not emoji: - logger.warning(f"[警告] 未找到哈希值为 {emoji_hash} 的表情包") + # 5. 检查容量并决定是否替换或者直接注册 + if self._emoji_num >= global_config.emoji.max_reg_num: + logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包") + replaced = await self.replace_an_emoji_by_llm(target_emoji) + if not replaced: + logger.error("[注册表情包] 替换表情包失败,无法注册新表情包") return False - - # 使用MaiEmoji对象的delete方法删除表情包 - success = await emoji.delete() - - if success: - # 从emoji_objects列表中移除该对象 - self.emoji_objects = [e for e in self.emoji_objects if e.hash != emoji_hash] - # 更新计数 - self.emoji_num -= 1 - logger.info(f"[统计] 当前表情包数量: {self.emoji_num}") - + return True + else: + if self.register_emoji_to_db(target_emoji): + self.emojis.append(target_emoji) + self._emoji_num += 1 + logger.info(f"[注册表情包] 成功注册新表情包: {target_emoji.file_name}") return True else: - logger.error(f"[错误] 删除表情包失败: {emoji_hash}") + logger.error(f"[注册表情包] 注册表情包到数据库失败: {file_full_path}") return False - except Exception as e: - logger.error(f"[错误] 删除表情包失败: {str(e)}") - logger.error(traceback.format_exc()) - return False - - async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool: - # sourcery skip: use-getitem-for-re-match-groups - """替换一个表情包 + def _calculate_emotion_similarity_list(self, text_emotion: str) -> List[Tuple[MaiEmoji, float]]: + """ + 计算文本情感标签与所有表情包情感标签的相似度列表 Args: - new_emoji: 新表情包对象 - + text_emotion (str): 文本的情感标签 Returns: - bool: 是否成功替换表情包 + return (List[Tuple[MaiEmoji, float]]): 返回表情包对象及其相似度的列表 """ - try: - self._ensure_db() - - # 获取所有表情包对象 - emoji_objects = self.emoji_objects - # 计算每个表情包的选择概率 - probabilities = [1 / (emoji.usage_count + 1) for emoji in emoji_objects] - # 归一化概率,确保总和为1 - total_probability = sum(probabilities) - normalized_probabilities = [p / total_probability for p in probabilities] - - # 使用概率分布选择最多20个表情包 - selected_emojis = random.choices( - emoji_objects, weights=normalized_probabilities, k=min(MAX_EMOJI_FOR_PROMPT, len(emoji_objects)) - ) - - # 将表情包信息转换为可读的字符串 - emoji_info_list = _emoji_objects_to_readable_list(selected_emojis) - - # 构建提示词 - prompt = ( - f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})," - f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n" - f"新表情包信息:\n" - f"描述: {new_emoji.description}\n\n" - f"现有表情包列表:\n" + "\n".join(emoji_info_list) + "\n\n" - "请决定:\n" - "1. 是否要删除某个现有表情包来为新表情包腾出空间?\n" - "2. 如果要删除,应该删除哪一个(给出编号)?\n" - "请只回答:'不删除'或'删除编号X'(X为表情包编号)。" - ) - - # 调用大模型进行决策 - decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8, max_tokens=600) - logger.info(f"[决策] 结果: {decision}") - - # 解析决策结果 - if "不删除" in decision: - logger.info("[决策] 不删除任何表情包") - return False - - if match := re.search(r"删除编号(\d+)", decision): - emoji_index = int(match.group(1)) - 1 # 转换为0-based索引 - - # 检查索引是否有效 - if 0 <= emoji_index < len(selected_emojis): - emoji_to_delete = selected_emojis[emoji_index] - - # 删除选定的表情包 - logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}") - delete_success = await self.delete_emoji(emoji_to_delete.hash) - - if delete_success: - # 修复:等待异步注册完成 - register_success = await new_emoji.register_to_db() - if register_success: - self.emoji_objects.append(new_emoji) - self.emoji_num += 1 - logger.info(f"[成功] 注册: {new_emoji.filename}") - return True - else: - logger.error(f"[错误] 注册表情包到数据库失败: {new_emoji.filename}") - return False - else: - logger.error("[错误] 删除表情包失败,无法完成替换") - return False - else: - logger.error(f"[错误] 无效的表情包编号: {emoji_index + 1}") - else: - logger.error(f"[错误] 无法从决策中提取表情包编号: {decision}") - - return False - - except Exception as e: - logger.error(f"[错误] 替换表情包失败: {str(e)}") - logger.error(traceback.format_exc()) - return False - - async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]: - """获取表情包描述和情感列表,优化复用已有描述 - - Args: - image_base64: 图片的base64编码 - - Returns: - Tuple[str, list]: 返回表情包描述和情感列表 - """ - try: - # 解码图片并获取格式 - # 确保base64字符串只包含ASCII字符 - if isinstance(image_base64, str): - image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") - image_bytes = base64.b64decode(image_base64) - image_hash = hashlib.md5(image_bytes).hexdigest() - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore - - # 尝试从 EmojiDescriptionCache 表获取已有的详细描述 - existing_description = None - try: - cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash) - if cache_record and cache_record.description: - existing_description = cache_record.description - logger.info(f"[复用描述] 表情描述缓存命中: {existing_description[:50]}...") - except Exception as e: - logger.debug(f"查询表情描述缓存时出错: {e}") - - # 第一步:VLM视觉分析(如果没有已有描述才调用) - if existing_description: - description = existing_description - logger.info("[优化] 复用已有的详细描述,跳过VLM调用") - else: - logger.info("[VLM分析] 生成新的详细描述") - if image_format in ["gif", "GIF"]: - image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore - if not image_base64: - raise RuntimeError("GIF表情包转换失败") - prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗,meme的角度去分析,精简回答" - description, _ = await self.vlm.generate_response_for_image( - prompt, image_base64, "jpg", temperature=0.5 - ) - else: - prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析,精简回答" - description, _ = await self.vlm.generate_response_for_image( - prompt, image_base64, image_format, temperature=0.5 - ) - - # 若是新生成的描述,写入缓存表(此时还没有情感标签,稍后会更新) - if not existing_description: - try: - cache_record, created = EmojiDescriptionCache.get_or_create( - emoji_hash=image_hash, - defaults={"description": description, "timestamp": time.time()}, - ) - if not created: - # 更新描述,但保留已有的情感标签(如果有) - cache_record.description = description - cache_record.timestamp = time.time() - cache_record.save() - except Exception as cache_error: - logger.debug(f"写入表情描述缓存失败: {cache_error}") - - # 审核表情包 - if global_config.emoji.content_filtration: - prompt = f''' - 这是一个表情包,请对这个表情包进行审核,标准如下: - 1. 必须符合"{global_config.emoji.filtration_prompt}"的要求 - 2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗 - 3. 不能是任何形式的截图,聊天记录或视频截图 - 4. 不要出现5个以上文字 - 请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容 - ''' - content, _ = await self.vlm.generate_response_for_image( - prompt, image_base64, image_format, temperature=0.3, max_tokens=1000 - ) - if content == "否": - return "", [] - - # 第二步:LLM情感分析 - 基于详细描述生成情感标签列表 - emotion_prompt = f""" -这是一个聊天场景中的表情包描述:'{description}' - -请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字 -你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析 -请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔 - """ - emotions_text, _ = await self.llm_emotion_judge.generate_response_async( - emotion_prompt, temperature=0.7, max_tokens=256 - ) - - # 处理情感列表 - emotions = [e.strip() for e in emotions_text.replace(",", ",").split(",") if e.strip()] - - # 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个 - if len(emotions) > 5: - emotions = random.sample(emotions, 3) - elif len(emotions) > 2: - emotions = random.sample(emotions, 2) - - logger.info(f"[注册分析] 详细描述: {description[:50]}... -> 情感标签: {emotions}") - - # 将情感标签列表转换为逗号分隔的字符串 - emotion_tags_str = ",".join(emotions) - - # 更新EmojiDescriptionCache,保存情感标签 - try: - cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash) - if cache_record: - # 更新已有记录的情感标签 - cache_record.emotion_tags = emotion_tags_str - cache_record.timestamp = time.time() - cache_record.save() - logger.info(f"[缓存更新] 表情包情感标签已更新到EmojiDescriptionCache: {image_hash[:8]}...") - else: - # 如果缓存不存在,创建新记录(包含描述和情感标签) - EmojiDescriptionCache.create( - emoji_hash=image_hash, - description=description, - emotion_tags=emotion_tags_str, - timestamp=time.time(), - ) - logger.info(f"[缓存创建] 表情包描述和情感标签已保存到EmojiDescriptionCache: {image_hash[:8]}...") - except Exception as cache_error: - logger.debug(f"更新表情包情感标签缓存失败: {cache_error}") - - return f"[表情包:{description}]", emotions - - except Exception as e: - logger.error(f"获取表情包描述失败: {str(e)}") - return "", [] - - async def register_emoji_by_filename(self, filename: str) -> bool: - """读取指定文件名的表情包图片,分析并注册到数据库 - - Args: - filename: 表情包文件名,必须位于EMOJI_DIR目录下 - - Returns: - bool: 注册是否成功 - """ - file_full_path = os.path.join(EMOJI_DIR, filename) - if not os.path.exists(file_full_path): - logger.error(f"[注册失败] 文件不存在: {file_full_path}") - return False - - try: - # 1. 创建 MaiEmoji 实例并初始化哈希和格式 - new_emoji = MaiEmoji(full_path=file_full_path) - init_result = await new_emoji.initialize_hash_format() - if init_result is None or new_emoji.is_deleted: # 初始化失败或文件读取错误 - logger.error(f"[注册失败] 初始化哈希和格式失败: {filename}") - # 是否需要删除源文件?看业务需求,暂时不删 - return False - - # 2. 检查哈希是否已存在 (在内存中检查) - if await self.get_emoji_from_manager(new_emoji.hash): - logger.warning(f"[注册跳过] 表情包已存在 (Hash: {new_emoji.hash}): {filename}") - # 删除重复的源文件 - try: - os.remove(file_full_path) - logger.info(f"[清理] 删除重复的待注册文件: {filename}") - except Exception as e: - logger.error(f"[错误] 删除重复文件失败: {str(e)}") - return False # 返回 False 表示未注册新表情 - - # 3. 构建描述和情感 - try: - emoji_base64 = image_path_to_base64(file_full_path) - if emoji_base64 is None: # 再次检查读取 - logger.error(f"[注册失败] 无法读取图片以生成描述: {filename}") - return False - description, emotions = await self.build_emoji_description(emoji_base64) - if not description: # 检查描述是否成功生成或审核通过 - logger.warning(f"[注册失败] 未能生成有效描述或审核未通过: {filename}") - # 删除未能生成描述的文件 - try: - os.remove(file_full_path) - logger.info(f"[清理] 删除描述生成失败的文件: {filename}") - except Exception as e: - logger.error(f"[错误] 删除描述生成失败文件时出错: {str(e)}") - return False - new_emoji.description = description - new_emoji.emotion = emotions - except Exception as build_desc_error: - logger.error(f"[注册失败] 生成描述/情感时出错 ({filename}): {build_desc_error}") - # 同样考虑删除文件 - try: - os.remove(file_full_path) - logger.info(f"[清理] 删除描述生成异常的文件: {filename}") - except Exception as e: - logger.error(f"[错误] 删除描述生成异常文件时出错: {str(e)}") - return False - - # 4. 检查容量并决定是否替换或直接注册 - if self.emoji_num >= self.emoji_num_max: - logger.warning(f"表情包数量已达到上限({self.emoji_num}/{self.emoji_num_max}),尝试替换...") - replaced = await self.replace_a_emoji(new_emoji) - if not replaced: - logger.error("[注册失败] 替换表情包失败,无法完成注册") - # 替换失败,删除新表情包文件 - try: - os.remove(file_full_path) # new_emoji 的 full_path 此时还是源路径 - logger.info(f"[清理] 删除替换失败的新表情文件: {filename}") - except Exception as e: - logger.error(f"[错误] 删除替换失败文件时出错: {str(e)}") - return False - # 替换成功时,replace_a_emoji 内部已处理 new_emoji 的注册和添加到列表 - return True - else: - # 直接注册 - register_success = await new_emoji.register_to_db() # 此方法会移动文件并更新 DB - if register_success: - # 注册成功后,添加到内存列表 - self.emoji_objects.append(new_emoji) - self.emoji_num += 1 - logger.info(f"[成功] 注册新表情包: {filename} (当前: {self.emoji_num}/{self.emoji_num_max})") - return True - else: - logger.error(f"[注册失败] 保存表情包到数据库/移动文件失败: {filename}") - # register_to_db 失败时,内部会尝试清理移动后的文件,源文件可能还在 - # 是否需要删除源文件? - if os.path.exists(file_full_path): - try: - os.remove(file_full_path) - logger.info(f"[清理] 删除注册失败的源文件: {filename}") - except Exception as e: - logger.error(f"[错误] 删除注册失败源文件时出错: {str(e)}") - return False - - except Exception as e: - logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {str(e)}") - logger.error(traceback.format_exc()) - # 尝试删除源文件以避免循环处理 - if os.path.exists(file_full_path): - try: - os.remove(file_full_path) - logger.info(f"[清理] 删除处理异常的源文件: {filename}") - except Exception as remove_error: - logger.error(f"[错误] 删除异常处理文件时出错: {remove_error}") - return False + similarity_list: List[Tuple[MaiEmoji, float]] = [] + for emoji in self.emojis: + if not emoji.emotion: + continue + # 计算情感标签相似度,使用 Levenshtein 距离作为相似度指标 + distance = Levenshtein.distance(text_emotion, emoji.emotion) + max_len = max(len(text_emotion), len(emoji.emotion)) + similarity = 1 - (distance / max_len if max_len > 0 else 0) + similarity_list.append((emoji, similarity)) + return similarity_list -emoji_manager = None - - -def get_emoji_manager(): - global emoji_manager - if emoji_manager is None: - emoji_manager = EmojiManager() - return emoji_manager +emoji_manager = EmojiManager() diff --git a/src/common/data_models/image_data_model.py b/src/common/data_models/image_data_model.py new file mode 100644 index 00000000..e6062b66 --- /dev/null +++ b/src/common/data_models/image_data_model.py @@ -0,0 +1,165 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path +from PIL import Image as PILImage +from rich.traceback import install +from typing import Optional, List + +import asyncio +import hashlib +import io +import traceback + +from src.common.database.database_model import Images, ImageType +from src.common.logger import get_logger + + +install(extra_lines=3) + +logger = get_logger("emoji") + + +class BaseImageDataModel(ABC): + @classmethod + @abstractmethod + def from_db_instance(cls, image: "Images"): + raise NotImplementedError + + @abstractmethod + def to_db_instance(self) -> "Images": + raise NotImplementedError + + def read_image_bytes(self, path: Path) -> bytes: + """ + 同步读取图片文件的字节内容 + + Args: + path (Path): 图片文件的完整路径 + Returns: + return (bytes): 图片文件的字节内容 + Raises: + FileNotFoundError: 如果文件不存在则抛出该异常 + Exception: 其他读取文件时发生的异常 + """ + try: + with open(path, "rb") as f: + return f.read() + except FileNotFoundError as e: + logger.error(f"[读取图片文件] 文件未找到: {path}") + raise e + except Exception as e: + logger.error(f"[读取图片文件] 读取文件时发生错误: {e}") + raise e + + def get_image_format(self, image_bytes: bytes) -> str: + """ + 获取图片的格式 + + Args: + image_bytes (bytes): 图片的字节内容 + + Returns: + return (str): 图片的格式(小写) + + Raises: + ValueError: 如果无法识别图片格式 + Exception: 其他读取图片格式时发生的异常 + """ + try: + with PILImage.open(io.BytesIO(image_bytes)) as img: + if not img.format: + raise ValueError("无法识别图片格式") + return img.format.lower() + except Exception as e: + logger.error(f"[获取图片格式] 读取图片格式时发生错误: {e}") + raise e + + +class ImageDataModel(BaseImageDataModel): + pass + + +class MaiEmoji(BaseImageDataModel): + def __init__(self, full_path: str | Path): + if not full_path: + # 创建时候即检测文件路径合法性 + raise ValueError("表情包路径不能为空") + if Path(full_path).is_dir() or not Path(full_path).exists(): + raise FileNotFoundError(f"表情包路径无效: {full_path}") + resolved_path = Path(full_path).absolute().resolve() + self.full_path: Path = resolved_path + self.dir_path: Path = resolved_path.parent.resolve() + self.file_name: str = resolved_path.name + # self.embedding = [] + self.emoji_hash: str = None # type: ignore + self.description = "" + self.emotion: List[str] = [] + self.query_count = 0 + self.register_time: Optional[datetime] = None + self.last_used_time: Optional[datetime] = None + + # 私有属性 + self.is_deleted = False + self._format: str = "" # 图片格式 + + @classmethod + def from_db_instance(cls, image: Images): + obj = cls(image.full_path) + obj.emoji_hash = image.image_hash + obj.description = image.description + if image.emotion: + obj.emotion = image.emotion.split(",") + obj.query_count = image.query_count + obj.last_used_time = image.last_used_time + obj.register_time = image.register_time + return obj + + def to_db_instance(self) -> Images: + emotion_str = ",".join(self.emotion) if self.emotion else None + return Images( + image_hash=self.emoji_hash, + description=self.description, + full_path=str(self.full_path), + image_type=ImageType.EMOJI, + emotion=emotion_str, + query_count=self.query_count, + last_used_time=self.last_used_time, + register_time=self.register_time, + ) + + async def calculate_hash_format(self) -> bool: + """ + 异步计算表情包的哈希值和格式 + + Returns: + return (bool): 如果成功计算哈希值和格式则返回True,否则返回False + """ + logger.debug(f"[初始化] 正在读取文件: {self.full_path}") + try: + # 计算哈希值 + logger.debug(f"[初始化] 计算 {self.file_name} 的哈希值...") + image_bytes = await asyncio.to_thread(self.read_image_bytes, self.full_path) + self.emoji_hash = hashlib.sha256(image_bytes).hexdigest() + logger.debug(f"[初始化] {self.file_name} 计算哈希值成功: {self.emoji_hash}") + + # 用PIL读取图片格式 + logger.debug(f"[初始化] 读取 {self.file_name} 的图片格式...") + self._format = await asyncio.to_thread(self.get_image_format, image_bytes) + logger.debug(f"[初始化] {self.file_name} 读取图片格式成功: {self._format}") + + # 比对文件扩展名和实际格式 + file_ext = self.file_name.split(".")[-1].lower() + if file_ext != self._format: + logger.warning(f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self._format}`") + # 重命名文件以匹配实际格式 + new_file_name = ".".join(self.file_name.split(".")[:-1] + [self._format]) + new_full_path = self.dir_path / new_file_name + self.full_path.rename(new_full_path) + self.full_path = new_full_path + + return True + except Exception as e: + logger.error(f"[初始化] 初始化表情包时发生错误: {e}") + logger.error(traceback.format_exc()) + self.is_deleted = True + return False diff --git a/src/common/utils/utils_image.py b/src/common/utils/utils_image.py new file mode 100644 index 00000000..944e53a0 --- /dev/null +++ b/src/common/utils/utils_image.py @@ -0,0 +1,104 @@ +from PIL import Image as PILImage, ImageSequence + +import base64 +import io +import numpy as np + +from src.common.logger import get_logger + +logger = get_logger("image") + + +class ImageUtils: + @staticmethod + def gif_2_static_image(gif_bytes: bytes, similarity_threshold: float = 1000.0, max_frames: int = 15) -> bytes: + """ + 将GIF图片水平拼接为静态图像,跳过相似帧 + + Args: + gif_bytes (bytes): 输入的GIF图片字节数据 + similarity_threshold (float): 判定帧相似的阈值 (MSE),越小表示要求差异越大才算不同帧,默认1000.0 + max_frames (int): 最大抽取的帧数,默认15 + Returns: + bytes: 拼接后的静态图像字节数据,格式为JPEG + Raises: + ValueError: 如果输入的GIF无效或无法处理 + MemoryError: 如果处理过程中内存不足 + Exception: 其他异常 + """ + with PILImage.open(io.BytesIO(gif_bytes)) as gif_image: + if not gif_image.format or gif_image.format.lower() != "gif": + logger.error("输入的图片不是有效的GIF格式") + raise ValueError("输入的图片不是有效的GIF格式") + # --- 流式迭代并选择帧(避免一次性加载所有帧) --- + selected_frames: list[PILImage.Image] = [] + last_selected_frame_np = None + frame_index = 0 + + for frame in ImageSequence.Iterator(gif_image): + # 确保是RGB格式方便比较 + frame_rgb = frame.convert("RGB") + frame_np = np.array(frame_rgb) + + if frame_index == 0: + selected_frames.append(frame_rgb.copy()) + last_selected_frame_np = frame_np + else: + # 计算和上一张选中帧的差异(均方误差 MSE) + mse = np.mean((frame_np - last_selected_frame_np) ** 2) + # logger.debug(f"帧 {frame_index} 与上一选中帧的 MSE: {mse}") + if mse > similarity_threshold: + selected_frames.append(frame_rgb.copy()) + last_selected_frame_np = frame_np + if len(selected_frames) >= max_frames: + break + frame_index += 1 + + if not selected_frames: + logger.error("未能抽取到任何有效帧") + raise ValueError("未能抽取到任何有效帧") + + # 获取选中的第一帧的尺寸(假设所有帧尺寸一致) + frame_width, frame_height = selected_frames[0].size + # 防止除以零 + if frame_height == 0: + raise ValueError("帧高度为0,无法计算缩放尺寸") + + # 计算目标尺寸,保持宽高比 + target_height = 200 # 固定高度 + target_width = int((target_height / frame_height) * frame_width) + # 宽度也不能是0 + if target_width == 0: + logger.warning(f"计算出的目标宽度为0 (原始尺寸 {frame_width}x{frame_height}),调整为1") + target_width = 1 + # 调整所有选中帧的大小 + resized_frames = [ + frame.resize((target_width, target_height), PILImage.Resampling.LANCZOS) for frame in selected_frames + ] + + # 创建拼接图像 + total_width = target_width * len(resized_frames) + combined_image = PILImage.new("RGB", (total_width, target_height)) + # 水平拼接图像 + for idx, frame in enumerate(resized_frames): + combined_image.paste(frame, (idx * target_width, 0)) + buffer = io.BytesIO() + combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG + return buffer.getvalue() + + @staticmethod + def image_bytes_to_base64(image_bytes: bytes) -> str: + """ + 将图片字节数据转换为Base64编码字符串 + + Args: + image_bytes (bytes): 输入的图片字节数据 + Returns: + str: Base64编码的图片字符串 + Raises: + ValueError: 如果输入的图片字节数据无效 + """ + if not image_bytes: + logger.error("输入的图片字节数据无效") + raise ValueError("输入的图片字节数据无效") + return base64.b64encode(image_bytes).decode("utf-8")