diff --git a/changelogs/mai_next_todo.md b/changelogs/mai_next_todo.md index 2a5aa1b3..2c68b177 100644 --- a/changelogs/mai_next_todo.md +++ b/changelogs/mai_next_todo.md @@ -140,6 +140,8 @@ version 0.3.0 - 2026-01-11 - [ ] 指令类型文档 - [ ] 文本说明 - [ ] 代码示例 +## 消息链构建(仿Astrbot模式) +将消息仿照Astrbot的消息链模式进行构建,消息链中的每个元素都是一个消息组件,消息链本身也是一个数据模型,包含了消息组件列表以及一些元信息(如是否为转发消息等)。 ## 表情包系统 - [ ] 移除大量冗余代码,全部返回单一对象MaiEmoji diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 680e1674..c5b40bd9 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -57,11 +57,11 @@ class EmojiManager: logger.info("启动表情包管理器") async def get_emoji_description( - self, emoji_bytes: Optional[bytes] = None, emoji_hash: Optional[str] = None + self, *, emoji_bytes: Optional[bytes] = None, emoji_hash: Optional[str] = None ) -> Optional[Tuple[str, List[str]]]: """ - 根据表情包哈希获取表情包描述的封装方法 - + 根据表情包哈希获取表情包描述和情感列表的封装方法 + Args: emoji_bytes (Optional[bytes]): 表情包的字节数据,如果提供了字节数据但数据库中没有找到对应记录,则会尝试构建表情包描述 emoji_hash (Optional[str]): 表情包的哈希值,如果提供了哈希值则优先使用哈希值查找表情包描述 diff --git a/src/chat/image_system/image_manager.py b/src/chat/image_system/image_manager.py index 92f5168b..ec545894 100644 --- a/src/chat/image_system/image_manager.py +++ b/src/chat/image_system/image_manager.py @@ -36,7 +36,9 @@ class ImageManager: logger.info("图片管理器初始化完成") - async def get_image_description(self, image_hash: Optional[str] = None, image_bytes: Optional[bytes] = None) -> str: + async def get_image_description( + self, *, image_hash: Optional[str] = None, image_bytes: Optional[bytes] = None + ) -> str: """ 获取图片描述的封装方法 @@ -82,7 +84,7 @@ class ImageManager: def get_image_from_db(self, image_hash: str) -> Optional[MaiImage]: """ 从数据库中根据图片哈希值获取图片记录 - + """ with get_db_session() as session: statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1) @@ -262,3 +264,6 @@ class ImageManager: if not description: logger.warning("VLM未能生成图片描述") return description or "" + + +image_manager = ImageManager() diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py deleted file mode 100644 index 65eec4f0..00000000 --- a/src/chat/utils/utils_image.py +++ /dev/null @@ -1,787 +0,0 @@ -import base64 -from datetime import datetime -from typing import Optional, Tuple - -import hashlib -import io -import os -import time -import uuid - -import numpy as np -from PIL import Image -from rich.traceback import install - -from sqlmodel import select, col -from src.common.logger import get_logger -from src.common.database.database import get_db_session -from src.common.database.database_model import Images, ImageType -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest - -install(extra_lines=3) - -logger = get_logger("chat_image") - - -class ImageManager: - _instance = None - IMAGE_DIR = "data" # 图像存储根目录 - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self): - if not self._initialized: - self._ensure_image_dir() - - self._initialized = True - self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image") - - get_db_session() - - try: - self._cleanup_invalid_descriptions() - except Exception as e: - logger.warning(f"数据库清理失败: {e}") - - try: - self._cleanup_emoji_from_image_descriptions() - except Exception as e: - logger.warning(f"清理ImageDescriptions中的emoji记录失败: {e}") - - self._initialized = True - - def _ensure_image_dir(self): - """确保图像存储目录存在""" - os.makedirs(self.IMAGE_DIR, exist_ok=True) - - @staticmethod - def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: - """从数据库获取图片描述 - - Args: - image_hash: 图片哈希值 - description_type: 描述类型 ('emoji' 或 'image') - - Returns: - Optional[str]: 描述文本,如果不存在则返回None - """ - try: - with get_db_session() as session: - statement = select(Images).where( - (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type)) - ) - record = session.exec(statement).first() - return record.description if record else None - except Exception as e: - logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}") - return None - - @staticmethod - def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: - """保存图片描述到数据库 - - Args: - image_hash: 图片哈希值 - description: 描述文本 - description_type: 描述类型 ('emoji' 或 'image') - """ - try: - with get_db_session() as session: - statement = select(Images).where( - (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type)) - ) - record = session.exec(statement).first() - if record: - record.description = description - session.add(record) - return - - new_record = Images( - image_hash=image_hash, - description=description, - full_path="", - image_type=ImageType(description_type), - query_count=0, - is_registered=False, - is_banned=False, - vlm_processed=True, - ) - session.add(new_record) - except Exception as e: - logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") - - @staticmethod - def _cleanup_invalid_descriptions(): - """清理数据库中 description 为空或为 'None' 的记录""" - invalid_values = ["", "None"] - - with get_db_session() as session: - statement = ( - select(Images) - .where(col(Images.description).is_(None) | col(Images.description).in_(invalid_values)) - .limit(1000) - ) - records = session.exec(statement).all() - for record in records: - session.delete(record) - - if records: - logger.info(f"[清理完成] 删除 Images: {len(records)} 条") - else: - logger.info("[清理完成] 未发现无效描述记录") - - @staticmethod - def _cleanup_emoji_from_image_descriptions(): - """清理Images和ImageDescriptions表中type为emoji的记录(已迁移到EmojiDescriptionCache)""" - try: - with get_db_session() as session: - statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI) - records = session.exec(statement).all() - for record in records: - session.delete(record) - - total_deleted = len(records) - if total_deleted > 0: - logger.info(f"[清理完成] 从Images表中删除 {total_deleted} 条emoji类型记录") - else: - logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录") - except Exception as e: - logger.error(f"清理Images和ImageDescriptions中的emoji记录时出错: {str(e)}") - raise - - async def get_emoji_tag(self, image_base64: str) -> str: - from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance - - emoji_manager = emoji_manager_instance - if isinstance(image_base64, str): - image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") - image_bytes = base64.b64decode(image_base64) - image_hash = hashlib.md5(image_bytes).hexdigest() - emoji = emoji_manager.get_emoji_by_hash(image_hash) - if not emoji: - return "[表情包:未知]" - emotion_list = emoji.emotion - tag_str = ",".join(emotion_list) - return f"[表情包:{tag_str}]" - - async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None: - """如果启用了steal_emoji且表情包未注册,保存文件到data/emoji目录 - - Args: - image_base64: 图片的base64编码 - image_hash: 图片的MD5哈希值 - image_format: 图片格式 - """ - if not global_config.emoji.steal_emoji: - return - - try: - from src.chat.emoji_system.emoji_manager import EMOJI_DIR - from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance - - # 确保目录存在 - os.makedirs(EMOJI_DIR, exist_ok=True) - - # 检查是否已存在该表情包(通过哈希值) - emoji_manager = emoji_manager_instance - existing_emoji = emoji_manager.get_emoji_by_hash(image_hash) - if existing_emoji: - logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...") - return - - # 生成文件名:使用哈希值前8位 + 格式 - filename = f"{image_hash[:8]}.{image_format}" - file_path = os.path.join(EMOJI_DIR, filename) - - # 检查文件是否已存在(可能之前保存过但未注册) - if not os.path.exists(file_path): - # 保存文件 - if base64_to_image(image_base64, file_path): - logger.info(f"[自动保存] 表情包已保存到 {file_path} (Hash: {image_hash[:8]}...)") - else: - logger.warning(f"[自动保存] 保存表情包文件失败: {file_path}") - else: - logger.debug(f"[自动保存] 表情包文件已存在,跳过: {file_path}") - except Exception as save_error: - logger.warning(f"[自动保存] 保存表情包文件时出错: {save_error}") - - async def get_emoji_description(self, image_base64: str) -> str: - """获取表情包描述,优先使用EmojiDescriptionCache表中的缓存数据""" - try: - # 计算图片哈希 - # 确保base64字符串只包含ASCII字符 - if isinstance(image_base64, str): - image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") - image_bytes = base64.b64decode(image_base64) - image_hash = hashlib.md5(image_bytes).hexdigest() - image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore - - # 优先使用EmojiManager查询已注册表情包的描述 - try: - from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance - - emoji_manager = emoji_manager_instance - emoji = emoji_manager.get_emoji_by_hash(image_hash) - tags = emoji.emotion if emoji else None - if tags: - tag_str = ",".join(tags) - logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...") - return f"[表情包:{tag_str}]" - except Exception as e: - logger.debug(f"查询EmojiManager时出错: {e}") - - try: - with get_db_session() as session: - statement = select(Images).where( - (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI) - ) - cache_record = session.exec(statement).first() - if cache_record: - result_text = "" - if cache_record.emotion: - logger.info(f"[缓存命中] 使用Images表中的情感标签: {cache_record.emotion[:50]}...") - result_text = f"[表情包:{cache_record.emotion}]" - elif cache_record.description: - logger.info(f"[缓存命中] 使用Images表中的描述: {cache_record.description[:50]}...") - result_text = f"[表情包:{cache_record.description}]" - - if result_text: - await self._save_emoji_file_if_needed(image_base64, image_hash, image_format) - return result_text - except Exception as e: - logger.debug(f"查询Images缓存时出错: {e}") - - # === 二步走识别流程 === - - # 第一步:VLM视觉分析 - 生成详细描述 - if image_format in ["gif", "GIF"]: - image_base64_processed = self.transform_gif(image_base64) - if image_base64_processed is None: - logger.warning("GIF转换失败,无法获取描述") - return "[表情包(GIF处理失败)]" - vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - detailed_description, _ = await self.vlm.generate_response_for_image( - vlm_prompt, image_base64_processed, "jpg", temperature=0.4 - ) - else: - vlm_prompt = ( - "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - ) - detailed_description, _ = await self.vlm.generate_response_for_image( - vlm_prompt, image_base64, image_format, temperature=0.4 - ) - - if detailed_description is None: - logger.warning("VLM未能生成表情包详细描述") - return "[表情包(VLM描述生成失败)]" - - # 第二步:LLM情感分析 - 基于详细描述生成简短的情感标签 - emotion_prompt = f""" - 请你基于这个表情包的详细描述,提取出最核心的情感含义,用1-2个词概括。 - 详细描述:'{detailed_description}' - - 要求: - 1. 只输出1-2个最核心的情感词汇 - 2. 从互联网梗、meme的角度理解 - 3. 输出简短精准,不要解释 - 4. 如果有多个词用逗号分隔 - """ - - # 使用较低温度确保输出稳定 - emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji") - emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt, temperature=0.3) - - if not emotion_result: - logger.warning("LLM未能生成情感标签,使用详细描述的前几个词") - # 降级处理:从详细描述中提取关键词 - import jieba - - words = list(jieba.cut(detailed_description)) - emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情") - - # 处理情感结果,取前1-2个最重要的标签 - emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()] - final_emotion = emotions[0] if emotions else "表情" - - # 如果有第二个情感且不重复,也包含进来 - if len(emotions) > 1 and emotions[1] != emotions[0]: - final_emotion = f"{emotions[0]},{emotions[1]}" - - logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}") - - try: - with get_db_session() as session: - statement = select(Images).where( - (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI) - ) - cache_record = session.exec(statement).first() - if cache_record and cache_record.emotion: - logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion}") - return f"[表情包:{cache_record.emotion}]" - except Exception as e: - logger.debug(f"再次查询Images缓存时出错: {e}") - - try: - with get_db_session() as session: - statement = select(Images).where( - (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI) - ) - cache_record = session.exec(statement).first() - if cache_record: - cache_record.description = detailed_description - cache_record.emotion = final_emotion - session.add(cache_record) - else: - cache_record = Images( - image_hash=image_hash, - description=detailed_description, - full_path="", - image_type=ImageType.EMOJI, - emotion=final_emotion, - query_count=0, - is_registered=False, - is_banned=False, - vlm_processed=True, - ) - session.add(cache_record) - logger.info(f"[缓存保存] 表情包描述和情感标签已保存到Images: {image_hash[:8]}...") - except Exception as e: - logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}") - - # 如果启用了steal_emoji,自动保存表情包文件到data/emoji目录 - await self._save_emoji_file_if_needed(image_base64, image_hash, image_format) - - return f"[表情包:{final_emotion}]" - - except Exception as e: - logger.error(f"获取表情包描述失败: {str(e)}") - return "[表情包(处理失败)]" - - async def get_image_description(self, image_base64: str) -> str: - """获取普通图片描述,优先使用Images表中的缓存数据""" - try: - # 计算图片哈希 - if isinstance(image_base64, str): - image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") - image_bytes = base64.b64decode(image_base64) - image_hash = hashlib.md5(image_bytes).hexdigest() - - # 优先检查Images表中是否已有完整的描述 - with get_db_session() as session: - statement = select(Images).where(col(Images.image_hash) == image_hash) - existing_image = session.exec(statement).first() - if existing_image: - existing_image.query_count += 1 - with get_db_session() as session: - session.add(existing_image) - - # 如果已有描述,直接返回 - if existing_image.description: - logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...") - return f"[图片:{existing_image.description}]" - - if cached_description := self._get_description_from_db(image_hash, "image"): - logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...") - return f"[图片:{cached_description}]" - - # 调用AI获取描述 - image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore - prompt = global_config.personality.visual_style - logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") - description, _ = await self.vlm.generate_response_for_image( - prompt, image_base64, image_format, temperature=0.4 - ) - - if description is None: - logger.warning("AI未能生成图片描述") - return "[图片(描述生成失败)]" - - # 保存图片和描述 - current_timestamp = time.time() - filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" - image_dir = os.path.join(self.IMAGE_DIR, "image") - os.makedirs(image_dir, exist_ok=True) - file_path = os.path.join(image_dir, filename) - - try: - # 保存文件 - with open(file_path, "wb") as f: - f.write(image_bytes) - - # 保存到数据库,补充缺失字段 - if existing_image: - existing_image.full_path = file_path - existing_image.description = description - existing_image.record_time = datetime.fromtimestamp(current_timestamp) - existing_image.vlm_processed = True - with get_db_session() as session: - session.add(existing_image) - logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...") - else: - with get_db_session() as session: - new_record = Images( - image_hash=image_hash, - description=description, - full_path=file_path, - image_type=ImageType.IMAGE, - query_count=1, - is_registered=False, - is_banned=False, - record_time=datetime.fromtimestamp(current_timestamp), - vlm_processed=True, - ) - session.add(new_record) - logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") - except Exception as e: - logger.error(f"保存图片文件或元数据失败: {str(e)}") - - # 保存描述到ImageDescriptions表作为备用缓存 - self._save_description_to_db(image_hash, description, "image") - - logger.info(f"[VLM完成] 图片描述生成: {description[:50]}...") - return f"[图片:{description}]" - except Exception as e: - logger.error(f"获取图片描述失败: {str(e)}") - return "[图片(处理失败)]" - - @staticmethod - def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]: - # sourcery skip: use-contextlib-suppress - """将GIF转换为水平拼接的静态图像, 跳过相似的帧 - - Args: - gif_base64: GIF的base64编码字符串 - similarity_threshold: 判定帧相似的阈值 (MSE),越小表示要求差异越大才算不同帧,默认1000.0 - max_frames: 最大抽取的帧数,默认15 - - Returns: - Optional[str]: 拼接后的JPG图像的base64编码字符串, 或者在失败时返回None - """ - try: - # 确保base64字符串只包含ASCII字符 - if isinstance(gif_base64, str): - gif_base64 = gif_base64.encode("ascii", errors="ignore").decode("ascii") - # 解码base64 - gif_data = base64.b64decode(gif_base64) - gif = Image.open(io.BytesIO(gif_data)) - - # 收集所有帧 - all_frames = [] - try: - while True: - gif.seek(len(all_frames)) - # 确保是RGB格式方便比较 - frame = gif.convert("RGB") - all_frames.append(frame.copy()) - except EOFError: - pass # 读完啦 - - if not all_frames: - logger.warning("GIF中没有找到任何帧") - return None # 空的GIF直接返回None - - # --- 新的帧选择逻辑 --- - selected_frames = [] - last_selected_frame_np = None - - for i, current_frame in enumerate(all_frames): - current_frame_np = np.array(current_frame) - - # 第一帧总是要选的 - if i == 0: - selected_frames.append(current_frame) - last_selected_frame_np = current_frame_np - continue - - # 计算和上一张选中帧的差异(均方误差 MSE) - if last_selected_frame_np is not None: - mse = np.mean((current_frame_np - last_selected_frame_np) ** 2) - # logger.debug(f"帧 {i} 与上一选中帧的 MSE: {mse}") # 可以取消注释来看差异值 - - # 如果差异够大,就选它! - if mse > similarity_threshold: - selected_frames.append(current_frame) - last_selected_frame_np = current_frame_np - # 检查是不是选够了 - if len(selected_frames) >= max_frames: - # logger.debug(f"已选够 {max_frames} 帧,停止选择。") - break - # 如果差异不大就跳过这一帧啦 - - # --- 帧选择逻辑结束 --- - - # 如果选择后连一帧都没有(比如GIF只有一帧且后续处理失败?)或者原始GIF就没帧,也返回None - if not selected_frames: - logger.warning("处理后没有选中任何帧") - return None - - # logger.debug(f"总帧数: {len(all_frames)}, 选中帧数: {len(selected_frames)}") - - # 获取选中的第一帧的尺寸(假设所有帧尺寸一致) - frame_width, frame_height = selected_frames[0].size - - # 计算目标尺寸,保持宽高比 - target_height = 200 # 固定高度 - # 防止除以零 - if frame_height == 0: - logger.error("帧高度为0,无法计算缩放尺寸") - return None - target_width = int((target_height / frame_height) * frame_width) - # 宽度也不能是0 - if target_width == 0: - logger.warning(f"计算出的目标宽度为0 (原始尺寸 {frame_width}x{frame_height}),调整为1") - target_width = 1 - - # 调整所有选中帧的大小 - resized_frames = [ - frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames - ] - - # 创建拼接图像 - total_width = target_width * len(resized_frames) - # 防止总宽度为0 - if total_width == 0 and resized_frames: - logger.warning("计算出的总宽度为0,但有选中帧,可能目标宽度太小") - # 至少给点宽度吧 - total_width = len(resized_frames) - elif total_width == 0: - logger.error("计算出的总宽度为0且无选中帧") - return None - - combined_image = Image.new("RGB", (total_width, target_height)) - - # 水平拼接图像 - for idx, frame in enumerate(resized_frames): - combined_image.paste(frame, (idx * target_width, 0)) - - # 转换为base64 - buffer = io.BytesIO() - combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG - return base64.b64encode(buffer.getvalue()).decode("utf-8") - except MemoryError: - logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多") - return None # 内存不够啦 - except Exception as e: - logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息 - return None # 其他错误也返回None - - async def process_image(self, image_base64: str) -> Tuple[str, str]: - # sourcery skip: hoist-if-from-if - """处理图片并返回图片ID和描述 - - Args: - image_base64: 图片的base64编码 - - Returns: - Tuple[str, str]: (图片ID, 描述) - """ - try: - # 生成图片ID - # 计算图片哈希 - # 确保base64字符串只包含ASCII字符 - if isinstance(image_base64, str): - image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") - image_bytes = base64.b64decode(image_base64) - image_hash = hashlib.md5(image_bytes).hexdigest() - - with get_db_session() as session: - statement = select(Images).where( - (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.IMAGE) - ) - existing_image = session.exec(statement).first() - if existing_image: - existing_image.query_count += 1 - session.add(existing_image) - return str(existing_image.id), f"[picid:{existing_image.id}]" - - image_id = str(uuid.uuid4()) - - # 保存新图片 - current_timestamp = time.time() - image_dir = os.path.join(self.IMAGE_DIR, "images") - os.makedirs(image_dir, exist_ok=True) - filename = f"{image_id}.png" - file_path = os.path.join(image_dir, filename) - - # 保存文件 - with open(file_path, "wb") as f: - f.write(image_bytes) - - # 保存到数据库 - with get_db_session() as session: - new_record = Images( - image_hash=image_hash, - description="", - full_path=file_path, - image_type=ImageType.IMAGE, - query_count=1, - is_registered=False, - is_banned=False, - record_time=datetime.fromtimestamp(current_timestamp), - vlm_processed=False, - ) - session.add(new_record) - - # 启动异步VLM处理 - await self._process_image_with_vlm(image_id, image_base64) - - return image_id, f"[picid:{image_id}]" - - except Exception as e: - logger.error(f"处理图片失败: {str(e)}") - return "", "[图片]" - - async def _process_image_with_vlm(self, image_id: str, image_base64: str) -> None: - """使用VLM处理图片并更新数据库 - - Args: - image_id: 图片ID - image_base64: 图片的base64编码 - """ - try: - # 计算图片哈希 - # 确保base64字符串只包含ASCII字符 - if isinstance(image_base64, str): - image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") - image_bytes = base64.b64decode(image_base64) - image_hash = hashlib.md5(image_bytes).hexdigest() - - # 获取当前图片记录 - with get_db_session() as session: - image = session.get(Images, int(image_id)) if image_id.isdigit() else None - if image is None: - logger.warning(f"未找到图片记录: {image_id}") - return - - # 优先检查是否已有其他相同哈希的图片记录包含描述 - with get_db_session() as session: - statement = select(Images).where( - (col(Images.image_hash) == image_hash) - & (col(Images.description).is_not(None)) - & (col(Images.description) != "") - ) - existing_with_description = session.exec(statement).first() - if existing_with_description and existing_with_description.id != image.id: - logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...") - image.description = existing_with_description.description - image.vlm_processed = True - with get_db_session() as session: - session.add(image) - # 同时保存到ImageDescriptions表作为备用缓存 - self._save_description_to_db(image_hash, existing_with_description.description, "image") - return - - # 检查ImageDescriptions表的缓存描述 - if cached_description := self._get_description_from_db(image_hash, "image"): - logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...") - image.description = cached_description - image.vlm_processed = True - with get_db_session() as session: - session.add(image) - return - - # 获取图片格式 - image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore - - # 构建prompt - prompt = global_config.personality.visual_style - - # 获取VLM描述 - description, _ = await self.vlm.generate_response_for_image( - prompt, image_base64, image_format, temperature=0.4 - ) - - if description is None: - logger.warning("VLM未能生成图片描述") - description = "" - - if cached_description := self._get_description_from_db(image_hash, "image"): - logger.info(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}") - description = cached_description - - # 更新数据库 - image.description = description - image.vlm_processed = True - with get_db_session() as session: - session.add(image) - - # 保存描述到ImageDescriptions表作为备用缓存 - self._save_description_to_db(image_hash, description, "image") - - except Exception as e: - logger.error(f"VLM处理图片失败: {str(e)}") - - -# 创建全局单例 -image_manager = None - - -def get_image_manager() -> ImageManager: - """获取全局图片管理器单例""" - global image_manager - if image_manager is None: - image_manager = ImageManager() - return image_manager - - -def image_path_to_base64(image_path: str) -> str: - """将图片路径转换为base64编码 - Args: - image_path: 图片文件路径 - Returns: - str: base64编码的图片数据 - Raises: - FileNotFoundError: 当图片文件不存在时 - IOError: 当读取图片文件失败时 - """ - if not os.path.exists(image_path): - raise FileNotFoundError(f"图片文件不存在: {image_path}") - - with open(image_path, "rb") as f: - if image_data := f.read(): - return base64.b64encode(image_data).decode("utf-8") - else: - raise IOError(f"读取图片文件失败: {image_path}") - - -def base64_to_image(image_base64: str, output_path: str) -> bool: - """将base64编码的图片保存为文件 - - Args: - image_base64: 图片的base64编码 - output_path: 输出文件路径 - - Returns: - bool: 是否成功保存 - - Raises: - ValueError: 当base64编码无效时 - IOError: 当保存文件失败时 - """ - try: - # 确保base64字符串只包含ASCII字符 - if isinstance(image_base64, str): - image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") - - # 解码base64 - image_bytes = base64.b64decode(image_base64) - - # 确保输出目录存在 - output_dir = os.path.dirname(output_path) - if output_dir: - os.makedirs(output_dir, exist_ok=True) - - # 保存文件 - with open(output_path, "wb") as f: - f.write(image_bytes) - - return True - - except Exception as e: - logger.error(f"保存base64图片失败: {e}") - return False diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py deleted file mode 100644 index 49ec1079..00000000 --- a/src/chat/utils/utils_voice.py +++ /dev/null @@ -1,29 +0,0 @@ -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest - -from src.common.logger import get_logger -from rich.traceback import install - -install(extra_lines=3) - -logger = get_logger("chat_voice") - - -async def get_voice_text(voice_base64: str) -> str: - """获取音频文件转录文本""" - if not global_config.voice.enable_asr: - logger.warning("语音识别未启用,无法处理语音消息") - return "[语音]" - try: - _llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio") - text = await _llm.generate_response_for_voice(voice_base64) - if text is None: - logger.warning("未能生成语音文本") - return "[语音(文本生成失败)]" - - logger.debug(f"描述是{text}") - - return f"[语音:{text}]" - except Exception as e: - logger.error(f"语音转文字失败: {str(e)}") - return "[语音]" diff --git a/src/common/utils/utils_person.py b/src/common/utils/utils_person.py new file mode 100644 index 00000000..6089e2d2 --- /dev/null +++ b/src/common/utils/utils_person.py @@ -0,0 +1,39 @@ +from rich.traceback import install +from sqlmodel import select +from typing import Optional + +import hashlib + +from src.common.logger import get_logger +from src.common.data_models.person_info_data_model import MaiPersonInfo +from src.common.database.database_model import PersonInfo +from src.common.database.database import get_db_session + +install(extra_lines=3) + +logger = get_logger("person_utils") + + +class PersonUtils: + @staticmethod + def get_person_info_by_id(person_id: str) -> Optional[MaiPersonInfo]: + """根据person_id获取用户信息""" + try: + with get_db_session() as session: + statement = select(PersonInfo).filter_by(person_id=person_id) + if result := session.exec(statement).first(): + return MaiPersonInfo.from_db_instance(result) + except Exception as e: + logger.error(f"查询用户信息失败: {str(e)}") + return None + + @staticmethod + def calculate_person_id(platform: str, user_id: str) -> str: + """根据平台和用户ID计算person_id""" + return hashlib.sha256(f"{platform}_{user_id}".encode("utf-8")).hexdigest() + + @staticmethod + def get_person_info_by_user_id_and_platform(user_id: str, platform: str) -> Optional[MaiPersonInfo]: + """根据user_id和platform获取用户信息""" + person_id = PersonUtils.calculate_person_id(platform, user_id) + return PersonUtils.get_person_info_by_id(person_id) diff --git a/src/common/utils/utils_voice.py b/src/common/utils/utils_voice.py new file mode 100644 index 00000000..fafe19a0 --- /dev/null +++ b/src/common/utils/utils_voice.py @@ -0,0 +1,42 @@ +from rich.traceback import install +from typing import Optional + +import base64 + +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest + + +install(extra_lines=3) + +logger = get_logger("chat_voice") + +# TODO: 在LLMRequest重构后修改这里 +asr_model = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio") + + +async def get_voice_text(voice_bytes: bytes) -> Optional[str]: + """ + 获取音频文件转录文本 + + Args: + voice_bytes (bytes): 语音消息的字节数据 + Returns: + return (Optional[str]): 转录后的文本描述,如果转录失败或未启用语音识别功能,则返回 None + """ + if not global_config.voice.enable_asr: + logger.warning("语音识别未启用,无法处理语音消息") + return None + try: + voice_base64 = base64.b64encode(voice_bytes).decode("utf-8") + text = await asr_model.generate_response_for_voice(voice_base64) + if not text: + logger.warning("语音转文字结果为空") + + # logger.debug(f"转录结果是是{text}") + + return text + except Exception as e: + logger.error(f"语音转文字失败: {str(e)}") + return None