From 1c0580c5770a4d2c4a1b0572d5d3abbf9493c876 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 19 Feb 2026 19:02:44 +0800 Subject: [PATCH] =?UTF-8?q?ImageManager=E5=8F=8A=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/image_sys_test/image_manager_test.py | 200 ++++++++++++++ src/chat/image_system/image_manager.py | 264 +++++++++++++++++++ src/common/data_models/image_data_model.py | 4 + 3 files changed, 468 insertions(+) create mode 100644 pytests/image_sys_test/image_manager_test.py create mode 100644 src/chat/image_system/image_manager.py diff --git a/pytests/image_sys_test/image_manager_test.py b/pytests/image_sys_test/image_manager_test.py new file mode 100644 index 00000000..360ba50c --- /dev/null +++ b/pytests/image_sys_test/image_manager_test.py @@ -0,0 +1,200 @@ +import sys +import types +import importlib +import pytest +from pathlib import Path +import importlib.util +import asyncio + + +class DummyLogger: + def info(self, *a, **k): + pass + + def warning(self, *a, **k): + pass + + def error(self, *a, **k): + pass + + +class DummySession: + def exec(self, *a, **k): + class R: + def first(self): + return None + + def yield_per(self, n): + return iter(()) + + return R() + + def add(self, *a, **k): + pass + + def flush(self, *a, **k): + pass + + def delete(self, *a, **k): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class DummyMaiImage: + def __init__(self, full_path=None, image_bytes=None): + self.full_path = full_path + self.image_bytes = image_bytes + self.image_format = "png" + self.description = "" + self.vlm_processed = False + + @classmethod + def from_db_instance(cls, record): + return cls() + + def to_db_instance(self): + return types.SimpleNamespace(id=1, full_path=str(self.full_path) if self.full_path is not None else "") + + async def calculate_hash_format(self): + return None + + +class DummyLLMRequest: + def __init__(self, *a, **k): + pass + + async def generate_response_for_image(self, prompt, image_base64, image_format, temp): + return ("dummy description", {}) + +class DummySelect: + def __init__(self, *a, **k): + pass + + def filter_by(self, *a, **k): + return self + + def limit(self, n): + return self + +@pytest.fixture(autouse=True) +def patch_external_dependencies(monkeypatch): + # Provide dummy implementations as modules so that importing image_manager is safe + # Patch LLMRequest + llm_mod = types.SimpleNamespace(LLMRequest=DummyLLMRequest) + monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_mod) + + # Patch logger + logger_mod = types.SimpleNamespace(get_logger=lambda name: DummyLogger()) + monkeypatch.setitem(sys.modules, "src.common.logger", logger_mod) + + # Patch DB session provider + db_mod = types.SimpleNamespace(get_db_session=lambda: DummySession()) + monkeypatch.setitem(sys.modules, "src.common.database.database", db_mod) + + # Patch database model types + db_model_mod = types.SimpleNamespace(Images=types.SimpleNamespace, ImageType=types.SimpleNamespace(IMAGE="image")) + monkeypatch.setitem(sys.modules, "src.common.database.database_model", db_model_mod) + + # Patch MaiImage data model + data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage) + monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod) + + # Patch SQLModel select function + sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect()) + monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod) + + # Patch config values used at import-time + cfg = types.SimpleNamespace(personality=types.SimpleNamespace(visual_style="test-style")) + model_cfg = types.SimpleNamespace(model_task_config=types.SimpleNamespace(vlm="test-vlm")) + config_mod = types.SimpleNamespace(global_config=cfg, model_config=model_cfg) + monkeypatch.setitem(sys.modules, "src.config.config", config_mod) + + # If module already imported, reload it to apply patches + mod_name = "src.chat.image_system.image_manager" + if mod_name in sys.modules: + importlib.reload(sys.modules[mod_name]) + + yield + + +def _load_image_manager_module(tmp_path=None): + repo_root = Path(__file__).parent.parent.parent + file_path = repo_root / "src" / "chat" / "image_system" / "image_manager.py" + spec = importlib.util.spec_from_file_location("image_manager_test_loaded", str(file_path)) + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod + spec.loader.exec_module(mod) + # Redirect IMAGE_DIR to pytest's tmp_path when provided + try: + if tmp_path is not None: + tmpdir = Path(tmp_path) + tmpdir.mkdir(parents=True, exist_ok=True) + setattr(mod, "IMAGE_DIR", tmpdir) + except Exception: + pass + return mod + + +@pytest.mark.asyncio +async def test_get_image_description_generates(tmp_path): + image_manager = _load_image_manager_module(tmp_path) + + mgr = image_manager.ImageManager() + desc = await mgr.get_image_description(image_bytes=b"abc") + assert desc == "dummy description" + + +def test_get_image_from_db_none(tmp_path): + image_manager = _load_image_manager_module(tmp_path) + + mgr = image_manager.ImageManager() + assert mgr.get_image_from_db("nohash") is None + + +def test_register_image_to_db(tmp_path): + image_manager = _load_image_manager_module(tmp_path) + + mgr = image_manager.ImageManager() + p = tmp_path / "img.png" + p.write_bytes(b"data") + img = DummyMaiImage(full_path=p, image_bytes=b"data") + assert mgr.register_image_to_db(img) is True + + +def test_update_image_description_not_found(tmp_path): + image_manager = _load_image_manager_module(tmp_path) + + mgr = image_manager.ImageManager() + img = DummyMaiImage() + img.file_hash = "nohash" + img.description = "desc" + assert mgr.update_image_description(img) is False + + +def test_delete_image_not_found(tmp_path): + image_manager = _load_image_manager_module(tmp_path) + + mgr = image_manager.ImageManager() + img = DummyMaiImage() + img.file_hash = "nohash" + img.full_path = tmp_path = None + assert mgr.delete_image(img) is False + + +@pytest.mark.asyncio +async def test_save_image_and_process_and_cleanup(tmp_path): + image_manager = _load_image_manager_module(tmp_path) + + mgr = image_manager.ImageManager() + # call save_image_and_process + image = await mgr.save_image_and_process(b"binarydata") + assert getattr(image, "description", None) == "dummy description" + + # cleanup should run without error + mgr.cleanup_invalid_descriptions_in_db() + diff --git a/src/chat/image_system/image_manager.py b/src/chat/image_system/image_manager.py new file mode 100644 index 00000000..92f5168b --- /dev/null +++ b/src/chat/image_system/image_manager.py @@ -0,0 +1,264 @@ +from datetime import datetime +from pathlib import Path +from rich.traceback import install +from sqlmodel import select +from typing import Optional + +import base64 +import hashlib + +from src.common.logger import get_logger +from src.common.database.database import get_db_session +from src.common.database.database_model import Images, ImageType +from src.common.data_models.image_data_model import MaiImage +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest + +install(extra_lines=3) + +PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve() +DATA_DIR = PROJECT_ROOT / "data" +IMAGE_DIR = DATA_DIR / "images" + +logger = get_logger("image") + + +def _ensure_image_dir_exists(): + IMAGE_DIR.mkdir(parents=True, exist_ok=True) + + +vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image") + + +class ImageManager: + def __init__(self): + _ensure_image_dir_exists() + + logger.info("图片管理器初始化完成") + + async def get_image_description(self, image_hash: Optional[str] = None, image_bytes: Optional[bytes] = None) -> str: + """ + 获取图片描述的封装方法 + + 如果图片已存在于数据库中,则直接返回描述 + + 如果不存在,则**保存图片**并**生成描述**后返回 + + Args: + image_hash (Optional[str]): 图片的哈希值,如果提供则优先使用该 + image_bytes (Optional[bytes]): 图片的字节数据,如果提供则在数据库中找不到哈希值时使用该数据生成描述 + Returns: + return (str): 图片描述,如果发生错误或无法生成描述则返回空字符串 + Raises: + ValueError: 如果未提供有效的图片哈希值或图片字节数据 + Exception: 在查询数据库、保存图片或生成描述过程中发生的其他异常 + """ + if image_hash: + hash_str = image_hash + elif not image_bytes: + raise ValueError("必须提供图片哈希值或图片字节数据") + else: + hash_str = hashlib.sha256(image_bytes).hexdigest() + + try: + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1) + if record := session.exec(statement).first(): + return record.description + except Exception as e: + logger.error(f"查询图片描述时发生错误: {e}") + + if not image_bytes: + logger.warning("图片哈希值未找到,且未提供图片字节数据,返回无描述") + return "" + logger.info(f"图片描述未找到,哈希值: {hash_str},准备生成新描述") + try: + image = await self.save_image_and_process(image_bytes) + return image.description + except Exception as e: + logger.error(f"生成图片描述时发生错误: {e}") + return "" + + def get_image_from_db(self, image_hash: str) -> Optional[MaiImage]: + """ + 从数据库中根据图片哈希值获取图片记录 + + """ + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1) + if record := session.exec(statement).first(): + if record.no_file_flag: + logger.warning(f"数据库记录标记为文件不存在,哈希值: {image_hash}") + return None + return MaiImage.from_db_instance(record) + logger.info(f"未找到哈希值为 {image_hash} 的图片记录") + return None + + def register_image_to_db(self, image: MaiImage) -> bool: + """ + 将图片对象注册到数据库中 + Args: + image (MaiImage): 包含图片信息的 MaiImage 对象,必须包含有效的 full_path 和 image_format + Returns: + return (bool): 注册成功返回 True,失败返回 False + """ + # sourcery skip: extract-method + if not image or not isinstance(image, MaiImage): + logger.error("无效的图片对象,无法注册到数据库") + return False + if not image.full_path.exists(): + logger.error(f"图片文件不存在,无法注册到数据库: {image.full_path}") + return False + + try: + with get_db_session() as session: + record = image.to_db_instance() + record.is_registered = True + record.register_time = record.last_used_time = datetime.now() + session.add(record) + session.flush() # 确保记录被写入数据库以获取ID + record_id = record.id + logger.info(f"成功保存图片记录到数据库: ID: {record_id},路径: {record.full_path}") + except Exception as e: + logger.error(f"保存图片记录到数据库时发生错误: {e}") + return False + return True + + def update_image_description(self, image: MaiImage) -> bool: + """ + 更新图片描述 + + Args: + image (MaiImage): 包含新描述的图片对象,必须包含有效的 file_hash 和 full_path + Returns: + return (bool): 更新成功返回 True,失败返回 False + """ + try: + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=image.file_hash, image_type=ImageType.IMAGE).limit(1) + record = session.exec(statement).first() + if not record: + logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法更新描述") + return False + record.description = image.description + record.last_used_time = datetime.now() + session.add(record) + logger.info(f"成功更新图片描述: {image.file_hash},新描述: {image.description}") + except Exception as e: + logger.error(f"更新图片描述时发生错误: {e}") + return False + return True + + def delete_image(self, image: MaiImage) -> bool: + """ + 删除图片记录和对应的文件 + + Args: + image (MaiImage): 包含要删除图片信息的对象,必须包含有效的 file_hash 和 full_path + Returns: + return (bool): 删除成功返回 True,失败返回 False + """ + try: + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=image.file_hash, image_type=ImageType.IMAGE).limit(1) + record = session.exec(statement).first() + if not record: + logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法删除") + return False + session.delete(record) + logger.info(f"成功删除图片记录: {image.file_hash}") + + if image.full_path.exists(): + image.full_path.unlink() + logger.info(f"成功删除图片文件: {image.full_path}") + else: + logger.warning(f"图片文件不存在,无法删除: {image.full_path}") + except Exception as e: + logger.error(f"删除图片时发生错误: {e}") + if image.full_path.exists(): + logger.warning(f"图片文件未被删除: {image.full_path}") + return False + return True + + async def save_image_and_process(self, image_bytes: bytes) -> MaiImage: + """ + 保存图片并生成描述 + + Args: + image_bytes (bytes): 图片的字节数据 + Returns: + return (MaiImage): 包含图片信息的 MaiImage 对象 + Raises: + Exception: 如果在保存或处理过程中发生错误 + """ + hash_str = hashlib.sha256(image_bytes).hexdigest() + + try: + with get_db_session() as session: + statement = select(Images).filter_by(image_hash=hash_str).limit(1) + if record := session.exec(statement).first(): + logger.info(f"图片已存在于数据库中,哈希值: {hash_str}") + record.last_used_time = datetime.now() + record.query_count += 1 + session.add(record) + session.flush() + return MaiImage.from_db_instance(record) + except Exception as e: + logger.error(f"查询图片记录时发生错误: {e}") + raise e + + logger.info(f"图片不存在于数据库中,准备保存新图片,哈希值: {hash_str}") + tmp_file_path = IMAGE_DIR / f"{hash_str}.tmp" + with tmp_file_path.open("wb") as f: + f.write(image_bytes) + mai_image = MaiImage(full_path=(IMAGE_DIR / f"{hash_str}.tmp"), image_bytes=image_bytes) + await mai_image.calculate_hash_format() + desc = await self._generate_image_description(image_bytes, mai_image.image_format) + mai_image.description = desc + mai_image.vlm_processed = True + try: + self.register_image_to_db(mai_image) + except Exception as e: + logger.error(f"保存新图片记录到数据库时发生错误: {e}") + raise e + return mai_image + + def cleanup_invalid_descriptions_in_db(self): + """ + 清理数据库中无效的图片记录 + + 无效的判定:`description` 为空或仅包含空白字符,或者文件路径不存在 + """ + invalid_values = {"", None} + invalid_counter: int = 0 + null_path_counter: int = 0 + logger.info("开始清理数据库中无效的图片记录...") + + try: + with get_db_session() as session: + for record in session.exec(select(Images)).yield_per(100): + if record.description in invalid_values: + if record.full_path and Path(record.full_path).exists(): + try: + Path(record.full_path).unlink() + logger.info(f"已删除无效描述的图片文件: {record.full_path}") + except Exception as e: + logger.error(f"删除无效描述的图片文件时发生错误: {e}") + session.delete(record) + invalid_counter += 1 + elif record.full_path and not Path(record.full_path).exists(): + session.delete(record) + null_path_counter += 1 + except Exception as e: + logger.error(f"清理数据库中无效图片记录时发生错误: {e}") + + logger.info(f"清理完成: {invalid_counter} 条无效描述记录,{null_path_counter} 条文件路径不存在记录") + + async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str: + prompt = global_config.personality.visual_style + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + + description, _ = await vlm.generate_response_for_image(prompt, image_base64, image_format, 0.4) + if not description: + logger.warning("VLM未能生成图片描述") + return description or "" diff --git a/src/common/data_models/image_data_model.py b/src/common/data_models/image_data_model.py index 529a6423..9481e30b 100644 --- a/src/common/data_models/image_data_model.py +++ b/src/common/data_models/image_data_model.py @@ -124,6 +124,8 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]): class MaiEmoji(BaseImageDataModel): + """麦麦的表情包对象,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象""" + def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None): # self.embedding = [] self.description: str = "" @@ -173,6 +175,8 @@ class MaiEmoji(BaseImageDataModel): class MaiImage(BaseImageDataModel): + """麦麦图片数据模型,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象""" + def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None): self.description: str = "" self.vlm_processed: bool = False