mirror of https://github.com/Mai-with-u/MaiBot.git
447 lines
16 KiB
Python
447 lines
16 KiB
Python
import base64
|
||
import io
|
||
import os
|
||
import time
|
||
import zlib
|
||
import aiohttp
|
||
import hashlib
|
||
from typing import Optional, Tuple, Union
|
||
from urllib.parse import urlparse
|
||
|
||
from loguru import logger
|
||
from nonebot import get_driver
|
||
from PIL import Image
|
||
|
||
from ...common.database import Database
|
||
from ..chat.config import global_config
|
||
from ..models.utils_model import LLM_request
|
||
driver = get_driver()
|
||
config = driver.config
|
||
|
||
class ImageManager:
|
||
_instance = None
|
||
IMAGE_DIR = "data" # 图像存储根目录
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance.db = None
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if not self._initialized:
|
||
self.db = Database.get_instance()
|
||
self._ensure_image_collection()
|
||
self._ensure_description_collection()
|
||
self._ensure_image_dir()
|
||
self._initialized = True
|
||
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
|
||
|
||
def _ensure_image_dir(self):
|
||
"""确保图像存储目录存在"""
|
||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||
|
||
def _ensure_image_collection(self):
|
||
"""确保images集合存在并创建索引"""
|
||
if 'images' not in self.db.db.list_collection_names():
|
||
self.db.db.create_collection('images')
|
||
# 创建索引
|
||
self.db.db.images.create_index([('hash', 1)], unique=True)
|
||
self.db.db.images.create_index([('url', 1)])
|
||
self.db.db.images.create_index([('path', 1)])
|
||
|
||
def _ensure_description_collection(self):
|
||
"""确保image_descriptions集合存在并创建索引"""
|
||
if 'image_descriptions' not in self.db.db.list_collection_names():
|
||
self.db.db.create_collection('image_descriptions')
|
||
# 创建索引
|
||
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
|
||
self.db.db.image_descriptions.create_index([('type', 1)])
|
||
|
||
async def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
||
"""从数据库获取图片描述
|
||
|
||
Args:
|
||
image_hash: 图片哈希值
|
||
description_type: 描述类型 ('emoji' 或 'image')
|
||
|
||
Returns:
|
||
Optional[str]: 描述文本,如果不存在则返回None
|
||
"""
|
||
result = await self.db.db.image_descriptions.find_one({
|
||
'hash': image_hash,
|
||
'type': description_type
|
||
})
|
||
return result['description'] if result else None
|
||
|
||
async def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
|
||
"""保存图片描述到数据库
|
||
|
||
Args:
|
||
image_hash: 图片哈希值
|
||
description: 描述文本
|
||
description_type: 描述类型 ('emoji' 或 'image')
|
||
"""
|
||
await self.db.db.image_descriptions.update_one(
|
||
{'hash': image_hash, 'type': description_type},
|
||
{
|
||
'$set': {
|
||
'description': description,
|
||
'timestamp': int(time.time())
|
||
}
|
||
},
|
||
upsert=True
|
||
)
|
||
|
||
async def save_image(self,
|
||
image_data: Union[str, bytes],
|
||
url: str = None,
|
||
description: str = None,
|
||
is_base64: bool = False) -> Optional[str]:
|
||
"""保存图像
|
||
Args:
|
||
image_data: 图像数据(base64字符串或字节)
|
||
url: 图像URL
|
||
description: 图像描述
|
||
is_base64: image_data是否为base64格式
|
||
Returns:
|
||
str: 保存后的文件路径,失败返回None
|
||
"""
|
||
try:
|
||
# 转换为字节格式
|
||
if is_base64:
|
||
if isinstance(image_data, str):
|
||
image_bytes = base64.b64decode(image_data)
|
||
else:
|
||
return None
|
||
else:
|
||
if isinstance(image_data, bytes):
|
||
image_bytes = image_data
|
||
else:
|
||
return None
|
||
|
||
# 计算哈希值
|
||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||
|
||
# 查重
|
||
existing = self.db.db.images.find_one({'hash': image_hash})
|
||
if existing:
|
||
return existing['path']
|
||
|
||
# 生成文件名和路径
|
||
timestamp = int(time.time())
|
||
filename = f"{timestamp}_{image_hash[:8]}.jpg"
|
||
file_path = os.path.join(self.IMAGE_DIR, filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as f:
|
||
f.write(image_bytes)
|
||
|
||
# 保存到数据库
|
||
image_doc = {
|
||
'hash': image_hash,
|
||
'path': file_path,
|
||
'url': url,
|
||
'description': description,
|
||
'timestamp': timestamp
|
||
}
|
||
self.db.db.images.insert_one(image_doc)
|
||
|
||
return file_path
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存图像失败: {str(e)}")
|
||
return None
|
||
|
||
async def get_image_by_url(self, url: str) -> Optional[str]:
|
||
"""根据URL获取图像路径(带查重)
|
||
Args:
|
||
url: 图像URL
|
||
Returns:
|
||
str: 本地文件路径,不存在返回None
|
||
"""
|
||
try:
|
||
# 先查找是否已存在
|
||
existing = self.db.db.images.find_one({'url': url})
|
||
if existing:
|
||
return existing['path']
|
||
|
||
# 下载图像
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url) as resp:
|
||
if resp.status == 200:
|
||
image_bytes = await resp.read()
|
||
return await self.save_image(image_bytes, url=url)
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取图像失败: {str(e)}")
|
||
return None
|
||
|
||
async def get_base64_by_url(self, url: str) -> Optional[str]:
|
||
"""根据URL获取base64(带查重)
|
||
Args:
|
||
url: 图像URL
|
||
Returns:
|
||
str: base64字符串,失败返回None
|
||
"""
|
||
try:
|
||
image_path = await self.get_image_by_url(url)
|
||
if not image_path:
|
||
return None
|
||
|
||
with open(image_path, 'rb') as f:
|
||
image_bytes = f.read()
|
||
return base64.b64encode(image_bytes).decode('utf-8')
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取base64失败: {str(e)}")
|
||
return None
|
||
|
||
async def save_base64_image(self, base64_str: str, description: str = None) -> Optional[str]:
|
||
"""保存base64图像(带查重)
|
||
Args:
|
||
base64_str: base64字符串
|
||
description: 图像描述
|
||
Returns:
|
||
str: 保存路径,失败返回None
|
||
"""
|
||
return await self.save_image(base64_str, description=description, is_base64=True)
|
||
|
||
def check_url_exists(self, url: str) -> bool:
|
||
"""检查URL是否已存在
|
||
Args:
|
||
url: 图像URL
|
||
Returns:
|
||
bool: 是否存在
|
||
"""
|
||
return self.db.db.images.find_one({'url': url}) is not None
|
||
|
||
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
|
||
"""检查图像是否已存在
|
||
Args:
|
||
image_data: 图像数据(base64或字节)
|
||
is_base64: 是否为base64格式
|
||
Returns:
|
||
bool: 是否存在
|
||
"""
|
||
try:
|
||
if is_base64:
|
||
if isinstance(image_data, str):
|
||
image_bytes = base64.b64decode(image_data)
|
||
else:
|
||
return False
|
||
else:
|
||
if isinstance(image_data, bytes):
|
||
image_bytes = image_data
|
||
else:
|
||
return False
|
||
|
||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||
return self.db.db.images.find_one({'hash': image_hash}) is not None
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查哈希失败: {str(e)}")
|
||
return False
|
||
|
||
async def get_emoji_description(self, image_base64: str) -> str:
|
||
"""获取表情包描述,带查重和保存功能"""
|
||
try:
|
||
# 计算图片哈希
|
||
image_bytes = base64.b64decode(image_base64)
|
||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||
|
||
# 查询缓存的描述
|
||
cached_description = await self._get_description_from_db(image_hash, 'emoji')
|
||
if cached_description:
|
||
return f"[表情包:{cached_description}]"
|
||
|
||
# 调用AI获取描述
|
||
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
|
||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||
|
||
# 根据配置决定是否保存图片
|
||
if global_config.EMOJI_SAVE:
|
||
# 生成文件名和路径
|
||
timestamp = int(time.time())
|
||
filename = f"emoji_{timestamp}_{image_hash[:8]}.jpg"
|
||
file_path = os.path.join(self.IMAGE_DIR, filename)
|
||
|
||
try:
|
||
# 保存文件
|
||
with open(file_path, "wb") as f:
|
||
f.write(image_bytes)
|
||
|
||
# 保存到数据库
|
||
image_doc = {
|
||
'hash': image_hash,
|
||
'path': file_path,
|
||
'type': 'emoji',
|
||
'description': description,
|
||
'timestamp': timestamp
|
||
}
|
||
await self.db.db.images.update_one(
|
||
{'hash': image_hash},
|
||
{'$set': image_doc},
|
||
upsert=True
|
||
)
|
||
logger.success(f"保存表情包: {file_path}")
|
||
except Exception as e:
|
||
logger.error(f"保存表情包文件失败: {str(e)}")
|
||
|
||
# 保存描述到数据库
|
||
await self._save_description_to_db(image_hash, description, 'emoji')
|
||
|
||
return f"[表情包:{description}]"
|
||
except Exception as e:
|
||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||
return "[表情包]"
|
||
|
||
async def get_image_description(self, image_base64: str) -> str:
|
||
"""获取普通图片描述,带查重和保存功能"""
|
||
try:
|
||
# 计算图片哈希
|
||
image_bytes = base64.b64decode(image_base64)
|
||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||
|
||
# 查询缓存的描述
|
||
cached_description = await self._get_description_from_db(image_hash, 'image')
|
||
if cached_description:
|
||
return f"[图片:{cached_description}]"
|
||
|
||
# 调用AI获取描述
|
||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||
|
||
# 根据配置决定是否保存图片
|
||
if global_config.EMOJI_SAVE:
|
||
# 生成文件名和路径
|
||
timestamp = int(time.time())
|
||
filename = f"image_{timestamp}_{image_hash[:8]}.jpg"
|
||
file_path = os.path.join(self.IMAGE_DIR, filename)
|
||
|
||
try:
|
||
# 保存文件
|
||
with open(file_path, "wb") as f:
|
||
f.write(image_bytes)
|
||
|
||
# 保存到数据库
|
||
image_doc = {
|
||
'hash': image_hash,
|
||
'path': file_path,
|
||
'type': 'image',
|
||
'description': description,
|
||
'timestamp': timestamp
|
||
}
|
||
await self.db.db.images.update_one(
|
||
{'hash': image_hash},
|
||
{'$set': image_doc},
|
||
upsert=True
|
||
)
|
||
logger.success(f"保存图片: {file_path}")
|
||
except Exception as e:
|
||
logger.error(f"保存图片文件失败: {str(e)}")
|
||
|
||
# 保存描述到数据库
|
||
await self._save_description_to_db(image_hash, description, 'image')
|
||
|
||
return f"[图片:{description}]"
|
||
except Exception as e:
|
||
logger.error(f"获取图片描述失败: {str(e)}")
|
||
return "[图片]"
|
||
|
||
|
||
|
||
# 创建全局单例
|
||
image_manager = ImageManager()
|
||
|
||
|
||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||
"""压缩base64格式的图片到指定大小
|
||
Args:
|
||
base64_data: base64编码的图片数据
|
||
target_size: 目标文件大小(字节),默认0.8MB
|
||
Returns:
|
||
str: 压缩后的base64图片数据
|
||
"""
|
||
try:
|
||
# 将base64转换为字节数据
|
||
image_data = base64.b64decode(base64_data)
|
||
|
||
# 如果已经小于目标大小,直接返回原图
|
||
if len(image_data) <= 2*1024*1024:
|
||
return base64_data
|
||
|
||
# 将字节数据转换为图片对象
|
||
img = Image.open(io.BytesIO(image_data))
|
||
|
||
# 获取原始尺寸
|
||
original_width, original_height = img.size
|
||
|
||
# 计算缩放比例
|
||
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
|
||
|
||
# 计算新的尺寸
|
||
new_width = int(original_width * scale)
|
||
new_height = int(original_height * scale)
|
||
|
||
# 创建内存缓冲区
|
||
output_buffer = io.BytesIO()
|
||
|
||
# 如果是GIF,处理所有帧
|
||
if getattr(img, "is_animated", False):
|
||
frames = []
|
||
for frame_idx in range(img.n_frames):
|
||
img.seek(frame_idx)
|
||
new_frame = img.copy()
|
||
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
|
||
frames.append(new_frame)
|
||
|
||
# 保存到缓冲区
|
||
frames[0].save(
|
||
output_buffer,
|
||
format='GIF',
|
||
save_all=True,
|
||
append_images=frames[1:],
|
||
optimize=True,
|
||
duration=img.info.get('duration', 100),
|
||
loop=img.info.get('loop', 0)
|
||
)
|
||
else:
|
||
# 处理静态图片
|
||
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||
|
||
# 保存到缓冲区,保持原始格式
|
||
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
|
||
resized_img.save(output_buffer, format='PNG', optimize=True)
|
||
else:
|
||
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
|
||
|
||
# 获取压缩后的数据并转换为base64
|
||
compressed_data = output_buffer.getvalue()
|
||
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
|
||
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
|
||
|
||
return base64.b64encode(compressed_data).decode('utf-8')
|
||
|
||
except Exception as e:
|
||
logger.error(f"压缩图片失败: {str(e)}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return base64_data
|
||
|
||
def image_path_to_base64(image_path: str) -> str:
|
||
"""将图片路径转换为base64编码
|
||
Args:
|
||
image_path: 图片文件路径
|
||
Returns:
|
||
str: base64编码的图片数据
|
||
"""
|
||
try:
|
||
with open(image_path, 'rb') as f:
|
||
image_data = f.read()
|
||
return base64.b64encode(image_data).decode('utf-8')
|
||
except Exception as e:
|
||
logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}")
|
||
return None |