ImageManager及测试

pull/1496/head
UnCLAS-Prommer 2026-02-19 19:02:44 +08:00
parent 0a572515ba
commit 1c0580c577
No known key found for this signature in database
3 changed files with 468 additions and 0 deletions

View File

@ -0,0 +1,200 @@
import sys
import types
import importlib
import pytest
from pathlib import Path
import importlib.util
import asyncio
class DummyLogger:
def info(self, *a, **k):
pass
def warning(self, *a, **k):
pass
def error(self, *a, **k):
pass
class DummySession:
def exec(self, *a, **k):
class R:
def first(self):
return None
def yield_per(self, n):
return iter(())
return R()
def add(self, *a, **k):
pass
def flush(self, *a, **k):
pass
def delete(self, *a, **k):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class DummyMaiImage:
def __init__(self, full_path=None, image_bytes=None):
self.full_path = full_path
self.image_bytes = image_bytes
self.image_format = "png"
self.description = ""
self.vlm_processed = False
@classmethod
def from_db_instance(cls, record):
return cls()
def to_db_instance(self):
return types.SimpleNamespace(id=1, full_path=str(self.full_path) if self.full_path is not None else "")
async def calculate_hash_format(self):
return None
class DummyLLMRequest:
def __init__(self, *a, **k):
pass
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
return ("dummy description", {})
class DummySelect:
def __init__(self, *a, **k):
pass
def filter_by(self, *a, **k):
return self
def limit(self, n):
return self
@pytest.fixture(autouse=True)
def patch_external_dependencies(monkeypatch):
# Provide dummy implementations as modules so that importing image_manager is safe
# Patch LLMRequest
llm_mod = types.SimpleNamespace(LLMRequest=DummyLLMRequest)
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_mod)
# Patch logger
logger_mod = types.SimpleNamespace(get_logger=lambda name: DummyLogger())
monkeypatch.setitem(sys.modules, "src.common.logger", logger_mod)
# Patch DB session provider
db_mod = types.SimpleNamespace(get_db_session=lambda: DummySession())
monkeypatch.setitem(sys.modules, "src.common.database.database", db_mod)
# Patch database model types
db_model_mod = types.SimpleNamespace(Images=types.SimpleNamespace, ImageType=types.SimpleNamespace(IMAGE="image"))
monkeypatch.setitem(sys.modules, "src.common.database.database_model", db_model_mod)
# Patch MaiImage data model
data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage)
monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod)
# Patch SQLModel select function
sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect())
monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod)
# Patch config values used at import-time
cfg = types.SimpleNamespace(personality=types.SimpleNamespace(visual_style="test-style"))
model_cfg = types.SimpleNamespace(model_task_config=types.SimpleNamespace(vlm="test-vlm"))
config_mod = types.SimpleNamespace(global_config=cfg, model_config=model_cfg)
monkeypatch.setitem(sys.modules, "src.config.config", config_mod)
# If module already imported, reload it to apply patches
mod_name = "src.chat.image_system.image_manager"
if mod_name in sys.modules:
importlib.reload(sys.modules[mod_name])
yield
def _load_image_manager_module(tmp_path=None):
repo_root = Path(__file__).parent.parent.parent
file_path = repo_root / "src" / "chat" / "image_system" / "image_manager.py"
spec = importlib.util.spec_from_file_location("image_manager_test_loaded", str(file_path))
mod = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = mod
spec.loader.exec_module(mod)
# Redirect IMAGE_DIR to pytest's tmp_path when provided
try:
if tmp_path is not None:
tmpdir = Path(tmp_path)
tmpdir.mkdir(parents=True, exist_ok=True)
setattr(mod, "IMAGE_DIR", tmpdir)
except Exception:
pass
return mod
@pytest.mark.asyncio
async def test_get_image_description_generates(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
desc = await mgr.get_image_description(image_bytes=b"abc")
assert desc == "dummy description"
def test_get_image_from_db_none(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
assert mgr.get_image_from_db("nohash") is None
def test_register_image_to_db(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
p = tmp_path / "img.png"
p.write_bytes(b"data")
img = DummyMaiImage(full_path=p, image_bytes=b"data")
assert mgr.register_image_to_db(img) is True
def test_update_image_description_not_found(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
img = DummyMaiImage()
img.file_hash = "nohash"
img.description = "desc"
assert mgr.update_image_description(img) is False
def test_delete_image_not_found(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
img = DummyMaiImage()
img.file_hash = "nohash"
img.full_path = tmp_path = None
assert mgr.delete_image(img) is False
@pytest.mark.asyncio
async def test_save_image_and_process_and_cleanup(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
# call save_image_and_process
image = await mgr.save_image_and_process(b"binarydata")
assert getattr(image, "description", None) == "dummy description"
# cleanup should run without error
mgr.cleanup_invalid_descriptions_in_db()

View File

@ -0,0 +1,264 @@
from datetime import datetime
from pathlib import Path
from rich.traceback import install
from sqlmodel import select
from typing import Optional
import base64
import hashlib
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.common.data_models.image_data_model import MaiImage
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
install(extra_lines=3)
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
DATA_DIR = PROJECT_ROOT / "data"
IMAGE_DIR = DATA_DIR / "images"
logger = get_logger("image")
def _ensure_image_dir_exists():
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
class ImageManager:
def __init__(self):
_ensure_image_dir_exists()
logger.info("图片管理器初始化完成")
async def get_image_description(self, image_hash: Optional[str] = None, image_bytes: Optional[bytes] = None) -> str:
"""
获取图片描述的封装方法
如果图片已存在于数据库中则直接返回描述
如果不存在**保存图片****生成描述**后返回
Args:
image_hash (Optional[str]): 图片的哈希值如果提供则优先使用该
image_bytes (Optional[bytes]): 图片的字节数据如果提供则在数据库中找不到哈希值时使用该数据生成描述
Returns:
return (str): 图片描述如果发生错误或无法生成描述则返回空字符串
Raises:
ValueError: 如果未提供有效的图片哈希值或图片字节数据
Exception: 在查询数据库保存图片或生成描述过程中发生的其他异常
"""
if image_hash:
hash_str = image_hash
elif not image_bytes:
raise ValueError("必须提供图片哈希值或图片字节数据")
else:
hash_str = hashlib.sha256(image_bytes).hexdigest()
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
if record := session.exec(statement).first():
return record.description
except Exception as e:
logger.error(f"查询图片描述时发生错误: {e}")
if not image_bytes:
logger.warning("图片哈希值未找到,且未提供图片字节数据,返回无描述")
return ""
logger.info(f"图片描述未找到,哈希值: {hash_str},准备生成新描述")
try:
image = await self.save_image_and_process(image_bytes)
return image.description
except Exception as e:
logger.error(f"生成图片描述时发生错误: {e}")
return ""
def get_image_from_db(self, image_hash: str) -> Optional[MaiImage]:
"""
从数据库中根据图片哈希值获取图片记录
"""
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
if record := session.exec(statement).first():
if record.no_file_flag:
logger.warning(f"数据库记录标记为文件不存在,哈希值: {image_hash}")
return None
return MaiImage.from_db_instance(record)
logger.info(f"未找到哈希值为 {image_hash} 的图片记录")
return None
def register_image_to_db(self, image: MaiImage) -> bool:
"""
将图片对象注册到数据库中
Args:
image (MaiImage): 包含图片信息的 MaiImage 对象必须包含有效的 full_path image_format
Returns:
return (bool): 注册成功返回 True失败返回 False
"""
# sourcery skip: extract-method
if not image or not isinstance(image, MaiImage):
logger.error("无效的图片对象,无法注册到数据库")
return False
if not image.full_path.exists():
logger.error(f"图片文件不存在,无法注册到数据库: {image.full_path}")
return False
try:
with get_db_session() as session:
record = image.to_db_instance()
record.is_registered = True
record.register_time = record.last_used_time = datetime.now()
session.add(record)
session.flush() # 确保记录被写入数据库以获取ID
record_id = record.id
logger.info(f"成功保存图片记录到数据库: ID: {record_id},路径: {record.full_path}")
except Exception as e:
logger.error(f"保存图片记录到数据库时发生错误: {e}")
return False
return True
def update_image_description(self, image: MaiImage) -> bool:
"""
更新图片描述
Args:
image (MaiImage): 包含新描述的图片对象必须包含有效的 file_hash full_path
Returns:
return (bool): 更新成功返回 True失败返回 False
"""
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=image.file_hash, image_type=ImageType.IMAGE).limit(1)
record = session.exec(statement).first()
if not record:
logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法更新描述")
return False
record.description = image.description
record.last_used_time = datetime.now()
session.add(record)
logger.info(f"成功更新图片描述: {image.file_hash},新描述: {image.description}")
except Exception as e:
logger.error(f"更新图片描述时发生错误: {e}")
return False
return True
def delete_image(self, image: MaiImage) -> bool:
"""
删除图片记录和对应的文件
Args:
image (MaiImage): 包含要删除图片信息的对象必须包含有效的 file_hash full_path
Returns:
return (bool): 删除成功返回 True失败返回 False
"""
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=image.file_hash, image_type=ImageType.IMAGE).limit(1)
record = session.exec(statement).first()
if not record:
logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法删除")
return False
session.delete(record)
logger.info(f"成功删除图片记录: {image.file_hash}")
if image.full_path.exists():
image.full_path.unlink()
logger.info(f"成功删除图片文件: {image.full_path}")
else:
logger.warning(f"图片文件不存在,无法删除: {image.full_path}")
except Exception as e:
logger.error(f"删除图片时发生错误: {e}")
if image.full_path.exists():
logger.warning(f"图片文件未被删除: {image.full_path}")
return False
return True
async def save_image_and_process(self, image_bytes: bytes) -> MaiImage:
"""
保存图片并生成描述
Args:
image_bytes (bytes): 图片的字节数据
Returns:
return (MaiImage): 包含图片信息的 MaiImage 对象
Raises:
Exception: 如果在保存或处理过程中发生错误
"""
hash_str = hashlib.sha256(image_bytes).hexdigest()
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=hash_str).limit(1)
if record := session.exec(statement).first():
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
record.last_used_time = datetime.now()
record.query_count += 1
session.add(record)
session.flush()
return MaiImage.from_db_instance(record)
except Exception as e:
logger.error(f"查询图片记录时发生错误: {e}")
raise e
logger.info(f"图片不存在于数据库中,准备保存新图片,哈希值: {hash_str}")
tmp_file_path = IMAGE_DIR / f"{hash_str}.tmp"
with tmp_file_path.open("wb") as f:
f.write(image_bytes)
mai_image = MaiImage(full_path=(IMAGE_DIR / f"{hash_str}.tmp"), image_bytes=image_bytes)
await mai_image.calculate_hash_format()
desc = await self._generate_image_description(image_bytes, mai_image.image_format)
mai_image.description = desc
mai_image.vlm_processed = True
try:
self.register_image_to_db(mai_image)
except Exception as e:
logger.error(f"保存新图片记录到数据库时发生错误: {e}")
raise e
return mai_image
def cleanup_invalid_descriptions_in_db(self):
"""
清理数据库中无效的图片记录
无效的判定`description` 为空或仅包含空白字符或者文件路径不存在
"""
invalid_values = {"", None}
invalid_counter: int = 0
null_path_counter: int = 0
logger.info("开始清理数据库中无效的图片记录...")
try:
with get_db_session() as session:
for record in session.exec(select(Images)).yield_per(100):
if record.description in invalid_values:
if record.full_path and Path(record.full_path).exists():
try:
Path(record.full_path).unlink()
logger.info(f"已删除无效描述的图片文件: {record.full_path}")
except Exception as e:
logger.error(f"删除无效描述的图片文件时发生错误: {e}")
session.delete(record)
invalid_counter += 1
elif record.full_path and not Path(record.full_path).exists():
session.delete(record)
null_path_counter += 1
except Exception as e:
logger.error(f"清理数据库中无效图片记录时发生错误: {e}")
logger.info(f"清理完成: {invalid_counter} 条无效描述记录,{null_path_counter} 条文件路径不存在记录")
async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str:
prompt = global_config.personality.visual_style
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
description, _ = await vlm.generate_response_for_image(prompt, image_base64, image_format, 0.4)
if not description:
logger.warning("VLM未能生成图片描述")
return description or ""

View File

@ -124,6 +124,8 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
class MaiEmoji(BaseImageDataModel):
"""麦麦的表情包对象,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象"""
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
# self.embedding = []
self.description: str = ""
@ -173,6 +175,8 @@ class MaiEmoji(BaseImageDataModel):
class MaiImage(BaseImageDataModel):
"""麦麦图片数据模型,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象"""
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
self.description: str = ""
self.vlm_processed: bool = False