From 874b81c0e6589a966029779455824b2890b8b526 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Thu, 13 Nov 2025 21:22:39 +0800 Subject: [PATCH] Add msgpack and zstandard dependencies; implement MMIPKG tool for emoji packaging and unpacking - Updated requirements.txt to include msgpack and zstandard libraries. - Implemented the MMIPKG tool for exporting and importing registered emojis in MaiBot. - Added functionality for image processing, SHA256 calculation, and database interaction. - Included error handling and user prompts for a better interactive experience. --- requirements.txt | 4 +- scripts/mmipkg_tool.py | 1125 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1128 insertions(+), 1 deletion(-) create mode 100644 scripts/mmipkg_tool.py diff --git a/requirements.txt b/requirements.txt index 30328554..8a0d22c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,6 @@ structlog>=25.4.0 toml>=0.10.2 tomlkit>=0.13.3 urllib3>=2.5.0 -uvicorn>=0.35.0 \ No newline at end of file +uvicorn>=0.35.0 +msgpack +zstandard \ No newline at end of file diff --git a/scripts/mmipkg_tool.py b/scripts/mmipkg_tool.py new file mode 100644 index 00000000..6f14ae5b --- /dev/null +++ b/scripts/mmipkg_tool.py @@ -0,0 +1,1125 @@ +#!/usr/bin/env python3 +""" +.mmipkg 表情包打包工具 +用于导入/导出 MaiBot 已注册表情包 + +版本:1.0 +日期:2025-11-13 +""" + +import hashlib +import io +import os +import struct +import sys +import time +import uuid +from datetime import datetime +from pathlib import Path +from typing import BinaryIO, Dict, List, Optional, Tuple + +try: + import msgpack +except ImportError: + print("错误: 需要安装 msgpack 库") + print("请运行: pip install msgpack") + sys.exit(1) + +try: + import zstandard as zstd +except ImportError: + print("警告: zstandard 库未安装,将不支持压缩功能") + print("建议安装: pip install zstandard") + zstd = None + +try: + from PIL import Image +except ImportError: + print("错误: 需要安装 Pillow 库") + print("请运行: pip install Pillow") + sys.exit(1) + +try: + from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, TimeElapsedColumn + from rich.console import Console +except ImportError: + print("错误: 需要安装 rich 库") + print("请运行: pip install rich") + sys.exit(1) + +# 添加项目路径 +PROJECT_ROOT = Path(__file__).parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +# 延迟导入数据库模块(在添加路径之后) +# ruff: noqa: E402 +from src.common.database.database import db +from src.common.database.database_model import Emoji + +# 常量定义 +MAGIC = b'MMIP' +FOOTER_MAGIC = b'MMFF' +VERSION = 1 +FOOTER_VERSION = 1 + +# 安全限制 +MAX_MANIFEST_SIZE = 200 * 1024 * 1024 # 200 MB +MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB + +# 支持的图片格式 +SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.avif', '.bmp'} + +# 创建控制台对象 +console = Console() + + +class MMIPKGError(Exception): + """MMIPKG 相关错误""" + pass + + +def calculate_sha256(data: bytes) -> bytes: + """计算 SHA256 并返回二进制结果""" + return hashlib.sha256(data).digest() + + +def calculate_file_sha256(file_path: str) -> bytes: + """计算文件的 SHA256""" + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.digest() + + +def get_image_info(file_path: str) -> Tuple[int, int, str]: + """获取图片信息 (width, height, mime_type)""" + try: + with Image.open(file_path) as img: + width, height = img.size + format_lower = img.format.lower() if img.format else 'unknown' + mime_map = { + 'jpeg': 'image/jpeg', + 'jpg': 'image/jpeg', + 'png': 'image/png', + 'gif': 'image/gif', + 'webp': 'image/webp', + 'avif': 'image/avif', + 'bmp': 'image/bmp' + } + mime_type = mime_map.get(format_lower, f'image/{format_lower}') + return width, height, mime_type + except Exception as e: + print(f"警告: 无法读取图片信息 {file_path}: {e}") + return 0, 0, 'image/unknown' + + +def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 80) -> bytes: + """重新编码图片""" + try: + with Image.open(file_path) as img: + # 转换为 RGB(如果需要) + if img.mode in ('RGBA', 'LA', 'P'): + if output_format.lower() == 'jpeg': + # JPEG 不支持透明度,转为白色背景 + background = Image.new('RGB', img.size, (255, 255, 255)) + if img.mode == 'P': + img = img.convert('RGBA') + background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) + img = background + elif output_format.lower() == 'webp': + # WebP 支持透明度 + if img.mode == 'P': + img = img.convert('RGBA') + elif img.mode not in ('RGB', 'RGBA'): + img = img.convert('RGB') + + # 编码图片 + output = io.BytesIO() + save_kwargs = {'format': output_format.upper()} + + if output_format.lower() in {'jpeg', 'jpg'}: + save_kwargs['quality'] = quality + save_kwargs['optimize'] = True + elif output_format.lower() == 'webp': + save_kwargs['quality'] = quality + save_kwargs['method'] = 6 # 更好的压缩 + elif output_format.lower() == 'png': + save_kwargs['optimize'] = True + + img.save(output, **save_kwargs) + return output.getvalue() + except Exception as e: + raise MMIPKGError(f"重新编码图片失败 {file_path}: {e}") from e + + +class MMIPKGPacker: + """MMIPKG 打包器""" + + def __init__(self, + use_compression: bool = True, + zstd_level: int = 3, + reencode: Optional[str] = None, + reencode_quality: int = 80): + self.use_compression = use_compression and zstd is not None + self.zstd_level = zstd_level + self.reencode = reencode + self.reencode_quality = reencode_quality + + if use_compression and zstd is None: + print("警告: zstandard 未安装,将不使用压缩") + self.use_compression = False + + def pack_from_db(self, output_path: str, pack_name: Optional[str] = None, + custom_manifest: Optional[Dict] = None) -> bool: + """从数据库导出已注册的表情包 + + Args: + output_path: 输出文件路径 + pack_name: 包名称 + custom_manifest: 自定义 manifest 额外字段(可选) + """ + try: + # 连接数据库 + if db.is_closed(): + db.connect() + + # 查询所有已注册的表情包 + emojis = Emoji.select().where(Emoji.is_registered) + emoji_count = emojis.count() + + if emoji_count == 0: + print("错误: 数据库中没有已注册的表情包") + return False + + print(f"找到 {emoji_count} 个已注册的表情包") + + # 准备 items + items = [] + image_data_list = [] + + # 使用进度条处理表情包 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + console=console + ) as progress: + task = progress.add_task("[cyan]扫描表情包...", total=emoji_count) + + for idx, emoji in enumerate(emojis, 1): + progress.update(task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}") + + # 检查文件是否存在 + if not os.path.exists(emoji.full_path): + console.print(" [yellow]警告: 文件不存在,跳过[/yellow]") + progress.advance(task) + continue + + # 读取或重新编码图片 + if self.reencode: + try: + img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality) + except Exception as e: + console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]") + with open(emoji.full_path, 'rb') as f: + img_bytes = f.read() + else: + with open(emoji.full_path, 'rb') as f: + img_bytes = f.read() + + # 计算 SHA256 + img_sha = calculate_sha256(img_bytes) + + # 获取图片信息 + width, height, mime_type = get_image_info(emoji.full_path) + + # 构建 item(使用短字段名) + filename = os.path.basename(emoji.full_path) + item = { + "i": str(idx).zfill(5), # id + "fn": filename, # filename + "s": len(img_bytes), # size + "h": img_sha, # sha256 (binary) + "m": mime_type, # mime + "w": width, # width + "ht": height, # height + "opt": { + # 存储 MaiBot 特有的元数据 - 完整的数据库信息 + "desc": emoji.description or "", + "emotion": emoji.emotion or "", + "usage_count": emoji.usage_count or 0, + "last_used_time": emoji.last_used_time or time.time(), + "register_time": emoji.register_time or time.time(), + "record_time": emoji.record_time or time.time(), + "query_count": emoji.query_count or 0, + "format": emoji.format or "", + "emoji_hash": emoji.emoji_hash or "", + "is_registered": True, + "is_banned": emoji.is_banned or False, + } + } + + items.append(item) + image_data_list.append(img_bytes) + progress.advance(task) + + if not items: + print("错误: 没有有效的表情包可以打包") + return False + + print(f"找到 {len(items)} 个表情包可以打包...") + + # 准备打包 + pack_id = str(uuid.uuid4()) + if pack_name is None: + pack_name = f"MaiBot_Emojis_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + manifest = { + "p": pack_id, # pack_id + "n": pack_name, # pack_name + "t": datetime.now().isoformat(), # created_at + "a": items # items array + } + + # 添加自定义字段 + if custom_manifest: + for key, value in custom_manifest.items(): + if key not in manifest: # 不覆盖核心字段 + manifest[key] = value + + # 序列化 manifest + manifest_bytes = msgpack.packb(manifest, use_bin_type=True) + manifest_len = len(manifest_bytes) + + # 计算 payload 大小 + payload_size = 4 + manifest_len # manifest_len + manifest_bytes + for img_bytes in image_data_list: + payload_size += 4 + len(img_bytes) # img_len + img_bytes + + print(f"Manifest 大小: {manifest_len / 1024:.2f} KB") + print(f"Payload 未压缩大小: {payload_size / 1024 / 1024:.2f} MB") + + # 写入文件 + return self._write_package(output_path, manifest_bytes, image_data_list, payload_size) + + except Exception as e: + print(f"打包失败: {e}") + import traceback + traceback.print_exc() + return False + finally: + if not db.is_closed(): + db.close() + + def _write_package(self, output_path: str, manifest_bytes: bytes, + image_data_list: List[bytes], payload_size: int) -> bool: + """写入打包文件""" + try: + with open(output_path, 'wb') as f: + # 写入 Header (32 bytes) + flags = 0x01 if self.use_compression else 0x00 + header = MAGIC # 4 bytes + header += struct.pack('B', VERSION) # 1 byte + header += struct.pack('B', flags) # 1 byte + header += b'\x00\x00' # 2 bytes reserved + header += struct.pack('>Q', payload_size) # 8 bytes + header += struct.pack('>Q', len(manifest_bytes)) # 8 bytes + header += b'\x00' * 8 # 8 bytes reserved + + assert len(header) == 32, f"Header size mismatch: {len(header)}" + f.write(header) + + # 准备 payload 并计算 SHA256 + payload_sha = hashlib.sha256() + + # 写入 payload(可能压缩) + if self.use_compression: + console.print(f"[cyan]使用 Zstd 压缩 (level={self.zstd_level})...[/cyan]") + compressor = zstd.ZstdCompressor(level=self.zstd_level) + + with compressor.stream_writer(f, closefd=False) as writer: + # 写入 manifest + manifest_len_bytes = struct.pack('>I', len(manifest_bytes)) + writer.write(manifest_len_bytes) + writer.write(manifest_bytes) + payload_sha.update(manifest_len_bytes) + payload_sha.update(manifest_bytes) + + # 使用进度条写入所有图片 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + console=console + ) as progress: + task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list)) + + for idx, img_bytes in enumerate(image_data_list, 1): + progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}") + img_len_bytes = struct.pack('>I', len(img_bytes)) + writer.write(img_len_bytes) + writer.write(img_bytes) + payload_sha.update(img_len_bytes) + payload_sha.update(img_bytes) + progress.advance(task) + else: + # 不压缩,直接写入 + # 写入 manifest + manifest_len_bytes = struct.pack('>I', len(manifest_bytes)) + f.write(manifest_len_bytes) + f.write(manifest_bytes) + payload_sha.update(manifest_len_bytes) + payload_sha.update(manifest_bytes) + + # 使用进度条写入所有图片 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + console=console + ) as progress: + task = progress.add_task("[green]写入图片...", total=len(image_data_list)) + + for idx, img_bytes in enumerate(image_data_list, 1): + progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}") + img_len_bytes = struct.pack('>I', len(img_bytes)) + f.write(img_len_bytes) + f.write(img_bytes) + payload_sha.update(img_len_bytes) + payload_sha.update(img_bytes) + progress.advance(task) + + # 写入 Footer (40 bytes) + file_sha256 = payload_sha.digest() + footer = FOOTER_MAGIC # 4 bytes + footer += file_sha256 # 32 bytes + footer += struct.pack('B', FOOTER_VERSION) # 1 byte + footer += b'\x00' * 3 # 3 bytes reserved + + assert len(footer) == 40, f"Footer size mismatch: {len(footer)}" + f.write(footer) + + file_size = f.tell() + print("\n打包完成!") + print(f"输出文件: {output_path}") + print(f"文件大小: {file_size / 1024 / 1024:.2f} MB") + if self.use_compression: + ratio = (1 - file_size / (payload_size + 32 + 40)) * 100 + print(f"压缩率: {ratio:.1f}%") + + return True + + except Exception as e: + print(f"写入文件失败: {e}") + import traceback + traceback.print_exc() + return False + + +class MMIPKGUnpacker: + """MMIPKG 解包器""" + + def __init__(self, verify_sha: bool = True): + self.verify_sha = verify_sha + + def import_to_db(self, package_path: str, + output_dir: Optional[str] = None, + replace_existing: bool = False, + batch_size: int = 500) -> bool: + """导入到数据库""" + try: + if not os.path.exists(package_path): + print(f"错误: 文件不存在: {package_path}") + return False + + # 连接数据库 + if db.is_closed(): + db.connect() + + # 如果未指定输出目录,使用默认的已注册表情包目录 + if output_dir is None: + output_dir = os.path.join(PROJECT_ROOT, "data", "emoji_registed") + + os.makedirs(output_dir, exist_ok=True) + + print(f"正在读取包: {package_path}") + + with open(package_path, 'rb') as f: + # 读取 Header + header = f.read(32) + if len(header) != 32: + raise MMIPKGError("Header 大小不正确") + + magic = header[:4] + if magic != MAGIC: + raise MMIPKGError(f"无效的 MAGIC: {magic}") + + version = struct.unpack('B', header[4:5])[0] + if version != VERSION: + print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配") + + flags = struct.unpack('B', header[5:6])[0] + is_compressed = bool(flags & 0x01) + + payload_uncompressed_len = struct.unpack('>Q', header[8:16])[0] + manifest_uncompressed_len = struct.unpack('>Q', header[16:24])[0] + + # 安全检查 + if manifest_uncompressed_len > MAX_MANIFEST_SIZE: + raise MMIPKGError(f"Manifest 过大: {manifest_uncompressed_len} bytes") + if payload_uncompressed_len > MAX_PAYLOAD_SIZE: + raise MMIPKGError(f"Payload 过大: {payload_uncompressed_len} bytes") + + print(f"压缩: {'是' if is_compressed else '否'}") + print(f"Payload 大小: {payload_uncompressed_len / 1024 / 1024:.2f} MB") + + # 读取 payload + payload_start = f.tell() + + # 找到 footer 位置 + f.seek(-40, 2) # 从文件末尾向前 40 bytes + footer = f.read(40) + + if footer[:4] != FOOTER_MAGIC: + raise MMIPKGError("无效的 Footer MAGIC") + + expected_sha = footer[4:36] + + # 回到 payload 开始 + f.seek(payload_start) + + # 读取整个 payload(用于计算 SHA) + footer_start = os.path.getsize(package_path) - 40 + payload_data_size = footer_start - payload_start + + # 解压或直接读取 + if is_compressed: + if zstd is None: + raise MMIPKGError("需要 zstandard 库来解压此包") + + print("解压 payload...") + compressed_data = f.read(payload_data_size) + + # 使用流式解压,不需要预知解压后大小 + decompressor = zstd.ZstdDecompressor() + try: + # 方法1:使用 stream_reader(推荐) + dctx = zstd.ZstdDecompressor() + with io.BytesIO(compressed_data) as compressed_stream: + with dctx.stream_reader(compressed_stream) as reader: + payload_data = reader.read() + except Exception as e: + # 方法2:如果流式失败,尝试直接解压(兼容旧格式) + print(f" 流式解压失败,尝试直接解压: {e}") + try: + payload_data = decompressor.decompress(compressed_data, max_output_size=payload_uncompressed_len) + except Exception as e2: + raise MMIPKGError(f"解压失败: {e2}") from e2 + else: + payload_data = f.read(payload_data_size) + + # 验证 SHA256 + actual_sha = calculate_sha256(payload_data) + if self.verify_sha and actual_sha != expected_sha: + raise MMIPKGError("SHA256 校验失败!") + if self.verify_sha: + print("✓ SHA256 校验通过") + + # 解析 payload + payload_stream = io.BytesIO(payload_data) + + # 读取 manifest + manifest_len_bytes = payload_stream.read(4) + manifest_len = struct.unpack('>I', manifest_len_bytes)[0] + manifest_bytes = payload_stream.read(manifest_len) + manifest = msgpack.unpackb(manifest_bytes, raw=False) + + pack_id = manifest.get("p", "unknown") + pack_name = manifest.get("n", "unknown") + created_at = manifest.get("t", "unknown") + items = manifest.get("a", []) + + print("\n包信息:") + print(f" ID: {pack_id}") + print(f" 名称: {pack_name}") + print(f" 创建时间: {created_at}") + print(f" 表情包数量: {len(items)}") + + # 导入表情包 + return self._import_items(payload_stream, items, output_dir, + replace_existing, batch_size) + + except Exception as e: + print(f"导入失败: {e}") + import traceback + traceback.print_exc() + return False + finally: + if not db.is_closed(): + db.close() + + def _import_items(self, payload_stream: BinaryIO, items: List[Dict], + output_dir: str, replace_existing: bool, batch_size: int) -> bool: + """导入 items 到数据库""" + try: + imported_count = 0 + skipped_count = 0 + error_count = 0 + + # 开始事务,使用进度条 + with db.atomic(): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + console=console + ) as progress: + task = progress.add_task("[cyan]导入表情包...", total=len(items)) + + for idx, item in enumerate(items, 1): + try: + progress.update(task, description=f"[cyan]导入 {idx}/{len(items)}") + + # 读取图片数据 + img_len_bytes = payload_stream.read(4) + if len(img_len_bytes) != 4: + console.print(f"[red]错误: 读取图片长度失败 (item {idx})[/red]") + error_count += 1 + progress.advance(task) + continue + + img_len = struct.unpack('>I', img_len_bytes)[0] + img_bytes = payload_stream.read(img_len) + + if len(img_bytes) != img_len: + console.print(f"[red]错误: 图片数据不完整 (item {idx})[/red]") + error_count += 1 + progress.advance(task) + continue + + # 验证图片 SHA + if self.verify_sha and (expected_sha := item.get("h")): + actual_sha = calculate_sha256(img_bytes) + if actual_sha != expected_sha: + console.print(f"[yellow]警告: 图片 SHA256 不匹配 (item {idx}), 跳过[/yellow]") + error_count += 1 + progress.advance(task) + continue + + # 获取元数据 + opt = item.get("opt", {}) + # 使用 or 提供回退值,如果 emoji_hash 为空则使用后续计算的值 + emoji_hash = opt.get("emoji_hash") or calculate_sha256(img_bytes).hex() + + # 检查是否已存在 + existing = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) + + if existing and not replace_existing: + skipped_count += 1 + progress.advance(task) + continue + + # 保存图片文件 + filename = item.get("fn", f"{emoji_hash[:8]}.{opt.get('format', 'png')}") + file_path = os.path.join(output_dir, filename) + + # 如果文件已存在且不替换,生成新文件名 + if os.path.exists(file_path) and not replace_existing: + base, ext = os.path.splitext(filename) + counter = 1 + while os.path.exists(file_path): + filename = f"{base}_{counter}{ext}" + file_path = os.path.join(output_dir, filename) + counter += 1 + + with open(file_path, 'wb') as img_file: + img_file.write(img_bytes) + + # 准备数据库记录 + current_time = time.time() + emotion_str = opt.get("emotion", "") + + if existing and replace_existing: + # 更新现有记录 - 恢复完整的数据库信息 + existing.full_path = file_path + existing.format = opt.get("format", "") + existing.description = opt.get("desc", "") + existing.emotion = emotion_str + existing.usage_count = opt.get("usage_count", 0) + existing.last_used_time = opt.get("last_used_time", current_time) + existing.register_time = opt.get("register_time", current_time) + existing.record_time = opt.get("record_time", current_time) + existing.query_count = opt.get("query_count", 0) + existing.is_registered = opt.get("is_registered", True) + existing.is_banned = opt.get("is_banned", False) + existing.save() + else: + # 创建新记录 - 恢复完整的数据库信息 + Emoji.create( + emoji_hash=emoji_hash, + full_path=file_path, + format=opt.get("format", ""), + description=opt.get("desc", ""), + emotion=emotion_str, + usage_count=opt.get("usage_count", 0), + last_used_time=opt.get("last_used_time", current_time), + register_time=opt.get("register_time", current_time), + record_time=opt.get("record_time", current_time), + query_count=opt.get("query_count", 0), + is_registered=opt.get("is_registered", True), + is_banned=opt.get("is_banned", False), + ) + + imported_count += 1 + progress.advance(task) + + except Exception as e: + console.print(f"[red]处理 item {idx} 时出错: {e}[/red]") + error_count += 1 + progress.advance(task) + continue + + # 输出统计 + + console.print(f"\n[green]✓ 成功导入 {imported_count} 个表情包[/green]") + console.print(f" [yellow]跳过 {skipped_count} 个[/yellow]") + if error_count > 0: + console.print(f" [red]错误 {error_count} 个[/red]") + + return error_count == 0 + + except Exception as e: + console.print(f"[red]导入 items 失败: {e}[/red]") + import traceback + traceback.print_exc() + return False + + +def print_header(): + """打印欢迎信息""" + console.print("\n[bold cyan]" + "=" * 70 + "[/bold cyan]") + console.print("[bold cyan]" + " " * 20 + "MaiBot 表情包打包工具" + "[/bold cyan]") + console.print("[bold cyan]" + " " * 25 + ".mmipkg 格式" + "[/bold cyan]") + console.print("[bold cyan]" + "=" * 70 + "[/bold cyan]") + + +def print_menu(): + """打印主菜单""" + console.print("\n[yellow]请选择操作:[/yellow]") + console.print(" [1] [bold]导出表情包[/bold] (从数据库导出到 .mmipkg 文件)") + console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)") + console.print(" [0] [bold]退出[/bold]") + console.print() +def get_input(prompt: str, default: Optional[str] = None, + choices: Optional[List[str]] = None) -> str: + """获取用户输入""" + if default: + prompt = f"{prompt} (默认: {default})" + + while True: + try: + value = input(f"{prompt}: ").strip() + + if not value: + if default: + return default + console.print(" [yellow]⚠ 输入不能为空,请重新输入[/yellow]") + continue + + if choices and value not in choices: + console.print(f" [yellow]⚠ 无效的选择,请选择: {', '.join(choices)}[/yellow]") + continue + + return value + except KeyboardInterrupt: + console.print("\n[yellow]操作已取消[/yellow]") + raise + except EOFError: + console.print("\n[yellow]输入已结束[/yellow]") + if default: + return default + raise KeyboardInterrupt from None + + +def get_yes_no(prompt: str, default: bool = False) -> bool: + """获取是/否输入""" + default_str = "Y/n" if default else "y/N" + while True: + try: + value = input(f"{prompt} ({default_str}): ").strip().lower() + + if not value: + return default + + if value in ('y', 'yes', '是'): + return True + elif value in ('n', 'no', '否'): + return False + else: + console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]") + except KeyboardInterrupt: + console.print("\n[yellow]操作已取消[/yellow]") + raise + except EOFError: + return default + + +def get_int(prompt: str, default: int, min_val: int = 1, max_val: int = 100) -> int: + """获取整数输入""" + while True: + try: + value = input(f"{prompt} (默认: {default}, 范围: {min_val}-{max_val}): ").strip() + + if not value: + return default + + try: + num = int(value) + if min_val <= num <= max_val: + return num + else: + console.print(f" [yellow]⚠ 请输入 {min_val} 到 {max_val} 之间的数字[/yellow]") + except ValueError: + console.print(" [yellow]⚠ 请输入有效的数字[/yellow]") + except KeyboardInterrupt: + console.print("\n[yellow]操作已取消[/yellow]") + raise + except EOFError: + return default + + +def print_compression_level_info(): + """打印压缩级别说明""" + console.print("\n [cyan]压缩级别说明:[/cyan]") + console.print(" 1-3: 快速压缩,文件稍大") + console.print(" 4-9: 平衡模式(推荐)") + console.print(" 10-15: 高压缩,速度较慢") + console.print(" 16-22: 极限压缩,速度很慢") + + +def print_import_mode_selection(): + """打印导入模式选择菜单""" + console.print("\n[yellow]请选择导入模式:[/yellow]") + console.print(" [1] 自动扫描并导入 data/import_emoji 文件夹中的所有 .mmipkg 文件") + console.print(" [2] 手动指定文件路径导入") + + +def interactive_export(): + """交互式导出""" + console.print("\n[cyan]" + "-" * 70 + "[/cyan]") + console.print("[bold]导出表情包到 .mmipkg 文件[/bold]") + console.print("[cyan]" + "-" * 70 + "[/cyan]") + + # 检查数据库 + try: + if db.is_closed(): + db.connect() + + emoji_count = Emoji.select().where(Emoji.is_registered).count() + console.print(f"\n[green]✓ 找到 {emoji_count} 个已注册的表情包[/green]") + + if emoji_count == 0: + console.print("[red]✗ 数据库中没有已注册的表情包,无法导出[/red]") + return False + except Exception as e: + console.print(f"[red]✗ 数据库连接失败: {e}[/red]") + return False + finally: + if not db.is_closed(): + db.close() + + # 获取输出文件路径 + console.print("\n[yellow]1. 输出文件设置[/yellow]") + default_filename = f"maibot_emojis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mmipkg" + output_path = get_input(" 输出文件路径", default_filename) + + # 确保有 .mmipkg 扩展名 + if not output_path.endswith('.mmipkg'): + output_path += '.mmipkg' + + # 获取包名称 + default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}" + pack_name = get_input(" 包名称", default_pack_name) + + # 自定义 manifest + console.print("\n[yellow]2. 包信息设置(可选)[/yellow]") + if get_yes_no(" 是否添加包的作者和介绍信息", False): + custom_manifest = { + "author": author + } if (author := input(" 作者名称(可选): ").strip()) else {} + + # 介绍信息 + console.print(" 包介绍(限制 100 字以内):") + if description := input(" > ").strip(): + if len(description) > 100: + console.print(f" [yellow]⚠ 介绍过长({len(description)} 字),已截断至 100 字[/yellow]") + description = description[:100] + custom_manifest["description"] = description + + if not custom_manifest: + custom_manifest = None + else: + console.print(" [green]✓ 已添加包信息[/green]") + else: + custom_manifest = None + + # 压缩设置 + console.print("\n[yellow]3. 压缩设置[/yellow]") + use_compression = get_yes_no(" 使用 Zstd 压缩", True) + + zstd_level = 3 + if use_compression: + print_compression_level_info() + zstd_level = get_int(" 选择压缩级别", 3, 1, 22) + + # 重新编码设置 + console.print("\n[yellow]4. 图片编码设置[/yellow]") + if get_yes_no(" 是否重新编码图片(可显著减小文件大小)", False): + console.print("\n [cyan]可用格式:[/cyan]") + console.print(" webp: 推荐,体积小且支持透明度") + console.print(" jpeg: 最小体积,但不支持透明度") + console.print(" png: 无损,文件较大") + reencode = get_input(" 选择格式", "webp", ['webp', 'jpeg', 'png']) + + quality = get_int(" 编码质量", 80, 1, 100) if reencode in ('webp', 'jpeg') else 80 + else: + reencode = None + quality = 80 + + # 确认导出 + console.print("\n[cyan]" + "-" * 70 + "[/cyan]") + console.print("[bold]导出配置:[/bold]") + console.print(f" 输出文件: {output_path}") + console.print(f" 包名称: {pack_name or '自动生成'}") + if custom_manifest: + if "author" in custom_manifest: + console.print(f" 作者: {custom_manifest['author']}") + if "description" in custom_manifest: + console.print(f" 介绍: {custom_manifest['description']}") + compression_info = f"是 (级别 {zstd_level})" if use_compression else "否" + console.print(f" 压缩: {compression_info}") + console.print(f" 重新编码: {reencode or '否'}") + if reencode: + console.print(f" 编码质量: {quality}") + console.print(f" 表情包数量: {emoji_count}") + console.print("[cyan]" + "-" * 70 + "[/cyan]") + + if not get_yes_no("\n确认导出", True): + console.print("[red]✗ 已取消导出[/red]") + return False + + # 开始导出 + console.print("\n[cyan]开始导出...[/cyan]") + packer = MMIPKGPacker( + use_compression=use_compression, + zstd_level=zstd_level, + reencode=reencode, + reencode_quality=quality + ) + + success = packer.pack_from_db(output_path, pack_name, custom_manifest) + + if success: + console.print(f"\n[green]✓ 导出成功: {output_path}[/green]") + else: + console.print("\n[red]✗ 导出失败[/red]") + + return success + + +def interactive_import(): + """交互式导入""" + console.print("\n[cyan]" + "-" * 70 + "[/cyan]") + console.print("[bold]从 .mmipkg 文件导入表情包[/bold]") + console.print("[cyan]" + "-" * 70 + "[/cyan]") + + # 选择导入模式 + print_import_mode_selection() + import_mode = get_input("请选择", "1", ['1', '2']) + + input_files = [] + + if import_mode == '1': + # 自动扫描模式 + import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji") + os.makedirs(import_dir, exist_ok=True) + + console.print(f"\n[cyan]扫描目录: {import_dir}[/cyan]") + + # 查找所有 .mmipkg 文件 + for file in os.listdir(import_dir): + if file.endswith('.mmipkg'): + file_path = os.path.join(import_dir, file) + if os.path.isfile(file_path): + input_files.append(file_path) + + if not input_files: + console.print("[red]✗ 目录中没有找到 .mmipkg 文件[/red]") + console.print(f" 请将表情包文件放入: {import_dir}") + return False + + console.print(f"\n[green]找到 {len(input_files)} 个文件:[/green]") + for i, file_path in enumerate(input_files, 1): + file_size = os.path.getsize(file_path) / 1024 / 1024 + console.print(f" [{i}] {os.path.basename(file_path)} ({file_size:.2f} MB)") + + if not get_yes_no(f"\n确认导入这 {len(input_files)} 个文件", True): + console.print("[red]✗ 已取消导入[/red]") + return False + else: + # 手动输入模式 + console.print("\n[yellow]1. 输入文件设置[/yellow]") + input_path = get_input(" 输入文件路径 (.mmipkg)") + + if not os.path.exists(input_path): + console.print(f"[red]✗ 文件不存在: {input_path}[/red]") + return False + + input_files.append(input_path) + + # 获取输出目录 + console.print("\n[yellow]2. 输出目录设置[/yellow]") + default_output_dir = os.path.join(PROJECT_ROOT, "data", "emoji_registed") + output_dir = get_input(" 输出目录", default_output_dir) + + # 导入选项 + console.print("\n[yellow]3. 导入选项[/yellow]") + replace_existing = get_yes_no(" 替换已存在的表情包", False) + verify_sha = get_yes_no(" 验证 SHA256 完整性(推荐)", True) + + # 批量大小 + console.print("\n[yellow]4. 性能设置[/yellow]") + console.print(" [cyan]批量大小说明:[/cyan]") + console.print(" 100-500: 默认,适合大多数情况") + console.print(" 500-1000: 快速导入大量表情包") + console.print(" 1000+: 极速模式,但内存占用更高") + batch_size = get_int(" 批量提交大小", 500, 100, 5000) + + # 确认导入 + console.print("\n[cyan]" + "-" * 70 + "[/cyan]") + console.print("[bold]导入配置:[/bold]") + console.print(f" 导入模式: {'自动扫描' if import_mode == '1' else '手动指定'}") + console.print(f" 文件数量: {len(input_files)}") + console.print(f" 输出目录: {output_dir}") + console.print(f" 替换已存在: {'是' if replace_existing else '否'}") + console.print(f" SHA256 验证: {'是' if verify_sha else '否'}") + console.print(f" 批量大小: {batch_size}") + console.print("[cyan]" + "-" * 70 + "[/cyan]") + + if not get_yes_no("\n确认导入", True): + console.print("[red]✗ 已取消导入[/red]") + return False + + # 开始导入 + unpacker = MMIPKGUnpacker(verify_sha=verify_sha) + + total_success = 0 + total_failed = 0 + + # 使用进度条处理多个文件 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + console=console + ) as progress: + task = progress.add_task("[cyan]导入文件...", total=len(input_files)) + + for i, input_path in enumerate(input_files, 1): + progress.update(task, description=f"[cyan]导入 [{i}/{len(input_files)}]: {os.path.basename(input_path)}") + + console.print(f"\n[bold]{'=' * 70}[/bold]") + console.print(f"[bold]导入文件 [{i}/{len(input_files)}]: {os.path.basename(input_path)}[/bold]") + console.print(f"[bold]{'=' * 70}[/bold]") + + success = unpacker.import_to_db( + input_path, + output_dir=output_dir, + replace_existing=replace_existing, + batch_size=batch_size + ) + + if success: + total_success += 1 + else: + total_failed += 1 + + progress.advance(task) + + # 总结 + console.print(f"\n[bold]{'=' * 70}[/bold]") + console.print("[bold]导入总结:[/bold]") + console.print(f" [green]成功: {total_success} 个文件[/green]") + if total_failed > 0: + console.print(f" [red]失败: {total_failed} 个文件[/red]") + console.print(f"[bold]{'=' * 70}[/bold]") + + return total_failed == 0 + + +def main(): + """主函数 - 交互式界面""" + print_header() + + try: + while True: + print_menu() + try: + choice = get_input("请选择", "1", ['0', '1', '2']) + except KeyboardInterrupt: + console.print("\n[green]再见![/green]") + return 0 + + if choice == '0': + console.print("\n[green]再见![/green]") + return 0 + + elif choice == '1': + try: + interactive_export() + except KeyboardInterrupt: + console.print("\n\n[yellow]✗ 操作已取消[/yellow]") + except Exception as e: + console.print(f"\n[red]✗ 发生错误: {e}[/red]") + import traceback + traceback.print_exc() + + try: + input("\n按 Enter 键继续...") + except (KeyboardInterrupt, EOFError): + pass + + elif choice == '2': + try: + interactive_import() + except KeyboardInterrupt: + console.print("\n\n[yellow]✗ 操作已取消[/yellow]") + except Exception as e: + console.print(f"\n[red]✗ 发生错误: {e}[/red]") + import traceback + traceback.print_exc() + + try: + input("\n按 Enter 键继续...") + except (KeyboardInterrupt, EOFError): + pass + except KeyboardInterrupt: + console.print("\n[green]再见![/green]") + return 0 + + return 0 + + +if __name__ == '__main__': + sys.exit(main())