From 44f427dc64e7bfe3684445f7dd96f45029253dc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 19 Nov 2025 23:35:14 +0800 Subject: [PATCH] Ruff fix --- bot.py | 8 +- scripts/build_io_pairs.py | 2 - scripts/expression_scatter_analysis.py | 1 - scripts/mmipkg_tool.py | 531 ++++++++-------- src/chat/heart_flow/heartFC_chat.py | 1 - src/chat/knowledge/__init__.py | 2 + src/chat/knowledge/qa_manager.py | 5 +- src/chat/replyer/group_generator.py | 12 +- src/chat/replyer/private_generator.py | 8 +- src/chat/utils/chat_history_summarizer.py | 2 +- src/common/logger.py | 38 +- src/config/official_configs.py | 4 +- src/express/expression_learner.py | 6 +- src/express/expression_selector.py | 9 +- src/jargon/jargon_miner.py | 245 ++++---- src/llm_models/model_client/gemini_client.py | 2 +- src/llm_models/model_client/openai_client.py | 2 +- src/llm_models/payload_content/message.py | 4 +- src/llm_models/utils_model.py | 2 +- src/main.py | 11 +- src/memory_system/memory_retrieval.py | 471 +++++++------- src/memory_system/memory_utils.py | 45 +- src/memory_system/retrieval_tools/__init__.py | 1 + .../retrieval_tools/query_chat_history.py | 98 ++- .../retrieval_tools/query_lpmm_knowledge.py | 2 - .../retrieval_tools/query_person_info.py | 108 ++-- .../retrieval_tools/tool_registry.py | 24 +- src/person_info/person_info.py | 37 +- src/plugin_system/core/tool_use.py | 1 - src/webui/config_routes.py | 4 +- src/webui/emoji_routes.py | 280 ++++----- src/webui/expression_routes.py | 195 +++--- src/webui/git_mirror_service.py | 331 ++++------ src/webui/logs_ws.py | 43 +- src/webui/manager.py | 41 +- src/webui/person_routes.py | 171 +++--- src/webui/plugin_progress_ws.py | 31 +- src/webui/plugin_routes.py | 579 ++++++++---------- src/webui/routers/system.py | 47 +- src/webui/routes.py | 116 ++-- src/webui/statistics_routes.py | 213 +++---- src/webui/token_manager.py | 71 +-- 42 files changed, 1742 insertions(+), 2062 deletions(-) diff --git a/bot.py b/bot.py index 38894b29..7ba9af4b 100644 --- a/bot.py +++ b/bot.py @@ -1,7 +1,6 @@ import asyncio import hashlib import os -import sys import time import platform import traceback @@ -30,7 +29,7 @@ else: raise # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 -from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa +from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa initialize_logging() @@ -212,9 +211,10 @@ if __name__ == "__main__": # 创建事件循环 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + # 初始化 WebSocket 日志推送 from src.common.logger import initialize_ws_handler + initialize_ws_handler(loop) try: @@ -251,7 +251,7 @@ if __name__ == "__main__": print(f"关闭日志系统时出错: {e}") print("[主程序] 准备退出...") - + # 使用 os._exit() 强制退出,避免被阻塞 # 由于已经在 graceful_shutdown() 中完成了所有清理工作,这是安全的 os._exit(exit_code) diff --git a/scripts/build_io_pairs.py b/scripts/build_io_pairs.py index f934566a..944d7671 100644 --- a/scripts/build_io_pairs.py +++ b/scripts/build_io_pairs.py @@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) - - SECONDS_5_MINUTES = 5 * 60 diff --git a/scripts/expression_scatter_analysis.py b/scripts/expression_scatter_analysis.py index b022c94e..3cb22f71 100644 --- a/scripts/expression_scatter_analysis.py +++ b/scripts/expression_scatter_analysis.py @@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) - # 设置中文字体 plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"] plt.rcParams["axes.unicode_minus"] = False diff --git a/scripts/mmipkg_tool.py b/scripts/mmipkg_tool.py index 6f14ae5b..3b5369ac 100644 --- a/scripts/mmipkg_tool.py +++ b/scripts/mmipkg_tool.py @@ -57,8 +57,8 @@ from src.common.database.database import db from src.common.database.database_model import Emoji # 常量定义 -MAGIC = b'MMIP' -FOOTER_MAGIC = b'MMFF' +MAGIC = b"MMIP" +FOOTER_MAGIC = b"MMFF" VERSION = 1 FOOTER_VERSION = 1 @@ -67,7 +67,7 @@ MAX_MANIFEST_SIZE = 200 * 1024 * 1024 # 200 MB MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB # 支持的图片格式 -SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.avif', '.bmp'} +SUPPORTED_FORMATS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".avif", ".bmp"} # 创建控制台对象 console = Console() @@ -75,6 +75,7 @@ console = Console() class MMIPKGError(Exception): """MMIPKG 相关错误""" + pass @@ -97,56 +98,56 @@ def get_image_info(file_path: str) -> Tuple[int, int, str]: try: with Image.open(file_path) as img: width, height = img.size - format_lower = img.format.lower() if img.format else 'unknown' + format_lower = img.format.lower() if img.format else "unknown" mime_map = { - 'jpeg': 'image/jpeg', - 'jpg': 'image/jpeg', - 'png': 'image/png', - 'gif': 'image/gif', - 'webp': 'image/webp', - 'avif': 'image/avif', - 'bmp': 'image/bmp' + "jpeg": "image/jpeg", + "jpg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "webp": "image/webp", + "avif": "image/avif", + "bmp": "image/bmp", } - mime_type = mime_map.get(format_lower, f'image/{format_lower}') + mime_type = mime_map.get(format_lower, f"image/{format_lower}") return width, height, mime_type except Exception as e: print(f"警告: 无法读取图片信息 {file_path}: {e}") - return 0, 0, 'image/unknown' + return 0, 0, "image/unknown" -def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 80) -> bytes: +def reencode_image(file_path: str, output_format: str = "webp", quality: int = 80) -> bytes: """重新编码图片""" try: with Image.open(file_path) as img: # 转换为 RGB(如果需要) - if img.mode in ('RGBA', 'LA', 'P'): - if output_format.lower() == 'jpeg': + if img.mode in ("RGBA", "LA", "P"): + if output_format.lower() == "jpeg": # JPEG 不支持透明度,转为白色背景 - background = Image.new('RGB', img.size, (255, 255, 255)) - if img.mode == 'P': - img = img.convert('RGBA') - background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) + background = Image.new("RGB", img.size, (255, 255, 255)) + if img.mode == "P": + img = img.convert("RGBA") + background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None) img = background - elif output_format.lower() == 'webp': + elif output_format.lower() == "webp": # WebP 支持透明度 - if img.mode == 'P': - img = img.convert('RGBA') - elif img.mode not in ('RGB', 'RGBA'): - img = img.convert('RGB') - + if img.mode == "P": + img = img.convert("RGBA") + elif img.mode not in ("RGB", "RGBA"): + img = img.convert("RGB") + # 编码图片 output = io.BytesIO() - save_kwargs = {'format': output_format.upper()} - - if output_format.lower() in {'jpeg', 'jpg'}: - save_kwargs['quality'] = quality - save_kwargs['optimize'] = True - elif output_format.lower() == 'webp': - save_kwargs['quality'] = quality - save_kwargs['method'] = 6 # 更好的压缩 - elif output_format.lower() == 'png': - save_kwargs['optimize'] = True - + save_kwargs = {"format": output_format.upper()} + + if output_format.lower() in {"jpeg", "jpg"}: + save_kwargs["quality"] = quality + save_kwargs["optimize"] = True + elif output_format.lower() == "webp": + save_kwargs["quality"] = quality + save_kwargs["method"] = 6 # 更好的压缩 + elif output_format.lower() == "png": + save_kwargs["optimize"] = True + img.save(output, **save_kwargs) return output.getvalue() except Exception as e: @@ -155,25 +156,28 @@ def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 8 class MMIPKGPacker: """MMIPKG 打包器""" - - def __init__(self, - use_compression: bool = True, - zstd_level: int = 3, - reencode: Optional[str] = None, - reencode_quality: int = 80): + + def __init__( + self, + use_compression: bool = True, + zstd_level: int = 3, + reencode: Optional[str] = None, + reencode_quality: int = 80, + ): self.use_compression = use_compression and zstd is not None self.zstd_level = zstd_level self.reencode = reencode self.reencode_quality = reencode_quality - + if use_compression and zstd is None: print("警告: zstandard 未安装,将不使用压缩") self.use_compression = False - - def pack_from_db(self, output_path: str, pack_name: Optional[str] = None, - custom_manifest: Optional[Dict] = None) -> bool: + + def pack_from_db( + self, output_path: str, pack_name: Optional[str] = None, custom_manifest: Optional[Dict] = None + ) -> bool: """从数据库导出已注册的表情包 - + Args: output_path: 输出文件路径 pack_name: 包名称 @@ -183,21 +187,21 @@ class MMIPKGPacker: # 连接数据库 if db.is_closed(): db.connect() - + # 查询所有已注册的表情包 emojis = Emoji.select().where(Emoji.is_registered) emoji_count = emojis.count() - + if emoji_count == 0: print("错误: 数据库中没有已注册的表情包") return False - + print(f"找到 {emoji_count} 个已注册的表情包") - + # 准备 items items = [] image_data_list = [] - + # 使用进度条处理表情包 with Progress( SpinnerColumn(), @@ -205,37 +209,39 @@ class MMIPKGPacker: BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeElapsedColumn(), - console=console + console=console, ) as progress: task = progress.add_task("[cyan]扫描表情包...", total=emoji_count) - + for idx, emoji in enumerate(emojis, 1): - progress.update(task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}") - + progress.update( + task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}" + ) + # 检查文件是否存在 if not os.path.exists(emoji.full_path): console.print(" [yellow]警告: 文件不存在,跳过[/yellow]") progress.advance(task) continue - + # 读取或重新编码图片 if self.reencode: try: img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality) except Exception as e: console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]") - with open(emoji.full_path, 'rb') as f: + with open(emoji.full_path, "rb") as f: img_bytes = f.read() else: - with open(emoji.full_path, 'rb') as f: + with open(emoji.full_path, "rb") as f: img_bytes = f.read() - + # 计算 SHA256 img_sha = calculate_sha256(img_bytes) - + # 获取图片信息 width, height, mime_type = get_image_info(emoji.full_path) - + # 构建 item(使用短字段名) filename = os.path.basename(emoji.full_path) item = { @@ -259,95 +265,97 @@ class MMIPKGPacker: "emoji_hash": emoji.emoji_hash or "", "is_registered": True, "is_banned": emoji.is_banned or False, - } + }, } - + items.append(item) image_data_list.append(img_bytes) progress.advance(task) - + if not items: print("错误: 没有有效的表情包可以打包") return False - + print(f"找到 {len(items)} 个表情包可以打包...") - + # 准备打包 pack_id = str(uuid.uuid4()) if pack_name is None: pack_name = f"MaiBot_Emojis_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - + manifest = { "p": pack_id, # pack_id "n": pack_name, # pack_name "t": datetime.now().isoformat(), # created_at - "a": items # items array + "a": items, # items array } - + # 添加自定义字段 if custom_manifest: for key, value in custom_manifest.items(): if key not in manifest: # 不覆盖核心字段 manifest[key] = value - + # 序列化 manifest manifest_bytes = msgpack.packb(manifest, use_bin_type=True) manifest_len = len(manifest_bytes) - + # 计算 payload 大小 payload_size = 4 + manifest_len # manifest_len + manifest_bytes for img_bytes in image_data_list: payload_size += 4 + len(img_bytes) # img_len + img_bytes - + print(f"Manifest 大小: {manifest_len / 1024:.2f} KB") print(f"Payload 未压缩大小: {payload_size / 1024 / 1024:.2f} MB") - + # 写入文件 return self._write_package(output_path, manifest_bytes, image_data_list, payload_size) - + except Exception as e: print(f"打包失败: {e}") import traceback + traceback.print_exc() return False finally: if not db.is_closed(): db.close() - - def _write_package(self, output_path: str, manifest_bytes: bytes, - image_data_list: List[bytes], payload_size: int) -> bool: + + def _write_package( + self, output_path: str, manifest_bytes: bytes, image_data_list: List[bytes], payload_size: int + ) -> bool: """写入打包文件""" try: - with open(output_path, 'wb') as f: + with open(output_path, "wb") as f: # 写入 Header (32 bytes) flags = 0x01 if self.use_compression else 0x00 header = MAGIC # 4 bytes - header += struct.pack('B', VERSION) # 1 byte - header += struct.pack('B', flags) # 1 byte - header += b'\x00\x00' # 2 bytes reserved - header += struct.pack('>Q', payload_size) # 8 bytes - header += struct.pack('>Q', len(manifest_bytes)) # 8 bytes - header += b'\x00' * 8 # 8 bytes reserved - + header += struct.pack("B", VERSION) # 1 byte + header += struct.pack("B", flags) # 1 byte + header += b"\x00\x00" # 2 bytes reserved + header += struct.pack(">Q", payload_size) # 8 bytes + header += struct.pack(">Q", len(manifest_bytes)) # 8 bytes + header += b"\x00" * 8 # 8 bytes reserved + assert len(header) == 32, f"Header size mismatch: {len(header)}" f.write(header) - + # 准备 payload 并计算 SHA256 payload_sha = hashlib.sha256() - + # 写入 payload(可能压缩) if self.use_compression: console.print(f"[cyan]使用 Zstd 压缩 (level={self.zstd_level})...[/cyan]") compressor = zstd.ZstdCompressor(level=self.zstd_level) - + with compressor.stream_writer(f, closefd=False) as writer: # 写入 manifest - manifest_len_bytes = struct.pack('>I', len(manifest_bytes)) + manifest_len_bytes = struct.pack(">I", len(manifest_bytes)) writer.write(manifest_len_bytes) writer.write(manifest_bytes) payload_sha.update(manifest_len_bytes) payload_sha.update(manifest_bytes) - + # 使用进度条写入所有图片 with Progress( SpinnerColumn(), @@ -355,13 +363,13 @@ class MMIPKGPacker: BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeRemainingColumn(), - console=console + console=console, ) as progress: task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list)) - + for idx, img_bytes in enumerate(image_data_list, 1): progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}") - img_len_bytes = struct.pack('>I', len(img_bytes)) + img_len_bytes = struct.pack(">I", len(img_bytes)) writer.write(img_len_bytes) writer.write(img_bytes) payload_sha.update(img_len_bytes) @@ -370,12 +378,12 @@ class MMIPKGPacker: else: # 不压缩,直接写入 # 写入 manifest - manifest_len_bytes = struct.pack('>I', len(manifest_bytes)) + manifest_len_bytes = struct.pack(">I", len(manifest_bytes)) f.write(manifest_len_bytes) f.write(manifest_bytes) payload_sha.update(manifest_len_bytes) payload_sha.update(manifest_bytes) - + # 使用进度条写入所有图片 with Progress( SpinnerColumn(), @@ -383,29 +391,29 @@ class MMIPKGPacker: BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeRemainingColumn(), - console=console + console=console, ) as progress: task = progress.add_task("[green]写入图片...", total=len(image_data_list)) - + for idx, img_bytes in enumerate(image_data_list, 1): progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}") - img_len_bytes = struct.pack('>I', len(img_bytes)) + img_len_bytes = struct.pack(">I", len(img_bytes)) f.write(img_len_bytes) f.write(img_bytes) payload_sha.update(img_len_bytes) payload_sha.update(img_bytes) progress.advance(task) - + # 写入 Footer (40 bytes) file_sha256 = payload_sha.digest() footer = FOOTER_MAGIC # 4 bytes footer += file_sha256 # 32 bytes - footer += struct.pack('B', FOOTER_VERSION) # 1 byte - footer += b'\x00' * 3 # 3 bytes reserved - + footer += struct.pack("B", FOOTER_VERSION) # 1 byte + footer += b"\x00" * 3 # 3 bytes reserved + assert len(footer) == 40, f"Footer size mismatch: {len(footer)}" f.write(footer) - + file_size = f.tell() print("\n打包完成!") print(f"输出文件: {output_path}") @@ -413,100 +421,100 @@ class MMIPKGPacker: if self.use_compression: ratio = (1 - file_size / (payload_size + 32 + 40)) * 100 print(f"压缩率: {ratio:.1f}%") - + return True - + except Exception as e: print(f"写入文件失败: {e}") import traceback + traceback.print_exc() return False class MMIPKGUnpacker: """MMIPKG 解包器""" - + def __init__(self, verify_sha: bool = True): self.verify_sha = verify_sha - - def import_to_db(self, package_path: str, - output_dir: Optional[str] = None, - replace_existing: bool = False, - batch_size: int = 500) -> bool: + + def import_to_db( + self, package_path: str, output_dir: Optional[str] = None, replace_existing: bool = False, batch_size: int = 500 + ) -> bool: """导入到数据库""" try: if not os.path.exists(package_path): print(f"错误: 文件不存在: {package_path}") return False - + # 连接数据库 if db.is_closed(): db.connect() - + # 如果未指定输出目录,使用默认的已注册表情包目录 if output_dir is None: output_dir = os.path.join(PROJECT_ROOT, "data", "emoji_registed") - + os.makedirs(output_dir, exist_ok=True) - + print(f"正在读取包: {package_path}") - - with open(package_path, 'rb') as f: + + with open(package_path, "rb") as f: # 读取 Header header = f.read(32) if len(header) != 32: raise MMIPKGError("Header 大小不正确") - + magic = header[:4] if magic != MAGIC: raise MMIPKGError(f"无效的 MAGIC: {magic}") - - version = struct.unpack('B', header[4:5])[0] + + version = struct.unpack("B", header[4:5])[0] if version != VERSION: print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配") - - flags = struct.unpack('B', header[5:6])[0] + + flags = struct.unpack("B", header[5:6])[0] is_compressed = bool(flags & 0x01) - - payload_uncompressed_len = struct.unpack('>Q', header[8:16])[0] - manifest_uncompressed_len = struct.unpack('>Q', header[16:24])[0] - + + payload_uncompressed_len = struct.unpack(">Q", header[8:16])[0] + manifest_uncompressed_len = struct.unpack(">Q", header[16:24])[0] + # 安全检查 if manifest_uncompressed_len > MAX_MANIFEST_SIZE: raise MMIPKGError(f"Manifest 过大: {manifest_uncompressed_len} bytes") if payload_uncompressed_len > MAX_PAYLOAD_SIZE: raise MMIPKGError(f"Payload 过大: {payload_uncompressed_len} bytes") - + print(f"压缩: {'是' if is_compressed else '否'}") print(f"Payload 大小: {payload_uncompressed_len / 1024 / 1024:.2f} MB") - + # 读取 payload payload_start = f.tell() - + # 找到 footer 位置 f.seek(-40, 2) # 从文件末尾向前 40 bytes footer = f.read(40) - + if footer[:4] != FOOTER_MAGIC: raise MMIPKGError("无效的 Footer MAGIC") - + expected_sha = footer[4:36] - + # 回到 payload 开始 f.seek(payload_start) - + # 读取整个 payload(用于计算 SHA) footer_start = os.path.getsize(package_path) - 40 payload_data_size = footer_start - payload_start - + # 解压或直接读取 if is_compressed: if zstd is None: raise MMIPKGError("需要 zstandard 库来解压此包") - + print("解压 payload...") compressed_data = f.read(payload_data_size) - + # 使用流式解压,不需要预知解压后大小 decompressor = zstd.ZstdDecompressor() try: @@ -519,60 +527,63 @@ class MMIPKGUnpacker: # 方法2:如果流式失败,尝试直接解压(兼容旧格式) print(f" 流式解压失败,尝试直接解压: {e}") try: - payload_data = decompressor.decompress(compressed_data, max_output_size=payload_uncompressed_len) + payload_data = decompressor.decompress( + compressed_data, max_output_size=payload_uncompressed_len + ) except Exception as e2: raise MMIPKGError(f"解压失败: {e2}") from e2 else: payload_data = f.read(payload_data_size) - + # 验证 SHA256 actual_sha = calculate_sha256(payload_data) if self.verify_sha and actual_sha != expected_sha: raise MMIPKGError("SHA256 校验失败!") if self.verify_sha: print("✓ SHA256 校验通过") - + # 解析 payload payload_stream = io.BytesIO(payload_data) - + # 读取 manifest manifest_len_bytes = payload_stream.read(4) - manifest_len = struct.unpack('>I', manifest_len_bytes)[0] + manifest_len = struct.unpack(">I", manifest_len_bytes)[0] manifest_bytes = payload_stream.read(manifest_len) manifest = msgpack.unpackb(manifest_bytes, raw=False) - + pack_id = manifest.get("p", "unknown") pack_name = manifest.get("n", "unknown") created_at = manifest.get("t", "unknown") items = manifest.get("a", []) - + print("\n包信息:") print(f" ID: {pack_id}") print(f" 名称: {pack_name}") print(f" 创建时间: {created_at}") print(f" 表情包数量: {len(items)}") - + # 导入表情包 - return self._import_items(payload_stream, items, output_dir, - replace_existing, batch_size) - + return self._import_items(payload_stream, items, output_dir, replace_existing, batch_size) + except Exception as e: print(f"导入失败: {e}") import traceback + traceback.print_exc() return False finally: if not db.is_closed(): db.close() - - def _import_items(self, payload_stream: BinaryIO, items: List[Dict], - output_dir: str, replace_existing: bool, batch_size: int) -> bool: + + def _import_items( + self, payload_stream: BinaryIO, items: List[Dict], output_dir: str, replace_existing: bool, batch_size: int + ) -> bool: """导入 items 到数据库""" try: imported_count = 0 skipped_count = 0 error_count = 0 - + # 开始事务,使用进度条 with db.atomic(): with Progress( @@ -581,14 +592,14 @@ class MMIPKGUnpacker: BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeRemainingColumn(), - console=console + console=console, ) as progress: task = progress.add_task("[cyan]导入表情包...", total=len(items)) - + for idx, item in enumerate(items, 1): try: progress.update(task, description=f"[cyan]导入 {idx}/{len(items)}") - + # 读取图片数据 img_len_bytes = payload_stream.read(4) if len(img_len_bytes) != 4: @@ -596,16 +607,16 @@ class MMIPKGUnpacker: error_count += 1 progress.advance(task) continue - - img_len = struct.unpack('>I', img_len_bytes)[0] + + img_len = struct.unpack(">I", img_len_bytes)[0] img_bytes = payload_stream.read(img_len) - + if len(img_bytes) != img_len: console.print(f"[red]错误: 图片数据不完整 (item {idx})[/red]") error_count += 1 progress.advance(task) continue - + # 验证图片 SHA if self.verify_sha and (expected_sha := item.get("h")): actual_sha = calculate_sha256(img_bytes) @@ -614,24 +625,24 @@ class MMIPKGUnpacker: error_count += 1 progress.advance(task) continue - + # 获取元数据 opt = item.get("opt", {}) # 使用 or 提供回退值,如果 emoji_hash 为空则使用后续计算的值 emoji_hash = opt.get("emoji_hash") or calculate_sha256(img_bytes).hex() - + # 检查是否已存在 existing = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) - + if existing and not replace_existing: skipped_count += 1 progress.advance(task) continue - + # 保存图片文件 filename = item.get("fn", f"{emoji_hash[:8]}.{opt.get('format', 'png')}") file_path = os.path.join(output_dir, filename) - + # 如果文件已存在且不替换,生成新文件名 if os.path.exists(file_path) and not replace_existing: base, ext = os.path.splitext(filename) @@ -640,14 +651,14 @@ class MMIPKGUnpacker: filename = f"{base}_{counter}{ext}" file_path = os.path.join(output_dir, filename) counter += 1 - - with open(file_path, 'wb') as img_file: + + with open(file_path, "wb") as img_file: img_file.write(img_bytes) - + # 准备数据库记录 current_time = time.time() emotion_str = opt.get("emotion", "") - + if existing and replace_existing: # 更新现有记录 - 恢复完整的数据库信息 existing.full_path = file_path @@ -678,28 +689,29 @@ class MMIPKGUnpacker: is_registered=opt.get("is_registered", True), is_banned=opt.get("is_banned", False), ) - + imported_count += 1 progress.advance(task) - + except Exception as e: console.print(f"[red]处理 item {idx} 时出错: {e}[/red]") error_count += 1 progress.advance(task) continue - + # 输出统计 - + console.print(f"\n[green]✓ 成功导入 {imported_count} 个表情包[/green]") console.print(f" [yellow]跳过 {skipped_count} 个[/yellow]") if error_count > 0: console.print(f" [red]错误 {error_count} 个[/red]") - + return error_count == 0 - + except Exception as e: console.print(f"[red]导入 items 失败: {e}[/red]") import traceback + traceback.print_exc() return False @@ -719,26 +731,27 @@ def print_menu(): console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)") console.print(" [0] [bold]退出[/bold]") console.print() -def get_input(prompt: str, default: Optional[str] = None, - choices: Optional[List[str]] = None) -> str: + + +def get_input(prompt: str, default: Optional[str] = None, choices: Optional[List[str]] = None) -> str: """获取用户输入""" if default: prompt = f"{prompt} (默认: {default})" - + while True: try: value = input(f"{prompt}: ").strip() - + if not value: if default: return default console.print(" [yellow]⚠ 输入不能为空,请重新输入[/yellow]") continue - + if choices and value not in choices: console.print(f" [yellow]⚠ 无效的选择,请选择: {', '.join(choices)}[/yellow]") continue - + return value except KeyboardInterrupt: console.print("\n[yellow]操作已取消[/yellow]") @@ -756,13 +769,13 @@ def get_yes_no(prompt: str, default: bool = False) -> bool: while True: try: value = input(f"{prompt} ({default_str}): ").strip().lower() - + if not value: return default - - if value in ('y', 'yes', '是'): + + if value in ("y", "yes", "是"): return True - elif value in ('n', 'no', '否'): + elif value in ("n", "no", "否"): return False else: console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]") @@ -778,10 +791,10 @@ def get_int(prompt: str, default: int, min_val: int = 1, max_val: int = 100) -> while True: try: value = input(f"{prompt} (默认: {default}, 范围: {min_val}-{max_val}): ").strip() - + if not value: return default - + try: num = int(value) if min_val <= num <= max_val: @@ -818,15 +831,15 @@ def interactive_export(): console.print("\n[cyan]" + "-" * 70 + "[/cyan]") console.print("[bold]导出表情包到 .mmipkg 文件[/bold]") console.print("[cyan]" + "-" * 70 + "[/cyan]") - + # 检查数据库 try: if db.is_closed(): db.connect() - + emoji_count = Emoji.select().where(Emoji.is_registered).count() console.print(f"\n[green]✓ 找到 {emoji_count} 个已注册的表情包[/green]") - + if emoji_count == 0: console.print("[red]✗ 数据库中没有已注册的表情包,无法导出[/red]") return False @@ -836,27 +849,25 @@ def interactive_export(): finally: if not db.is_closed(): db.close() - + # 获取输出文件路径 console.print("\n[yellow]1. 输出文件设置[/yellow]") default_filename = f"maibot_emojis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mmipkg" output_path = get_input(" 输出文件路径", default_filename) - + # 确保有 .mmipkg 扩展名 - if not output_path.endswith('.mmipkg'): - output_path += '.mmipkg' - + if not output_path.endswith(".mmipkg"): + output_path += ".mmipkg" + # 获取包名称 default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}" pack_name = get_input(" 包名称", default_pack_name) - + # 自定义 manifest console.print("\n[yellow]2. 包信息设置(可选)[/yellow]") if get_yes_no(" 是否添加包的作者和介绍信息", False): - custom_manifest = { - "author": author - } if (author := input(" 作者名称(可选): ").strip()) else {} - + custom_manifest = {"author": author} if (author := input(" 作者名称(可选): ").strip()) else {} + # 介绍信息 console.print(" 包介绍(限制 100 字以内):") if description := input(" > ").strip(): @@ -864,23 +875,23 @@ def interactive_export(): console.print(f" [yellow]⚠ 介绍过长({len(description)} 字),已截断至 100 字[/yellow]") description = description[:100] custom_manifest["description"] = description - + if not custom_manifest: custom_manifest = None else: console.print(" [green]✓ 已添加包信息[/green]") else: custom_manifest = None - + # 压缩设置 console.print("\n[yellow]3. 压缩设置[/yellow]") use_compression = get_yes_no(" 使用 Zstd 压缩", True) - + zstd_level = 3 if use_compression: print_compression_level_info() zstd_level = get_int(" 选择压缩级别", 3, 1, 22) - + # 重新编码设置 console.print("\n[yellow]4. 图片编码设置[/yellow]") if get_yes_no(" 是否重新编码图片(可显著减小文件大小)", False): @@ -888,13 +899,13 @@ def interactive_export(): console.print(" webp: 推荐,体积小且支持透明度") console.print(" jpeg: 最小体积,但不支持透明度") console.print(" png: 无损,文件较大") - reencode = get_input(" 选择格式", "webp", ['webp', 'jpeg', 'png']) - - quality = get_int(" 编码质量", 80, 1, 100) if reencode in ('webp', 'jpeg') else 80 + reencode = get_input(" 选择格式", "webp", ["webp", "jpeg", "png"]) + + quality = get_int(" 编码质量", 80, 1, 100) if reencode in ("webp", "jpeg") else 80 else: reencode = None quality = 80 - + # 确认导出 console.print("\n[cyan]" + "-" * 70 + "[/cyan]") console.print("[bold]导出配置:[/bold]") @@ -912,27 +923,24 @@ def interactive_export(): console.print(f" 编码质量: {quality}") console.print(f" 表情包数量: {emoji_count}") console.print("[cyan]" + "-" * 70 + "[/cyan]") - + if not get_yes_no("\n确认导出", True): console.print("[red]✗ 已取消导出[/red]") return False - + # 开始导出 console.print("\n[cyan]开始导出...[/cyan]") packer = MMIPKGPacker( - use_compression=use_compression, - zstd_level=zstd_level, - reencode=reencode, - reencode_quality=quality + use_compression=use_compression, zstd_level=zstd_level, reencode=reencode, reencode_quality=quality ) - + success = packer.pack_from_db(output_path, pack_name, custom_manifest) - + if success: console.print(f"\n[green]✓ 导出成功: {output_path}[/green]") else: console.print("\n[red]✗ 导出失败[/red]") - + return success @@ -941,37 +949,37 @@ def interactive_import(): console.print("\n[cyan]" + "-" * 70 + "[/cyan]") console.print("[bold]从 .mmipkg 文件导入表情包[/bold]") console.print("[cyan]" + "-" * 70 + "[/cyan]") - + # 选择导入模式 print_import_mode_selection() - import_mode = get_input("请选择", "1", ['1', '2']) - + import_mode = get_input("请选择", "1", ["1", "2"]) + input_files = [] - - if import_mode == '1': + + if import_mode == "1": # 自动扫描模式 import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji") os.makedirs(import_dir, exist_ok=True) - + console.print(f"\n[cyan]扫描目录: {import_dir}[/cyan]") - + # 查找所有 .mmipkg 文件 for file in os.listdir(import_dir): - if file.endswith('.mmipkg'): + if file.endswith(".mmipkg"): file_path = os.path.join(import_dir, file) if os.path.isfile(file_path): input_files.append(file_path) - + if not input_files: console.print("[red]✗ 目录中没有找到 .mmipkg 文件[/red]") console.print(f" 请将表情包文件放入: {import_dir}") return False - + console.print(f"\n[green]找到 {len(input_files)} 个文件:[/green]") for i, file_path in enumerate(input_files, 1): file_size = os.path.getsize(file_path) / 1024 / 1024 console.print(f" [{i}] {os.path.basename(file_path)} ({file_size:.2f} MB)") - + if not get_yes_no(f"\n确认导入这 {len(input_files)} 个文件", True): console.print("[red]✗ 已取消导入[/red]") return False @@ -979,23 +987,23 @@ def interactive_import(): # 手动输入模式 console.print("\n[yellow]1. 输入文件设置[/yellow]") input_path = get_input(" 输入文件路径 (.mmipkg)") - + if not os.path.exists(input_path): console.print(f"[red]✗ 文件不存在: {input_path}[/red]") return False - + input_files.append(input_path) - + # 获取输出目录 console.print("\n[yellow]2. 输出目录设置[/yellow]") default_output_dir = os.path.join(PROJECT_ROOT, "data", "emoji_registed") output_dir = get_input(" 输出目录", default_output_dir) - + # 导入选项 console.print("\n[yellow]3. 导入选项[/yellow]") replace_existing = get_yes_no(" 替换已存在的表情包", False) verify_sha = get_yes_no(" 验证 SHA256 完整性(推荐)", True) - + # 批量大小 console.print("\n[yellow]4. 性能设置[/yellow]") console.print(" [cyan]批量大小说明:[/cyan]") @@ -1003,7 +1011,7 @@ def interactive_import(): console.print(" 500-1000: 快速导入大量表情包") console.print(" 1000+: 极速模式,但内存占用更高") batch_size = get_int(" 批量提交大小", 500, 100, 5000) - + # 确认导入 console.print("\n[cyan]" + "-" * 70 + "[/cyan]") console.print("[bold]导入配置:[/bold]") @@ -1014,17 +1022,17 @@ def interactive_import(): console.print(f" SHA256 验证: {'是' if verify_sha else '否'}") console.print(f" 批量大小: {batch_size}") console.print("[cyan]" + "-" * 70 + "[/cyan]") - + if not get_yes_no("\n确认导入", True): console.print("[red]✗ 已取消导入[/red]") return False - + # 开始导入 unpacker = MMIPKGUnpacker(verify_sha=verify_sha) - + total_success = 0 total_failed = 0 - + # 使用进度条处理多个文件 with Progress( SpinnerColumn(), @@ -1032,31 +1040,28 @@ def interactive_import(): BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeElapsedColumn(), - console=console + console=console, ) as progress: task = progress.add_task("[cyan]导入文件...", total=len(input_files)) - + for i, input_path in enumerate(input_files, 1): progress.update(task, description=f"[cyan]导入 [{i}/{len(input_files)}]: {os.path.basename(input_path)}") - + console.print(f"\n[bold]{'=' * 70}[/bold]") console.print(f"[bold]导入文件 [{i}/{len(input_files)}]: {os.path.basename(input_path)}[/bold]") console.print(f"[bold]{'=' * 70}[/bold]") - + success = unpacker.import_to_db( - input_path, - output_dir=output_dir, - replace_existing=replace_existing, - batch_size=batch_size + input_path, output_dir=output_dir, replace_existing=replace_existing, batch_size=batch_size ) - + if success: total_success += 1 else: total_failed += 1 - + progress.advance(task) - + # 总结 console.print(f"\n[bold]{'=' * 70}[/bold]") console.print("[bold]导入总结:[/bold]") @@ -1064,28 +1069,28 @@ def interactive_import(): if total_failed > 0: console.print(f" [red]失败: {total_failed} 个文件[/red]") console.print(f"[bold]{'=' * 70}[/bold]") - + return total_failed == 0 def main(): """主函数 - 交互式界面""" print_header() - + try: while True: print_menu() try: - choice = get_input("请选择", "1", ['0', '1', '2']) + choice = get_input("请选择", "1", ["0", "1", "2"]) except KeyboardInterrupt: console.print("\n[green]再见![/green]") return 0 - - if choice == '0': + + if choice == "0": console.print("\n[green]再见![/green]") return 0 - - elif choice == '1': + + elif choice == "1": try: interactive_export() except KeyboardInterrupt: @@ -1093,14 +1098,15 @@ def main(): except Exception as e: console.print(f"\n[red]✗ 发生错误: {e}[/red]") import traceback + traceback.print_exc() - + try: input("\n按 Enter 键继续...") except (KeyboardInterrupt, EOFError): pass - - elif choice == '2': + + elif choice == "2": try: interactive_import() except KeyboardInterrupt: @@ -1108,8 +1114,9 @@ def main(): except Exception as e: console.print(f"\n[red]✗ 发生错误: {e}[/red]") import traceback + traceback.print_exc() - + try: input("\n按 Enter 键继续...") except (KeyboardInterrupt, EOFError): @@ -1117,9 +1124,9 @@ def main(): except KeyboardInterrupt: console.print("\n[green]再见![/green]") return 0 - + return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 707c44bd..dbbd39ef 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -333,7 +333,6 @@ class HeartFChatting: # 重置连续 no_reply 计数 self.consecutive_no_reply_count = 0 reason = "" - await database_api.store_action_info( chat_stream=self.chat_stream, diff --git a/src/chat/knowledge/__init__.py b/src/chat/knowledge/__init__.py index b9c96708..a570277c 100644 --- a/src/chat/knowledge/__init__.py +++ b/src/chat/knowledge/__init__.py @@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data") qa_manager = None inspire_manager = None + def get_qa_manager(): return qa_manager + def lpmm_start_up(): # sourcery skip: extract-duplicate-method # 检查LPMM知识库是否启用 if global_config.lpmm_knowledge.enable: diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index 66a6e4d1..9988dc22 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -128,11 +128,10 @@ class QAManager: selected_knowledge = knowledge[:limit] formatted_knowledge = [ - f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" - for i, k in enumerate(selected_knowledge) + f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge) ] # if max_score is not None: - # formatted_knowledge.insert(0, f"最高相关系数:{max_score}") + # formatted_knowledge.insert(0, f"最高相关系数:{max_score}") found_knowledge = "\n".join(formatted_knowledge) if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH: diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 3019ca0d..cc948db1 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -226,7 +226,9 @@ class DefaultReplyer: traceback.print_exc() return False, llm_response - async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]: + async def build_expression_habits( + self, chat_history: str, target: str, reply_reason: str = "" + ) -> Tuple[str, List[int]]: # sourcery skip: for-append-to-extend """构建表达习惯块 @@ -1094,10 +1096,10 @@ class DefaultReplyer: if not global_config.lpmm_knowledge.enable: logger.debug("LPMM知识库未启用,跳过获取知识库内容") return "" - + if global_config.lpmm_knowledge.lpmm_mode == "agent": return "" - + time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) bot_name = global_config.bot.nickname @@ -1115,10 +1117,10 @@ class DefaultReplyer: model_config=model_config.model_task_config.tool_use, tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()], ) - + # logger.info(f"工具调用提示词: {prompt}") # logger.info(f"工具调用: {tool_calls}") - + if tool_calls: result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool()) end_time = time.time() diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 76eedf8c..74b04f6e 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -241,7 +241,9 @@ class PrivateReplyer: return f"{sender_relation}" - async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]: + async def build_expression_habits( + self, chat_history: str, target: str, reply_reason: str = "" + ) -> Tuple[str, List[int]]: # sourcery skip: for-append-to-extend """构建表达习惯块 @@ -1032,10 +1034,10 @@ class PrivateReplyer: if not global_config.lpmm_knowledge.enable: logger.debug("LPMM知识库未启用,跳过获取知识库内容") return "" - + if global_config.lpmm_knowledge.lpmm_mode == "agent": return "" - + time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) bot_name = global_config.bot.nickname diff --git a/src/chat/utils/chat_history_summarizer.py b/src/chat/utils/chat_history_summarizer.py index 36bb5ff0..f471781d 100644 --- a/src/chat/utils/chat_history_summarizer.py +++ b/src/chat/utils/chat_history_summarizer.py @@ -106,7 +106,7 @@ class ChatHistorySummarizer: await self._check_and_package(current_time) self.last_check_time = current_time return - + logger.info( f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}" ) diff --git a/src/common/logger.py b/src/common/logger.py index 9e7e08d4..4cf40398 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -72,16 +72,16 @@ def get_ws_handler(): def initialize_ws_handler(loop): """初始化 WebSocket handler 的事件循环 - + Args: loop: asyncio 事件循环 """ handler = get_ws_handler() handler.set_loop(loop) - + # 为 WebSocket handler 设置 JSON 格式化器(与文件格式相同) handler.setFormatter(file_formatter) - + # 添加到根日志记录器 root_logger = logging.getLogger() if handler not in root_logger.handlers: @@ -177,43 +177,44 @@ class TimestampedFileHandler(logging.Handler): class WebSocketLogHandler(logging.Handler): """WebSocket 日志处理器 - 将日志实时推送到前端""" - + _log_counter = 0 # 类级别计数器,确保 ID 唯一性 - + def __init__(self, loop=None): super().__init__() self.loop = loop self._initialized = False - + def set_loop(self, loop): """设置事件循环""" self.loop = loop self._initialized = True - + def emit(self, record): """发送日志到 WebSocket 客户端""" if not self._initialized or self.loop is None: return - + try: # 获取格式化后的消息 # 对于 structlog,formatted message 包含完整的日志信息 formatted_msg = self.format(record) if self.formatter else record.getMessage() - + # 如果是 JSON 格式(文件格式化器),解析它 message = formatted_msg try: import json + log_dict = json.loads(formatted_msg) - message = log_dict.get('event', formatted_msg) + message = log_dict.get("event", formatted_msg) except (json.JSONDecodeError, ValueError): # 不是 JSON,直接使用消息 message = formatted_msg - + # 生成唯一 ID: 时间戳毫秒 + 自增计数器 WebSocketLogHandler._log_counter += 1 log_id = f"{int(record.created * 1000)}_{WebSocketLogHandler._log_counter}" - + # 格式化日志数据 log_data = { "id": log_id, @@ -222,20 +223,17 @@ class WebSocketLogHandler(logging.Handler): "module": record.name, "message": message, } - + # 异步广播日志(不阻塞日志记录) try: import asyncio from src.webui.logs_ws import broadcast_log - - asyncio.run_coroutine_threadsafe( - broadcast_log(log_data), - self.loop - ) + + asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop) except Exception: # WebSocket 推送失败不影响日志记录 pass - + except Exception: # 不要让 WebSocket 错误影响日志系统 self.handleError(record) @@ -255,7 +253,7 @@ def close_handlers(): if _console_handler: _console_handler.close() _console_handler = None - + if _ws_handler: _ws_handler.close() _ws_handler = None diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 795b38cb..454f6976 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -647,7 +647,7 @@ class LPMMKnowledgeConfig(ConfigBase): enable: bool = True """是否启用LPMM知识库""" - + lpmm_mode: Literal["classic", "agent"] = "classic" """LPMM知识库模式,可选:classic经典模式,agent 模式,结合最新的记忆一同使用""" @@ -690,4 +690,4 @@ class JargonConfig(ConfigBase): """Jargon配置类""" all_global: bool = False - """是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id""" \ No newline at end of file + """是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id""" diff --git a/src/express/expression_learner.py b/src/express/expression_learner.py index 3bc14142..57c03e77 100644 --- a/src/express/expression_learner.py +++ b/src/express/expression_learner.py @@ -467,11 +467,7 @@ class ExpressionLearner: up_content: str, current_time: float, ) -> None: - expr_obj = ( - Expression.select() - .where((Expression.chat_id == self.chat_id) & (Expression.style == style)) - .first() - ) + expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first() if expr_obj: await self._update_existing_expression( diff --git a/src/express/expression_selector.py b/src/express/expression_selector.py index 66c49def..b4e25f36 100644 --- a/src/express/expression_selector.py +++ b/src/express/expression_selector.py @@ -42,8 +42,6 @@ def init_prompt(): Prompt(expression_evaluation_prompt, "expression_evaluation_prompt") - - class ExpressionSelector: def __init__(self): self.llm_model = LLMRequest( @@ -238,9 +236,9 @@ class ExpressionSelector: else: target_message_str = "" target_message_extra_block = "" - + chat_context = f"以下是正在进行的聊天内容:{chat_info}" - + # 构建reply_reason块 if reply_reason: reply_reason_block = f"你的回复理由是:{reply_reason}" @@ -262,9 +260,8 @@ class ExpressionSelector: # 4. 调用LLM content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) - # print(prompt) - + if not content: logger.warning("LLM返回空结果") return [], [] diff --git a/src/jargon/jargon_miner.py b/src/jargon/jargon_miner.py index 59d79df5..db1b79f9 100644 --- a/src/jargon/jargon_miner.py +++ b/src/jargon/jargon_miner.py @@ -36,10 +36,7 @@ def _contains_bot_self_name(content: str) -> bool: target = content.strip().lower() nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower() - alias_names = [ - str(alias or "").strip().lower() - for alias in getattr(bot_config, "alias_names", []) or [] - ] + alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []] candidates = [name for name in [nickname, *alias_names] if name] @@ -149,7 +146,7 @@ async def _enrich_raw_content_if_needed( ) -> List[str]: """ 检查raw_content是否只包含黑话本身,如果是,则获取该消息的前三条消息作为原始内容 - + Args: content: 黑话内容 raw_content_list: 原始raw_content列表 @@ -157,22 +154,22 @@ async def _enrich_raw_content_if_needed( messages: 当前时间窗口内的消息列表 extraction_start_time: 提取开始时间 extraction_end_time: 提取结束时间 - + Returns: 处理后的raw_content列表 """ enriched_list = [] - + for raw_content in raw_content_list: # 检查raw_content是否只包含黑话本身(去除空白字符后比较) raw_content_clean = raw_content.strip() content_clean = content.strip() - + # 如果raw_content只包含黑话本身(可能有一些标点或空白),则尝试获取上下文 # 去除所有空白字符后比较,确保只包含黑话本身 raw_content_normalized = raw_content_clean.replace(" ", "").replace("\n", "").replace("\t", "") content_normalized = content_clean.replace(" ", "").replace("\n", "").replace("\t", "") - + if raw_content_normalized == content_normalized: # 在消息列表中查找只包含该黑话的消息(去除空白后比较) target_message = None @@ -183,22 +180,20 @@ async def _enrich_raw_content_if_needed( if msg_content_normalized == content_normalized: target_message = msg break - + if target_message and target_message.time: # 获取该消息的前三条消息 try: previous_messages = get_raw_msg_before_timestamp_with_chat( - chat_id=chat_id, - timestamp=target_message.time, - limit=3 + chat_id=chat_id, timestamp=target_message.time, limit=3 ) - + if previous_messages: # 将前三条消息和当前消息一起格式化 context_messages = previous_messages + [target_message] # 按时间排序 context_messages.sort(key=lambda x: x.time or 0) - + # 格式化为可读消息 formatted_context, _ = await build_readable_messages_with_list( context_messages, @@ -206,7 +201,7 @@ async def _enrich_raw_content_if_needed( timestamp_mode="relative", truncate=False, ) - + if formatted_context.strip(): enriched_list.append(formatted_context.strip()) logger.warning(f"为黑话 {content} 补充了上下文消息") @@ -226,7 +221,7 @@ async def _enrich_raw_content_if_needed( else: # raw_content包含更多内容,直接使用 enriched_list.append(raw_content) - + return enriched_list @@ -240,31 +235,31 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool: # 如果已完成所有推断,不再推断 if jargon_obj.is_complete: return False - + count = jargon_obj.count or 0 last_inference = jargon_obj.last_inference_count or 0 - + # 阈值列表:3,6, 10, 20, 40, 60, 100 - thresholds = [3,6, 10, 20, 40, 60, 100] - + thresholds = [3, 6, 10, 20, 40, 60, 100] + if count < thresholds[0]: return False - + # 如果count没有超过上次判定值,不需要判定 if count <= last_inference: return False - + # 找到第一个大于last_inference的阈值 next_threshold = None for threshold in thresholds: if threshold > last_inference: next_threshold = threshold break - + # 如果没有找到下一个阈值,说明已经超过100,不应该再推断 if next_threshold is None: return False - + # 检查count是否达到或超过这个阈值 return count >= next_threshold @@ -275,13 +270,13 @@ class JargonMiner: self.last_learning_time: float = time.time() # 频率控制,可按需调整 self.min_messages_for_learning: int = 10 - self.min_learning_interval: float = 20 + self.min_learning_interval: float = 20 self.llm = LLMRequest( model_set=model_config.model_task_config.utils, request_type="jargon.extract", ) - + # 初始化stream_name作为类属性,避免重复提取 chat_manager = get_chat_manager() stream_name = chat_manager.get_stream_name(self.chat_id) @@ -306,17 +301,19 @@ class JargonMiner: try: content = jargon_obj.content raw_content_str = jargon_obj.raw_content or "" - + # 解析raw_content列表 raw_content_list = [] if raw_content_str: try: - raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str + raw_content_list = ( + json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str + ) if not isinstance(raw_content_list, list): raw_content_list = [raw_content_list] if raw_content_list else [] except (json.JSONDecodeError, TypeError): raw_content_list = [raw_content_str] if raw_content_str else [] - + if not raw_content_list: logger.warning(f"jargon {content} 没有raw_content,跳过推断") return @@ -328,12 +325,12 @@ class JargonMiner: content=content, raw_content_list=raw_content_text, ) - + response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3) if not response1: logger.warning(f"jargon {content} 推断1失败:无响应") return - + # 解析推断1结果 inference1 = None try: @@ -349,7 +346,7 @@ class JargonMiner: except Exception as e: logger.error(f"jargon {content} 推断1解析失败: {e}") return - + # 检查推断1是否表示信息不足无法推断 no_info = inference1.get("no_info", False) meaning1 = inference1.get("meaning", "").strip() @@ -360,18 +357,17 @@ class JargonMiner: jargon_obj.save() return - # 步骤2: 仅基于content推断 prompt2 = await global_prompt_manager.format_prompt( "jargon_inference_content_only_prompt", content=content, ) - + response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3) if not response2: logger.warning(f"jargon {content} 推断2失败:无响应") return - + # 解析推断2结果 inference2 = None try: @@ -387,13 +383,12 @@ class JargonMiner: except Exception as e: logger.error(f"jargon {content} 推断2解析失败: {e}") return - logger.info(f"jargon {content} 推断2提示词: {prompt2}") logger.info(f"jargon {content} 推断2结果: {response2}") logger.info(f"jargon {content} 推断1提示词: {prompt1}") logger.info(f"jargon {content} 推断1结果: {response1}") - + if global_config.debug.show_jargon_prompt: logger.info(f"jargon {content} 推断2提示词: {prompt2}") logger.info(f"jargon {content} 推断2结果: {response2}") @@ -404,22 +399,22 @@ class JargonMiner: logger.debug(f"jargon {content} 推断2结果: {response2}") logger.debug(f"jargon {content} 推断1提示词: {prompt1}") logger.debug(f"jargon {content} 推断1结果: {response1}") - + # 步骤3: 比较两个推断结果 prompt3 = await global_prompt_manager.format_prompt( "jargon_compare_inference_prompt", inference1=json.dumps(inference1, ensure_ascii=False), inference2=json.dumps(inference2, ensure_ascii=False), ) - + if global_config.debug.show_jargon_prompt: logger.info(f"jargon {content} 比较提示词: {prompt3}") - + response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3) if not response3: logger.warning(f"jargon {content} 比较失败:无响应") return - + # 解析比较结果 comparison = None try: @@ -439,7 +434,7 @@ class JargonMiner: # 判断是否为黑话 is_similar = comparison.get("is_similar", False) is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话 - + # 更新数据库记录 jargon_obj.is_jargon = is_jargon if is_jargon: @@ -448,17 +443,19 @@ class JargonMiner: else: # 不是黑话,也记录含义(使用推断2的结果,因为含义明确) jargon_obj.meaning = inference2.get("meaning", "") - + # 更新最后一次判定的count值,避免重启后重复判定 jargon_obj.last_inference_count = jargon_obj.count or 0 - + # 如果count>=100,标记为完成,不再进行推断 if (jargon_obj.count or 0) >= 100: jargon_obj.is_complete = True - + jargon_obj.save() - logger.debug(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}") - + logger.debug( + f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}" + ) + # 固定输出推断结果,格式化为可读形式 if is_jargon: # 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx @@ -471,10 +468,11 @@ class JargonMiner: else: # 不是黑话,输出格式:[聊天名]xxx 不是黑话 logger.info(f"[{self.stream_name}]{content} 不是黑话") - + except Exception as e: logger.error(f"jargon推断失败: {e}") import traceback + traceback.print_exc() def should_trigger(self) -> bool: @@ -502,7 +500,7 @@ class JargonMiner: # 记录本次提取的时间窗口,避免重复提取 extraction_start_time = self.last_learning_time extraction_end_time = time.time() - + # 拉取学习窗口内的消息 messages = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, @@ -525,7 +523,7 @@ class JargonMiner: response, _ = await self.llm.generate_response_async(prompt, temperature=0.2) if not response: return - + if global_config.debug.show_jargon_prompt: logger.info(f"jargon提取提示词: {prompt}") logger.info(f"jargon提取结果: {response}") @@ -555,7 +553,7 @@ class JargonMiner: continue content = str(item.get("content", "")).strip() raw_content_value = item.get("raw_content", "") - + # 处理raw_content:可能是字符串或列表 raw_content_list = [] if isinstance(raw_content_value, list): @@ -566,15 +564,12 @@ class JargonMiner: raw_content_str = raw_content_value.strip() if raw_content_str: raw_content_list = [raw_content_str] - + if content and raw_content_list: if _contains_bot_self_name(content): logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}") continue - entries.append({ - "content": content, - "raw_content": raw_content_list - }) + entries.append({"content": content, "raw_content": raw_content_list}) except Exception as e: logger.error(f"解析jargon JSON失败: {e}; 原始: {response}") return @@ -591,13 +586,13 @@ class JargonMiner: if content_key not in seen: seen.add(content_key) uniq_entries.append(entry) - + saved = 0 updated = 0 for entry in uniq_entries: content = entry["content"] raw_content_list = entry["raw_content"] # 已经是列表 - + # 检查并补充raw_content:如果只包含黑话本身,则获取前三条消息作为上下文 raw_content_list = await _enrich_raw_content_if_needed( content=content, @@ -607,60 +602,53 @@ class JargonMiner: extraction_start_time=extraction_start_time, extraction_end_time=extraction_end_time, ) - + try: # 根据all_global配置决定查询逻辑 if global_config.jargon.all_global: # 开启all_global:无视chat_id,查询所有content匹配的记录(所有记录都是全局的) - query = ( - Jargon.select() - .where(Jargon.content == content) - ) + query = Jargon.select().where(Jargon.content == content) else: # 关闭all_global:只查询chat_id匹配的记录(不考虑is_global) - query = ( - Jargon.select() - .where( - (Jargon.chat_id == self.chat_id) & - (Jargon.content == content) - ) - ) - + query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content)) + if query.exists(): obj = query.get() try: obj.count = (obj.count or 0) + 1 except Exception: obj.count = 1 - + # 合并raw_content列表:读取现有列表,追加新值,去重 existing_raw_content = [] if obj.raw_content: try: - existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content + existing_raw_content = ( + json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content + ) if not isinstance(existing_raw_content, list): existing_raw_content = [existing_raw_content] if existing_raw_content else [] except (json.JSONDecodeError, TypeError): existing_raw_content = [obj.raw_content] if obj.raw_content else [] - + # 合并并去重 merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list)) obj.raw_content = json.dumps(merged_list, ensure_ascii=False) - + # 开启all_global时,确保记录标记为is_global=True if global_config.jargon.all_global: obj.is_global = True # 关闭all_global时,保持原有is_global不变(不修改) - + obj.save() - + # 检查是否需要推断(达到阈值且超过上次判定值) if _should_infer_meaning(obj): # 异步触发推断,不阻塞主流程 # 重新加载对象以确保数据最新 jargon_id = obj.id asyncio.create_task(self._infer_meaning_by_id(jargon_id)) - + updated += 1 else: # 没找到匹配记录,创建新记录 @@ -670,13 +658,13 @@ class JargonMiner: else: # 关闭all_global:新记录is_global=False is_global_new = False - + Jargon.create( content=content, raw_content=json.dumps(raw_content_list, ensure_ascii=False), chat_id=self.chat_id, is_global=is_global_new, - count=1 + count=1, ) saved += 1 except Exception as e: @@ -688,13 +676,13 @@ class JargonMiner: # 收集所有提取的jargon内容 jargon_list = [entry["content"] for entry in uniq_entries] jargon_str = ",".join(jargon_list) - + # 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色) logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}") - + # 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口 self.last_learning_time = extraction_end_time - + if saved or updated: logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}") except Exception as e: @@ -720,15 +708,11 @@ async def extract_and_store_jargon(chat_id: str) -> None: def search_jargon( - keyword: str, - chat_id: Optional[str] = None, - limit: int = 10, - case_sensitive: bool = False, - fuzzy: bool = True + keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True ) -> List[Dict[str, str]]: """ 搜索jargon,支持大小写不敏感和模糊搜索 - + Args: keyword: 搜索关键词 chat_id: 可选的聊天ID @@ -737,21 +721,18 @@ def search_jargon( limit: 返回结果数量限制,默认10 case_sensitive: 是否大小写敏感,默认False(不敏感) fuzzy: 是否模糊搜索,默认True(使用LIKE匹配) - + Returns: List[Dict[str, str]]: 包含content, meaning的字典列表 """ if not keyword or not keyword.strip(): return [] - + keyword = keyword.strip() - + # 构建查询 - query = Jargon.select( - Jargon.content, - Jargon.meaning - ) - + query = Jargon.select(Jargon.content, Jargon.meaning) + # 构建搜索条件 if case_sensitive: # 大小写敏感 @@ -760,7 +741,7 @@ def search_jargon( search_condition = Jargon.content.contains(keyword) else: # 精确匹配 - search_condition = (Jargon.content == keyword) + search_condition = Jargon.content == keyword else: # 大小写不敏感 if fuzzy: @@ -768,10 +749,10 @@ def search_jargon( search_condition = fn.LOWER(Jargon.content).contains(keyword.lower()) else: # 精确匹配(使用LOWER函数) - search_condition = (fn.LOWER(Jargon.content) == keyword.lower()) - + search_condition = fn.LOWER(Jargon.content) == keyword.lower() + query = query.where(search_condition) - + # 根据all_global配置决定查询逻辑 if global_config.jargon.all_global: # 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id) @@ -779,35 +760,28 @@ def search_jargon( else: # 关闭all_global:如果提供了chat_id,优先搜索该聊天或global的jargon if chat_id: - query = query.where( - (Jargon.chat_id == chat_id) | Jargon.is_global - ) - + query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global) + # 只返回有meaning的记录 - query = query.where( - (Jargon.meaning.is_null(False)) & (Jargon.meaning != "") - ) - + query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != "")) + # 按count降序排序,优先返回出现频率高的 query = query.order_by(Jargon.count.desc()) - + # 限制结果数量 query = query.limit(limit) - + # 执行查询并返回结果 results = [] for jargon in query: - results.append({ - "content": jargon.content or "", - "meaning": jargon.meaning or "" - }) - + results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""}) + return results async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: str) -> None: """将黑话存入jargon系统 - + Args: jargon_keyword: 黑话关键词 answer: 答案内容(将概括为raw_content) @@ -820,53 +794,52 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st 答案:{answer} 只输出概括后的内容,不要输出其他内容:""" - + success, summary, _, _ = await llm_api.generate_with_model( summary_prompt, model_config=model_config.model_task_config.utils_small, request_type="memory.summarize_jargon", ) - + logger.info(f"概括答案提示: {summary_prompt}") logger.info(f"概括答案: {summary}") - + if not success: logger.warning(f"概括答案失败,使用原始答案: {summary}") summary = answer[:100] # 截取前100字符作为备用 - + raw_content = summary.strip()[:200] # 限制长度 - + # 检查是否已存在 if global_config.jargon.all_global: query = Jargon.select().where(Jargon.content == jargon_keyword) else: - query = Jargon.select().where( - (Jargon.chat_id == chat_id) & - (Jargon.content == jargon_keyword) - ) - + query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword)) + if query.exists(): # 更新现有记录 obj = query.get() obj.count = (obj.count or 0) + 1 - + # 合并raw_content列表 existing_raw_content = [] if obj.raw_content: try: - existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content + existing_raw_content = ( + json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content + ) if not isinstance(existing_raw_content, list): existing_raw_content = [existing_raw_content] if existing_raw_content else [] except (json.JSONDecodeError, TypeError): existing_raw_content = [obj.raw_content] if obj.raw_content else [] - + # 合并并去重 merged_list = list(dict.fromkeys(existing_raw_content + [raw_content])) obj.raw_content = json.dumps(merged_list, ensure_ascii=False) - + if global_config.jargon.all_global: obj.is_global = True - + obj.save() logger.info(f"更新jargon记录: {jargon_keyword}") else: @@ -877,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st raw_content=json.dumps([raw_content], ensure_ascii=False), chat_id=chat_id, is_global=is_global_new, - count=1 + count=1, ) logger.info(f"创建新jargon记录: {jargon_keyword}") - + except Exception as e: logger.error(f"存储jargon失败: {e}") - - diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index 444c5671..b3fafca0 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -147,7 +147,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar param_type_value = tool_option_param.param_type.value if param_type_value == "bool": param_type_value = "boolean" - + return_dict: dict[str, Any] = { "type": param_type_value, "description": tool_option_param.description, diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 01e12588..f573d33e 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -122,7 +122,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any] param_type_value = tool_option_param.param_type.value if param_type_value == "bool": param_type_value = "boolean" - + return_dict: dict[str, Any] = { "type": param_type_value, "description": tool_option_param.description, diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index ddcdf57f..960de08b 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -116,9 +116,7 @@ class MessageBuilder: 构建消息对象 :return: Message对象 """ - if len(self.__content) == 0 and not ( - self.__role == RoleType.Assistant and self.__tool_calls - ): + if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls): raise ValueError("内容不能为空") if self.__role == RoleType.Tool and self.__tool_call_id is None: raise ValueError("Tool角色的工具调用ID不能为空") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 28f63c11..1ed74e03 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -166,7 +166,7 @@ class LLMRequest: time_cost=time.time() - start_time, ) return content or "", (reasoning_content, model_info.name, tool_calls) - + async def generate_response_with_message_async( self, message_factory: Callable[[BaseClient], List[Message]], diff --git a/src/main.py b/src/main.py index a75d4d26..b442f29d 100644 --- a/src/main.py +++ b/src/main.py @@ -36,10 +36,10 @@ class MainSystem: # 使用消息API替代直接的FastAPI实例 self.app: MessageServer = get_global_api() self.server: Server = get_global_server() - + # 注册 WebUI API 路由 self._register_webui_routes() - + # 设置 WebUI(开发/生产模式) self._setup_webui() @@ -47,6 +47,7 @@ class MainSystem: """注册 WebUI API 路由""" try: from src.webui.routes import router as webui_router + self.server.register_router(webui_router) logger.info("WebUI API 路由已注册") except Exception as e: @@ -55,15 +56,17 @@ class MainSystem: def _setup_webui(self): """设置 WebUI(根据环境变量决定模式)""" import os + webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true" if not webui_enabled: logger.info("WebUI 已禁用") return - + webui_mode = os.getenv("WEBUI_MODE", "production").lower() - + try: from src.webui.manager import setup_webui + setup_webui(mode=webui_mode) except Exception as e: logger.error(f"设置 WebUI 失败: {e}") diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index e1562b6e..e9187eb0 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -17,26 +17,23 @@ from src.llm_models.payload_content.message import MessageBuilder, RoleType, Mes logger = get_logger("memory_retrieval") THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 3600 # 未找到答案记录保留时长 -THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 300 # 清理频率 +THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 300 # 清理频率 _last_not_found_cleanup_ts: float = 0.0 def _cleanup_stale_not_found_thinking_back() -> None: """定期清理过期的未找到答案记录""" global _last_not_found_cleanup_ts - + now = time.time() if now - _last_not_found_cleanup_ts < THINKING_BACK_CLEANUP_INTERVAL_SECONDS: return - + threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS try: deleted_rows = ( ThinkingBack.delete() - .where( - (ThinkingBack.found_answer == 0) & - (ThinkingBack.update_time < threshold_time) - ) + .where((ThinkingBack.found_answer == 0) & (ThinkingBack.update_time < threshold_time)) .execute() ) if deleted_rows: @@ -45,11 +42,12 @@ def _cleanup_stale_not_found_thinking_back() -> None: except Exception as e: logger.error(f"清理未找到答案的thinking_back记录失败: {e}") + def init_memory_retrieval_prompt(): """初始化记忆检索相关的 prompt 模板和工具""" # 首先注册所有工具 init_all_tools() - + # 第一步:问题生成prompt Prompt( """ @@ -102,7 +100,7 @@ def init_memory_retrieval_prompt(): """, name="memory_retrieval_question_prompt", ) - + # 第二步:ReAct Agent prompt(使用function calling,要求先思考再行动) Prompt( """你的名字是{bot_name}。现在是{time_now}。 @@ -140,7 +138,7 @@ def init_memory_retrieval_prompt(): """, name="memory_retrieval_react_prompt_head", ) - + # 额外,如果最后一轮迭代:ReAct Agent prompt(使用function calling,要求先思考再行动) Prompt( """你的名字是{bot_name}。现在是{time_now}。 @@ -170,10 +168,10 @@ def init_memory_retrieval_prompt(): def _parse_react_response(response: str) -> Optional[Dict[str, Any]]: """解析ReAct Agent的响应 - + Args: response: LLM返回的响应 - + Returns: Dict[str, Any]: 解析后的动作信息,如果解析失败返回None 格式: {"thought": str, "actions": List[Dict[str, Any]]} @@ -183,90 +181,75 @@ def _parse_react_response(response: str) -> Optional[Dict[str, Any]]: # 尝试提取JSON(可能包含在```json代码块中) json_pattern = r"```json\s*(.*?)\s*```" matches = re.findall(json_pattern, response, re.DOTALL) - + if matches: json_str = matches[0] else: # 尝试直接解析整个响应 json_str = response.strip() - + # 修复可能的JSON错误 repaired_json = repair_json(json_str) - + # 解析JSON action_info = json.loads(repaired_json) - + if not isinstance(action_info, dict): logger.warning(f"解析的JSON不是对象格式: {action_info}") return None - + # 确保actions字段存在且为列表 if "actions" not in action_info: logger.warning(f"响应中缺少actions字段: {action_info}") return None - + if not isinstance(action_info["actions"], list): logger.warning(f"actions字段不是数组格式: {action_info['actions']}") return None - + # 确保actions不为空 if len(action_info["actions"]) == 0: logger.warning("actions数组为空") return None - + return action_info - + except Exception as e: logger.error(f"解析ReAct响应失败: {e}, 响应内容: {response[:200]}...") return None -async def _retrieve_concepts_with_jargon( - concepts: List[str], - chat_id: str -) -> str: +async def _retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str: """对概念列表进行jargon检索 - + Args: concepts: 概念列表 chat_id: 聊天ID - + Returns: str: 检索结果字符串 """ if not concepts: return "" - + from src.jargon.jargon_miner import search_jargon - + results = [] for concept in concepts: concept = concept.strip() if not concept: continue - + # 先尝试精确匹配 - jargon_results = search_jargon( - keyword=concept, - chat_id=chat_id, - limit=10, - case_sensitive=False, - fuzzy=False - ) - + jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False) + is_fuzzy_match = False - + # 如果精确匹配未找到,尝试模糊搜索 if not jargon_results: - jargon_results = search_jargon( - keyword=concept, - chat_id=chat_id, - limit=10, - case_sensitive=False, - fuzzy=True - ) + jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True) is_fuzzy_match = True - + if jargon_results: # 找到结果 if is_fuzzy_match: @@ -291,28 +274,24 @@ async def _retrieve_concepts_with_jargon( else: # 未找到,不返回占位信息,只记录日志 logger.info(f"在jargon库中未找到匹配: {concept}") - + if results: return "【概念检索结果】\n" + "\n".join(results) + "\n" return "" async def _react_agent_solve_question( - question: str, - chat_id: str, - max_iterations: int = 5, - timeout: float = 30.0, - initial_info: str = "" + question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0, initial_info: str = "" ) -> Tuple[bool, str, List[Dict[str, Any]], bool]: """使用ReAct架构的Agent来解决问题 - + Args: question: 要回答的问题 chat_id: 聊天ID max_iterations: 最大迭代次数 timeout: 超时时间(秒) initial_info: 初始信息(如概念检索结果),将作为collected_info的初始值 - + Returns: Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时) """ @@ -321,34 +300,35 @@ async def _react_agent_solve_question( thinking_steps = [] is_timeout = False conversation_messages: List[Message] = [] - + for iteration in range(max_iterations): # 检查超时 if time.time() - start_time > timeout: logger.warning(f"ReAct Agent超时,已迭代{iteration}次") is_timeout = True break - + # 获取工具注册器 tool_registry = get_tool_registry() - + # 获取bot_name bot_name = global_config.bot.nickname - + # 获取当前时间 time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - + # 计算剩余迭代次数 current_iteration = iteration + 1 remaining_iterations = max_iterations - current_iteration is_final_iteration = current_iteration >= max_iterations - - + if is_final_iteration: # 最后一次迭代,使用最终prompt tool_definitions = [] - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0(最后一次迭代,不提供工具调用)") - + logger.info( + f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0(最后一次迭代,不提供工具调用)" + ) + prompt = await global_prompt_manager.format_prompt( "memory_retrieval_react_final_prompt", bot_name=bot_name, @@ -359,7 +339,7 @@ async def _react_agent_solve_question( remaining_iterations=remaining_iterations, max_iterations=max_iterations, ) - + logger.info(f"ReAct Agent 第 {iteration + 1} 次Prompt: {prompt}") success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools( prompt, @@ -370,7 +350,9 @@ async def _react_agent_solve_question( else: # 非最终迭代,使用head_prompt tool_definitions = tool_registry.get_tool_definitions() - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}") + logger.info( + f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}" + ) head_prompt = await global_prompt_manager.format_prompt( "memory_retrieval_react_prompt_head", @@ -397,12 +379,12 @@ async def _react_agent_solve_question( messages.append(system_builder.build()) messages.extend(_conversation_messages) - + # 优化日志展示 - 合并所有消息到一条日志 log_lines = [] for idx, msg in enumerate(messages, 1): - role_name = msg.role.value if hasattr(msg.role, 'value') else str(msg.role) - + role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + # 处理内容 - 显示完整内容,不截断 if isinstance(msg.content, str): full_content = msg.content @@ -415,37 +397,45 @@ async def _react_agent_solve_question( else: full_content = str(msg.content) content_type = "未知" - + # 构建单条消息的日志信息 msg_info = f"\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n========================================" - + if full_content: msg_info += f"\n{full_content}" - + if msg.tool_calls: msg_info += f"\n 工具调用: {len(msg.tool_calls)}个" for tool_call in msg.tool_calls: msg_info += f"\n - {tool_call}" - + if msg.tool_call_id: msg_info += f"\n 工具调用ID: {msg.tool_call_id}" - + log_lines.append(msg_info) - + # 合并所有消息为一条日志输出 logger.info(f"消息列表 (共{len(messages)}条):{''.join(log_lines)}") return messages - success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools_by_message_factory( + ( + success, + response, + reasoning_content, + model_name, + tool_calls, + ) = await llm_api.generate_with_model_with_tools_by_message_factory( message_factory, model_config=model_config.model_task_config.tool_use, tool_options=tool_definitions, request_type="memory.react", ) - - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}") - + + logger.info( + f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" + ) + if not success: logger.error(f"ReAct Agent LLM调用失败: {response}") break @@ -463,105 +453,108 @@ async def _react_agent_solve_question( assistant_builder.set_role(RoleType.Assistant) assistant_builder.add_text_content(response) assistant_message = assistant_builder.build() - + # 记录思考步骤 - step = { - "iteration": iteration + 1, - "thought": response, - "actions": [], - "observations": [] - } - + step = {"iteration": iteration + 1, "thought": response, "actions": [], "observations": []} + # 优先从思考内容中提取found_answer或not_enough_info def extract_quoted_content(text, func_name, param_name): """从文本中提取函数调用中参数的值,支持单引号和双引号 - + Args: text: 要搜索的文本 func_name: 函数名,如 'found_answer' param_name: 参数名,如 'answer' - + Returns: 提取的参数值,如果未找到则返回None """ if not text: return None - + # 查找函数调用位置(不区分大小写) func_pattern = func_name.lower() text_lower = text.lower() func_pos = text_lower.find(func_pattern) if func_pos == -1: return None - + # 查找参数名和等号 - param_pattern = f'{param_name}=' + param_pattern = f"{param_name}=" param_pos = text_lower.find(param_pattern, func_pos) if param_pos == -1: return None - + # 跳过参数名、等号和空白 start_pos = param_pos + len(param_pattern) - while start_pos < len(text) and text[start_pos] in ' \t\n': + while start_pos < len(text) and text[start_pos] in " \t\n": start_pos += 1 - + if start_pos >= len(text): return None - + # 确定引号类型 quote_char = text[start_pos] if quote_char not in ['"', "'"]: return None - + # 查找匹配的结束引号(考虑转义) end_pos = start_pos + 1 while end_pos < len(text): if text[end_pos] == quote_char: # 检查是否是转义的引号 - if end_pos > start_pos + 1 and text[end_pos - 1] == '\\': + if end_pos > start_pos + 1 and text[end_pos - 1] == "\\": end_pos += 1 continue # 找到匹配的引号 - content = text[start_pos + 1:end_pos] + content = text[start_pos + 1 : end_pos] # 处理转义字符 - content = content.replace('\\"', '"').replace("\\'", "'").replace('\\\\', '\\') + content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\") return content end_pos += 1 - + return None - + # 从LLM的直接输出内容中提取found_answer或not_enough_info found_answer_content = None not_enough_info_reason = None - + # 只检查response(LLM的直接输出内容),不检查reasoning_content if response: - found_answer_content = extract_quoted_content(response, 'found_answer', 'answer') + found_answer_content = extract_quoted_content(response, "found_answer", "answer") if not found_answer_content: - not_enough_info_reason = extract_quoted_content(response, 'not_enough_info', 'reason') - + not_enough_info_reason = extract_quoted_content(response, "not_enough_info", "reason") + # 如果从输出内容中找到了答案,直接返回 if found_answer_content: step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}}) step["observations"] = ["从LLM输出内容中检测到found_answer"] thinking_steps.append(step) - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}...") + logger.info( + f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}..." + ) return True, found_answer_content, thinking_steps, False - + if not_enough_info_reason: - step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}) + step["actions"].append( + {"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}} + ) step["observations"] = ["从LLM输出内容中检测到not_enough_info"] thinking_steps.append(step) - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}...") + logger.info( + f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}..." + ) return False, not_enough_info_reason, thinking_steps, False - + if is_final_iteration: - step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}}) + step["actions"].append( + {"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}} + ) step["observations"] = ["已到达最后一次迭代,无法找到答案"] thinking_steps.append(step) logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案") return False, "已到达最后一次迭代,无法找到答案", thinking_steps, False - + if assistant_message: conversation_messages.append(assistant_message) @@ -570,7 +563,7 @@ async def _react_agent_solve_question( thought_summary = reasoning_content or (response[:200] if response else "") if thought_summary: collected_info += f"\n[思考] {thought_summary}\n" - + # 处理工具调用 if not tool_calls: # 没有工具调用,说明LLM在思考中已经给出了答案(已在前面检查),或者需要继续查询 @@ -588,28 +581,31 @@ async def _react_agent_solve_question( step["observations"] = ["无响应且无工具调用"] thinking_steps.append(step) break - + # 处理工具调用 tool_tasks = [] - + for i, tool_call in enumerate(tool_calls): tool_name = tool_call.func_name tool_args = tool_call.args or {} - - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i+1}/{len(tool_calls)}: {tool_name}({tool_args})") - + + logger.info( + f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})" + ) + # 普通工具调用 tool = tool_registry.get_tool(tool_name) if tool: # 准备工具参数(需要添加chat_id如果工具需要) tool_params = tool_args.copy() - + # 如果工具函数签名需要chat_id,添加它 import inspect + sig = inspect.signature(tool.execute_func) if "chat_id" in sig.parameters: tool_params["chat_id"] = chat_id - + # 创建异步任务 async def execute_single_tool(tool_instance, params, tool_name_str, iter_num): try: @@ -620,23 +616,23 @@ async def _react_agent_solve_question( error_msg = f"工具执行失败: {str(e)}" logger.error(f"ReAct Agent 第 {iter_num + 1} 次迭代 工具 {tool_name_str} {error_msg}") return f"查询{tool_name_str}失败: {error_msg}" - + tool_tasks.append(execute_single_tool(tool, tool_params, tool_name, iteration)) step["actions"].append({"action_type": tool_name, "action_params": tool_args}) else: error_msg = f"未知的工具类型: {tool_name}" - logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1}/{len(tool_calls)} {error_msg}") + logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}") tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}"))) - + # 并行执行所有工具 if tool_tasks: observations = await asyncio.gather(*tool_tasks, return_exceptions=True) - + # 处理执行结果 for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)): if isinstance(observation, Exception): observation = f"工具执行异常: {str(observation)}" - logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1} 执行异常: {observation}") + logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}") observation_text = observation if isinstance(observation, str) else str(observation) step["observations"].append(observation_text) @@ -648,14 +644,16 @@ async def _react_agent_solve_question( tool_builder.add_tool_call(tool_call_item.call_id) conversation_messages.append(tool_builder.build()) # logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1} 执行结果: {observation_text}") - + thinking_steps.append(step) - + # 达到最大迭代次数或超时,但Agent没有明确返回found_answer # 迭代超时应该直接视为not_enough_info,而不是使用已有信息 # 只有Agent明确返回found_answer时,才认为找到了答案 if collected_info: - logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回found_answer。已收集信息: {collected_info[:100]}...") + logger.warning( + f"ReAct Agent达到最大迭代次数或超时,但未明确返回found_answer。已收集信息: {collected_info[:100]}..." + ) if is_timeout: logger.warning("ReAct Agent超时,直接视为not_enough_info") else: @@ -665,35 +663,32 @@ async def _react_agent_solve_question( def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str: """获取最近一段时间内的查询历史 - + Args: chat_id: 聊天ID time_window_seconds: 时间窗口(秒),默认10分钟 - + Returns: str: 格式化的查询历史字符串 """ try: current_time = time.time() start_time = current_time - time_window_seconds - + # 查询最近时间窗口内的记录,按更新时间倒序 records = ( ThinkingBack.select() - .where( - (ThinkingBack.chat_id == chat_id) & - (ThinkingBack.update_time >= start_time) - ) + .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time)) .order_by(ThinkingBack.update_time.desc()) .limit(5) # 最多返回5条最近的记录 ) - + if not records.exists(): return "" - + history_lines = [] history_lines.append("最近已查询的问题和结果:") - + for record in records: status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案" answer_preview = "" @@ -703,15 +698,15 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) answer_preview = record.answer[:100] if len(record.answer) > 100: answer_preview += "..." - + history_lines.append(f"- 问题:{record.question}") history_lines.append(f" 状态:{status}") if answer_preview: history_lines.append(f" 答案:{answer_preview}") history_lines.append("") # 空行分隔 - + return "\n".join(history_lines) - + except Exception as e: logger.error(f"获取查询历史失败: {e}") return "" @@ -719,40 +714,40 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> List[str]: """获取最近一段时间内缓存的记忆(只返回找到答案的记录) - + Args: chat_id: 聊天ID time_window_seconds: 时间窗口(秒),默认300秒(5分钟) - + Returns: List[str]: 格式化的记忆列表,每个元素格式为 "问题:xxx\n答案:xxx" """ try: current_time = time.time() start_time = current_time - time_window_seconds - + # 查询最近时间窗口内找到答案的记录,按更新时间倒序 records = ( ThinkingBack.select() .where( - (ThinkingBack.chat_id == chat_id) & - (ThinkingBack.update_time >= start_time) & - (ThinkingBack.found_answer == 1) + (ThinkingBack.chat_id == chat_id) + & (ThinkingBack.update_time >= start_time) + & (ThinkingBack.found_answer == 1) ) .order_by(ThinkingBack.update_time.desc()) .limit(5) # 最多返回5条最近的记录 ) - + if not records.exists(): return [] - + cached_memories = [] for record in records: if record.answer: cached_memories.append(f"问题:{record.question}\n答案:{record.answer}") - + return cached_memories - + except Exception as e: logger.error(f"获取缓存记忆失败: {e}") return [] @@ -760,11 +755,11 @@ def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> Li def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, str]]: """从thinking_back数据库中查询是否有现成的答案 - + Args: chat_id: 聊天ID question: 问题 - + Returns: Optional[Tuple[bool, str]]: 如果找到记录,返回(found_answer, answer),否则返回None found_answer: 是否找到答案(True表示found_answer=1,False表示found_answer=0) @@ -775,23 +770,20 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st # 按更新时间倒序,获取最新的记录 records = ( ThinkingBack.select() - .where( - (ThinkingBack.chat_id == chat_id) & - (ThinkingBack.question == question) - ) + .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question)) .order_by(ThinkingBack.update_time.desc()) .limit(1) ) - + if records.exists(): record = records.get() found_answer = bool(record.found_answer) answer = record.answer or "" logger.info(f"在thinking_back中找到记录,问题: {question[:50]}...,found_answer: {found_answer}") return found_answer, answer - + return None - + except Exception as e: logger.error(f"查询thinking_back失败: {e}") return None @@ -799,7 +791,7 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st async def _analyze_question_answer(question: str, answer: str, chat_id: str) -> None: """异步分析问题和答案的类别,并存储到相应系统 - + Args: question: 问题 answer: 答案 @@ -826,41 +818,42 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) -> }} 只输出JSON,不要输出其他内容:""" - + success, response, _, _ = await llm_api.generate_with_model( analysis_prompt, model_config=model_config.model_task_config.utils, request_type="memory.analyze_qa", ) - + if not success: logger.error(f"分析问题和答案失败: {response}") return - + # 解析JSON响应 try: json_pattern = r"```json\s*(.*?)\s*```" matches = re.findall(json_pattern, response, re.DOTALL) - + if matches: json_str = matches[0] else: json_str = response.strip() - + repaired_json = repair_json(json_str) analysis_result = json.loads(repaired_json) - + category = analysis_result.get("category", "").strip() - + if category == "黑话": # 处理黑话 jargon_keyword = analysis_result.get("jargon_keyword", "").strip() if jargon_keyword: from src.jargon.jargon_miner import store_jargon_from_answer + await store_jargon_from_answer(jargon_keyword, answer, chat_id) else: logger.warning(f"分析为黑话但未提取到关键词,问题: {question[:50]}...") - + elif category == "人物信息": # 处理人物信息 # person_name = analysis_result.get("person_name", "").strip() @@ -871,28 +864,22 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) -> # else: # logger.warning(f"分析为人物信息但未提取到人物名称或记忆内容,问题: {question[:50]}...") pass # 功能暂时禁用 - + else: logger.info(f"问题和答案类别为'其他',不进行存储,问题: {question[:50]}...") - + except Exception as e: logger.error(f"解析分析结果失败: {e}, 响应: {response[:200]}...") - + except Exception as e: logger.error(f"分析问题和答案时发生异常: {e}") - def _store_thinking_back( - chat_id: str, - question: str, - context: str, - found_answer: bool, - answer: str, - thinking_steps: List[Dict[str, Any]] + chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]] ) -> None: """存储或更新思考过程到数据库(如果已存在则更新,否则创建) - + Args: chat_id: 聊天ID question: 问题 @@ -903,18 +890,15 @@ def _store_thinking_back( """ try: now = time.time() - + # 先查询是否已存在相同chat_id和问题的记录 existing = ( ThinkingBack.select() - .where( - (ThinkingBack.chat_id == chat_id) & - (ThinkingBack.question == question) - ) + .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question)) .order_by(ThinkingBack.update_time.desc()) .limit(1) ) - + if existing.exists(): # 更新现有记录 record = existing.get() @@ -935,27 +919,22 @@ def _store_thinking_back( answer=answer, thinking_steps=json.dumps(thinking_steps, ensure_ascii=False), create_time=now, - update_time=now + update_time=now, ) logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...") except Exception as e: logger.error(f"存储思考过程失败: {e}") -async def _process_single_question( - question: str, - chat_id: str, - context: str, - initial_info: str = "" -) -> Optional[str]: +async def _process_single_question(question: str, chat_id: str, context: str, initial_info: str = "") -> Optional[str]: """处理单个问题的查询(包含缓存检查逻辑) - + Args: question: 要查询的问题 chat_id: 聊天ID context: 上下文信息 initial_info: 初始信息(如概念检索结果),将传递给ReAct Agent - + Returns: Optional[str]: 如果找到答案,返回格式化的结果字符串,否则返回None """ @@ -978,20 +957,20 @@ async def _process_single_question( logger.info(f"LPMM预查询未命中或未找到信息,问题: {question[:50]}...") except Exception as e: logger.error(f"LPMM预查询失败,问题: {question[:50]}... 错误: {e}") - + # 先检查thinking_back数据库中是否有现成答案 cached_result = _query_thinking_back(chat_id, question) should_requery = False - + if cached_result: cached_found_answer, cached_answer = cached_result - + if cached_found_answer: # found_answer == 1 (True) # found_answer == 1:20%概率重新查询 if random.random() < 0.5: should_requery = True logger.info(f"found_answer=1,触发20%概率重新查询,问题: {question[:50]}...") - + if not should_requery and cached_answer: logger.info(f"从thinking_back缓存中获取答案,问题: {question[:50]}...") return f"问题:{question}\n答案:{cached_answer}" @@ -1002,22 +981,22 @@ async def _process_single_question( # found_answer == 0:不使用缓存,直接重新查询 should_requery = True logger.info(f"thinking_back存在但未找到答案,忽略缓存重新查询,问题: {question[:50]}...") - + # 如果没有缓存答案或需要重新查询,使用ReAct Agent查询 if not cached_result or should_requery: if should_requery: logger.info(f"概率触发重新查询,使用ReAct Agent查询,问题: {question[:50]}...") else: logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...") - + found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( question=question, chat_id=chat_id, max_iterations=global_config.memory.max_agent_iterations, timeout=120.0, - initial_info=question_initial_info + initial_info=question_initial_info, ) - + # 存储到数据库(超时时不存储) if not is_timeout: _store_thinking_back( @@ -1026,16 +1005,16 @@ async def _process_single_question( context=context, found_answer=found_answer, answer=answer, - thinking_steps=thinking_steps + thinking_steps=thinking_steps, ) else: logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...") - + if found_answer and answer: # 创建异步任务分析问题和答案 asyncio.create_task(_analyze_question_answer(question, answer, chat_id)) return f"问题:{question}\n答案:{answer}" - + return None @@ -1048,30 +1027,30 @@ async def build_memory_retrieval_prompt( ) -> str: """构建记忆检索提示 使用两段式查询:第一步生成问题,第二步使用ReAct Agent查询答案 - + Args: message: 聊天历史记录 sender: 发送者名称 target: 目标消息内容 chat_stream: 聊天流对象 tool_executor: 工具执行器(保留参数以兼容接口) - + Returns: str: 记忆检索结果字符串 """ start_time = time.time() - + logger.info(f"检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}") try: time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) bot_name = global_config.bot.nickname chat_id = chat_stream.stream_id - + # 获取最近查询历史(最近1小时内的查询) recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=300.0) if not recent_query_history: recent_query_history = "最近没有查询记录。" - + # 第一步:生成问题 question_prompt = await global_prompt_manager.format_prompt( "memory_retrieval_question_prompt", @@ -1082,25 +1061,25 @@ async def build_memory_retrieval_prompt( sender=sender, target_message=target, ) - + success, response, reasoning_content, model_name = await llm_api.generate_with_model( question_prompt, model_config=model_config.model_task_config.tool_use, request_type="memory.question", ) - + logger.info(f"记忆检索问题生成提示词: {question_prompt}") logger.info(f"记忆检索问题生成响应: {response}") - + if not success: logger.error(f"LLM生成问题失败: {response}") return "" - + # 解析概念列表和问题列表 concepts, questions = _parse_questions_json(response) logger.info(f"解析到 {len(concepts)} 个概念: {concepts}") logger.info(f"解析到 {len(questions)} 个问题: {questions}") - + # 对概念进行jargon检索,作为初始信息 initial_info = "" if concepts: @@ -1111,11 +1090,10 @@ async def build_memory_retrieval_prompt( logger.info(f"概念检索完成,结果: {concept_info[:200]}...") else: logger.info("概念检索未找到任何结果") - - + # 获取缓存的记忆(与question时使用相同的时间窗口和数量限制) cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0) - + if not questions: logger.debug("模型认为不需要检索记忆或解析失败") # 即使没有当次查询,也返回缓存的记忆和概念检索结果 @@ -1124,7 +1102,7 @@ async def build_memory_retrieval_prompt( all_results.append(initial_info.strip()) if cached_memories: all_results.extend(cached_memories) - + if all_results: retrieved_memory = "\n\n".join(all_results) end_time = time.time() @@ -1132,27 +1110,22 @@ async def build_memory_retrieval_prompt( return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" else: return "" - + logger.info(f"解析到 {len(questions)} 个问题: {questions}") - + # 第二步:并行处理所有问题(使用配置的最大迭代次数/120秒超时) max_iterations = global_config.memory.max_agent_iterations logger.info(f"问题数量: {len(questions)},设置最大迭代次数: {max_iterations},超时时间: 120秒") - + # 并行处理所有问题,将概念检索结果作为初始信息传递 question_tasks = [ - _process_single_question( - question=question, - chat_id=chat_id, - context=message, - initial_info=initial_info - ) + _process_single_question(question=question, chat_id=chat_id, context=message, initial_info=initial_info) for question in questions ] - + # 并行执行所有查询任务 results = await asyncio.gather(*question_tasks, return_exceptions=True) - + # 收集所有有效结果 all_results = [] current_questions = set() # 用于去重,避免缓存和当次查询重复 @@ -1165,7 +1138,7 @@ async def build_memory_retrieval_prompt( if result.startswith("问题:"): question = result.split("\n")[0].replace("问题:", "").strip() current_questions.add(question) - + # 将缓存的记忆添加到结果中(排除当次查询已包含的问题,避免重复) for cached_memory in cached_memories: if cached_memory.startswith("问题:"): @@ -1174,17 +1147,19 @@ async def build_memory_retrieval_prompt( if question not in current_questions: all_results.append(cached_memory) logger.debug(f"添加缓存记忆: {question[:50]}...") - + end_time = time.time() - + if all_results: retrieved_memory = "\n\n".join(all_results) - logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)") + logger.info( + f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)" + ) return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" else: logger.debug("所有问题均未找到答案,且无缓存记忆") return "" - + except Exception as e: logger.error(f"记忆检索时发生异常: {str(e)}") return "" @@ -1192,10 +1167,10 @@ async def build_memory_retrieval_prompt( def _parse_questions_json(response: str) -> Tuple[List[str], List[str]]: """解析问题JSON,返回概念列表和问题列表 - + Args: response: LLM返回的响应 - + Returns: Tuple[List[str], List[str]]: (概念列表, 问题列表) """ @@ -1203,39 +1178,39 @@ def _parse_questions_json(response: str) -> Tuple[List[str], List[str]]: # 尝试提取JSON(可能包含在```json代码块中) json_pattern = r"```json\s*(.*?)\s*```" matches = re.findall(json_pattern, response, re.DOTALL) - + if matches: json_str = matches[0] else: # 尝试直接解析整个响应 json_str = response.strip() - + # 修复可能的JSON错误 repaired_json = repair_json(json_str) - + # 解析JSON parsed = json.loads(repaired_json) - + # 只支持新格式:包含concepts和questions的对象 if not isinstance(parsed, dict): logger.warning(f"解析的JSON不是对象格式: {parsed}") return [], [] - + concepts_raw = parsed.get("concepts", []) questions_raw = parsed.get("questions", []) - + # 确保是列表 if not isinstance(concepts_raw, list): concepts_raw = [] if not isinstance(questions_raw, list): questions_raw = [] - + # 确保所有元素都是字符串 concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()] questions = [q for q in questions_raw if isinstance(q, str) and q.strip()] - + return concepts, questions - + except Exception as e: logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...") return [], [] diff --git a/src/memory_system/memory_utils.py b/src/memory_system/memory_utils.py index af16456b..bff39f95 100644 --- a/src/memory_system/memory_utils.py +++ b/src/memory_system/memory_utils.py @@ -3,6 +3,7 @@ 记忆系统工具函数 包含模糊查找、相似度计算等工具函数 """ + import json import re from datetime import datetime @@ -14,6 +15,7 @@ from src.common.logger import get_logger logger = get_logger("memory_utils") + def parse_md_json(json_text: str) -> list[str]: """从Markdown格式的内容中提取JSON对象和推理内容""" json_objects = [] @@ -52,14 +54,15 @@ def parse_md_json(json_text: str) -> list[str]: return json_objects, reasoning_content + def calculate_similarity(text1: str, text2: str) -> float: """ 计算两个文本的相似度 - + Args: text1: 第一个文本 text2: 第二个文本 - + Returns: float: 相似度分数 (0-1) """ @@ -67,16 +70,16 @@ def calculate_similarity(text1: str, text2: str) -> float: # 预处理文本 text1 = preprocess_text(text1) text2 = preprocess_text(text2) - + # 使用SequenceMatcher计算相似度 similarity = SequenceMatcher(None, text1, text2).ratio() - + # 如果其中一个文本包含另一个,提高相似度 if text1 in text2 or text2 in text1: similarity = max(similarity, 0.8) - + return similarity - + except Exception as e: logger.error(f"计算相似度时出错: {e}") return 0.0 @@ -85,31 +88,30 @@ def calculate_similarity(text1: str, text2: str) -> float: def preprocess_text(text: str) -> str: """ 预处理文本,提高匹配准确性 - + Args: text: 原始文本 - + Returns: str: 预处理后的文本 """ try: # 转换为小写 text = text.lower() - + # 移除标点符号和特殊字符 - text = re.sub(r'[^\w\s]', '', text) - + text = re.sub(r"[^\w\s]", "", text) + # 移除多余空格 - text = re.sub(r'\s+', ' ', text).strip() - + text = re.sub(r"\s+", " ", text).strip() + return text - + except Exception as e: logger.error(f"预处理文本时出错: {e}") return text - def parse_datetime_to_timestamp(value: str) -> float: """ 接受多种常见格式并转换为时间戳(秒) @@ -143,25 +145,24 @@ def parse_datetime_to_timestamp(value: str) -> float: def parse_time_range(time_range: str) -> Tuple[float, float]: """ 解析时间范围字符串,返回开始和结束时间戳 - + Args: time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS" - + Returns: Tuple[float, float]: (开始时间戳, 结束时间戳) """ if " - " not in time_range: raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}") - + parts = time_range.split(" - ", 1) if len(parts) != 2: raise ValueError(f"时间范围格式错误: {time_range}") - + start_str = parts[0].strip() end_str = parts[1].strip() - + start_timestamp = parse_datetime_to_timestamp(start_str) end_timestamp = parse_datetime_to_timestamp(end_str) - - return start_timestamp, end_timestamp + return start_timestamp, end_timestamp diff --git a/src/memory_system/retrieval_tools/__init__.py b/src/memory_system/retrieval_tools/__init__.py index e02875be..4b5c40c0 100644 --- a/src/memory_system/retrieval_tools/__init__.py +++ b/src/memory_system/retrieval_tools/__init__.py @@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge from .query_person_info import register_tool as register_query_person_info from src.config.config import global_config + def init_all_tools(): """初始化并注册所有记忆检索工具""" register_query_jargon() diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py index 900426ee..407bba05 100644 --- a/src/memory_system/retrieval_tools/query_chat_history.py +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -15,13 +15,10 @@ logger = get_logger("memory_retrieval_tools") async def query_chat_history( - chat_id: str, - keyword: Optional[str] = None, - time_range: Optional[str] = None, - fuzzy: bool = True + chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True ) -> str: """根据时间或关键词在chat_history表中查询聊天记录概述 - + Args: chat_id: 聊天ID keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔) @@ -31,7 +28,7 @@ async def query_chat_history( fuzzy: 是否使用模糊匹配模式(默认True) - True: 模糊匹配,只要包含任意一个关键词即匹配(OR关系) - False: 全匹配,必须包含所有关键词才匹配(AND关系) - + Returns: str: 查询结果 """ @@ -39,10 +36,10 @@ async def query_chat_history( # 检查参数 if not keyword and not time_range: return "未指定查询参数(需要提供keyword或time_range之一)" - + # 构建查询条件 query = ChatHistory.select().where(ChatHistory.chat_id == chat_id) - + # 时间过滤条件 if time_range: # 判断是时间点还是时间范围 @@ -50,79 +47,79 @@ async def query_chat_history( # 时间范围:查询与时间范围有交集的记录 start_timestamp, end_timestamp = parse_time_range(time_range) # 交集条件:start_time < end_timestamp AND end_time > start_timestamp - time_filter = ( - (ChatHistory.start_time < end_timestamp) & - (ChatHistory.end_time > start_timestamp) - ) + time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp) else: # 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time) target_timestamp = parse_datetime_to_timestamp(time_range) - time_filter = ( - (ChatHistory.start_time <= target_timestamp) & - (ChatHistory.end_time >= target_timestamp) - ) + time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp) query = query.where(time_filter) - + # 执行查询 records = list(query.order_by(ChatHistory.start_time.desc()).limit(50)) - + # 如果有关键词,进一步过滤 if keyword: # 解析多个关键词(支持空格、逗号等分隔符) keywords_list = parse_keywords_string(keyword) if not keywords_list: keywords_list = [keyword.strip()] if keyword.strip() else [] - + # 转换为小写以便匹配 keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()] - + if not keywords_lower: return "关键词为空" - + filtered_records = [] - + for record in records: # 在theme、keywords、summary、original_text中搜索 theme = (record.theme or "").lower() summary = (record.summary or "").lower() original_text = (record.original_text or "").lower() - + # 解析record中的keywords JSON record_keywords_list = [] if record.keywords: try: - keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords + keywords_data = ( + json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords + ) if isinstance(keywords_data, list): record_keywords_list = [str(k).lower() for k in keywords_data] except (json.JSONDecodeError, TypeError, ValueError): pass - + # 根据匹配模式检查关键词 matched = False if fuzzy: # 模糊匹配:只要包含任意一个关键词即匹配(OR关系) for kw in keywords_lower: - if (kw in theme or - kw in summary or - kw in original_text or - any(kw in k for k in record_keywords_list)): + if ( + kw in theme + or kw in summary + or kw in original_text + or any(kw in k for k in record_keywords_list) + ): matched = True break else: # 全匹配:必须包含所有关键词才匹配(AND关系) matched = True for kw in keywords_lower: - kw_matched = (kw in theme or - kw in summary or - kw in original_text or - any(kw in k for k in record_keywords_list)) + kw_matched = ( + kw in theme + or kw in summary + or kw in original_text + or any(kw in k for k in record_keywords_list) + ) if not kw_matched: matched = False break - + if matched: filtered_records.append(record) - + if not filtered_records: keywords_str = "、".join(keywords_list) match_mode = "包含任意一个关键词" if fuzzy else "包含所有关键词" @@ -130,9 +127,9 @@ async def query_chat_history( return f"未找到{match_mode}'{keywords_str}'且在指定时间范围内的聊天记录概述" else: return f"未找到{match_mode}'{keywords_str}'的聊天记录概述" - + records = filtered_records - + # 如果没有记录(可能是时间范围查询但没有匹配的记录) if not records: if time_range: @@ -148,22 +145,23 @@ async def query_chat_history( record.count = (record.count or 0) + 1 except Exception as update_error: logger.error(f"更新聊天记录概述计数失败: {update_error}") - + # 构建结果文本 results = [] for record in records_to_use: # 最多返回3条记录 result_parts = [] - + # 添加主题 if record.theme: result_parts.append(f"主题:{record.theme}") - + # 添加时间范围 from datetime import datetime + start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S") end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S") result_parts.append(f"时间:{start_str} - {end_str}") - + # 添加概括(优先使用summary,如果没有则使用original_text的前200字符) if record.summary: result_parts.append(f"概括:{record.summary}") @@ -172,18 +170,18 @@ async def query_chat_history( if len(record.original_text) > 200: text_preview += "..." result_parts.append(f"内容:{text_preview}") - + results.append("\n".join(result_parts)) - + if not results: return "未找到相关聊天记录概述" - + response_text = "\n\n---\n\n".join(results) if len(records) > len(records_to_use): omitted_count = len(records) - len(records_to_use) response_text += f"\n\n(还有{omitted_count}条历史记录已省略)" return response_text - + except Exception as e: logger.error(f"查询聊天历史概述失败: {e}") return f"查询失败: {str(e)}" @@ -199,20 +197,20 @@ def register_tool(): "name": "keyword", "type": "string", "description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)", - "required": False + "required": False, }, { "name": "time_range", "type": "string", "description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)", - "required": False + "required": False, }, { "name": "fuzzy", "type": "boolean", "description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)", - "required": False - } + "required": False, + }, ], - execute_func=query_chat_history + execute_func=query_chat_history, ) diff --git a/src/memory_system/retrieval_tools/query_lpmm_knowledge.py b/src/memory_system/retrieval_tools/query_lpmm_knowledge.py index 20664eea..c1f39270 100644 --- a/src/memory_system/retrieval_tools/query_lpmm_knowledge.py +++ b/src/memory_system/retrieval_tools/query_lpmm_knowledge.py @@ -73,5 +73,3 @@ def register_tool(): ], execute_func=query_lpmm_knowledge, ) - - diff --git a/src/memory_system/retrieval_tools/query_person_info.py b/src/memory_system/retrieval_tools/query_person_info.py index 878daf4d..bc192722 100644 --- a/src/memory_system/retrieval_tools/query_person_info.py +++ b/src/memory_system/retrieval_tools/query_person_info.py @@ -14,23 +14,25 @@ logger = get_logger("memory_retrieval_tools") def _format_group_nick_names(group_nick_name_field) -> str: """格式化群昵称信息 - + Args: group_nick_name_field: 群昵称字段(可能是字符串JSON或None) - + Returns: str: 格式化后的群昵称信息字符串 """ if not group_nick_name_field: return "" - + try: # 解析JSON格式的群昵称列表 - group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field - + group_nick_names_data = ( + json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field + ) + if not isinstance(group_nick_names_data, list) or not group_nick_names_data: return "" - + # 格式化群昵称列表 group_nick_list = [] for item in group_nick_names_data: @@ -41,7 +43,7 @@ def _format_group_nick_names(group_nick_name_field) -> str: elif isinstance(item, str): # 兼容旧格式(如果存在) group_nick_list.append(f" - {item}") - + if group_nick_list: return "群昵称:\n" + "\n".join(group_nick_list) return "" @@ -58,10 +60,10 @@ def _format_group_nick_names(group_nick_name_field) -> str: async def query_person_info(person_name: str) -> str: """根据person_name查询用户信息,使用模糊查询 - + Args: person_name: 用户名称(person_name字段) - + Returns: str: 查询结果,包含用户的所有信息 """ @@ -69,37 +71,35 @@ async def query_person_info(person_name: str) -> str: person_name = str(person_name).strip() if not person_name: return "用户名称为空" - + # 构建查询条件(使用模糊查询) - query = PersonInfo.select().where( - PersonInfo.person_name.contains(person_name) - ) - + query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name)) + # 执行查询 records = list(query.limit(20)) # 最多返回20条记录 - + if not records: return f"未找到模糊匹配'{person_name}'的用户信息" - + # 区分精确匹配和模糊匹配的结果 exact_matches = [] fuzzy_matches = [] - + for record in records: # 检查是否是精确匹配 if record.person_name and record.person_name.strip() == person_name: exact_matches.append(record) else: fuzzy_matches.append(record) - + # 构建结果文本 results = [] - + # 先处理精确匹配的结果 for record in exact_matches: result_parts = [] result_parts.append("【精确匹配】") # 标注为精确匹配 - + # 基本信息 if record.person_name: result_parts.append(f"用户名称:{record.person_name}") @@ -111,19 +111,19 @@ async def query_person_info(person_name: str) -> str: result_parts.append(f"平台:{record.platform}") if record.user_id: result_parts.append(f"平台用户ID:{record.user_id}") - + # 群昵称信息 group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None)) if group_nick_name_str: result_parts.append(group_nick_name_str) - + # 名称设定原因 if record.name_reason: result_parts.append(f"名称设定原因:{record.name_reason}") - + # 认识状态 result_parts.append(f"是否已认识:{'是' if record.is_known else '否'}") - + # 时间信息 if record.know_since: know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S") @@ -133,11 +133,15 @@ async def query_person_info(person_name: str) -> str: result_parts.append(f"最后认识时间:{last_know_str}") if record.know_times: result_parts.append(f"认识次数:{int(record.know_times)}") - + # 记忆点(memory_points) if record.memory_points: try: - memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points + memory_points_data = ( + json.loads(record.memory_points) + if isinstance(record.memory_points, str) + else record.memory_points + ) if isinstance(memory_points_data, list) and memory_points_data: # 解析记忆点格式:category:content:weight memory_list = [] @@ -151,7 +155,7 @@ async def query_person_info(person_name: str) -> str: memory_list.append(f" - [{category}] {content} (权重: {weight})") else: memory_list.append(f" - {memory_point}") - + if memory_list: result_parts.append("记忆点:\n" + "\n".join(memory_list)) except (json.JSONDecodeError, TypeError, ValueError) as e: @@ -161,14 +165,14 @@ async def query_person_info(person_name: str) -> str: if len(str(record.memory_points)) > 200: memory_preview += "..." result_parts.append(f"记忆点(原始数据):{memory_preview}") - + results.append("\n".join(result_parts)) - + # 再处理模糊匹配的结果 for record in fuzzy_matches: result_parts = [] result_parts.append("【模糊匹配】") # 标注为模糊匹配 - + # 基本信息 if record.person_name: result_parts.append(f"用户名称:{record.person_name}") @@ -180,19 +184,19 @@ async def query_person_info(person_name: str) -> str: result_parts.append(f"平台:{record.platform}") if record.user_id: result_parts.append(f"平台用户ID:{record.user_id}") - + # 群昵称信息 group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None)) if group_nick_name_str: result_parts.append(group_nick_name_str) - + # 名称设定原因 if record.name_reason: result_parts.append(f"名称设定原因:{record.name_reason}") - + # 认识状态 result_parts.append(f"是否已认识:{'是' if record.is_known else '否'}") - + # 时间信息 if record.know_since: know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S") @@ -202,11 +206,15 @@ async def query_person_info(person_name: str) -> str: result_parts.append(f"最后认识时间:{last_know_str}") if record.know_times: result_parts.append(f"认识次数:{int(record.know_times)}") - + # 记忆点(memory_points) if record.memory_points: try: - memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points + memory_points_data = ( + json.loads(record.memory_points) + if isinstance(record.memory_points, str) + else record.memory_points + ) if isinstance(memory_points_data, list) and memory_points_data: # 解析记忆点格式:category:content:weight memory_list = [] @@ -220,7 +228,7 @@ async def query_person_info(person_name: str) -> str: memory_list.append(f" - [{category}] {content} (权重: {weight})") else: memory_list.append(f" - {memory_point}") - + if memory_list: result_parts.append("记忆点:\n" + "\n".join(memory_list)) except (json.JSONDecodeError, TypeError, ValueError) as e: @@ -230,20 +238,20 @@ async def query_person_info(person_name: str) -> str: if len(str(record.memory_points)) > 200: memory_preview += "..." result_parts.append(f"记忆点(原始数据):{memory_preview}") - + results.append("\n".join(result_parts)) - + # 组合所有结果 if not results: return f"未找到匹配'{person_name}'的用户信息" - + response_text = "\n\n---\n\n".join(results) - + # 添加统计信息 total_count = len(records) exact_count = len(exact_matches) fuzzy_count = len(fuzzy_matches) - + # 显示精确匹配和模糊匹配的统计 if exact_count > 0 or fuzzy_count > 0: stats_parts = [] @@ -257,13 +265,13 @@ async def query_person_info(person_name: str) -> str: response_text = f"找到 {total_count} 条匹配的用户信息:\n\n{response_text}" else: response_text = f"找到用户信息:\n\n{response_text}" - + # 如果结果数量达到限制,添加提示 if total_count >= 20: response_text += "\n\n(已显示前20条结果,可能还有更多匹配记录)" - + return response_text - + except Exception as e: logger.error(f"查询用户信息失败: {e}") return f"查询失败: {str(e)}" @@ -275,13 +283,7 @@ def register_tool(): name="query_person_info", description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等", parameters=[ - { - "name": "person_name", - "type": "string", - "description": "用户名称,用于查询用户信息", - "required": True - } + {"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True} ], - execute_func=query_person_info + execute_func=query_person_info, ) - diff --git a/src/memory_system/retrieval_tools/tool_registry.py b/src/memory_system/retrieval_tools/tool_registry.py index 143666ab..1e1fa62b 100644 --- a/src/memory_system/retrieval_tools/tool_registry.py +++ b/src/memory_system/retrieval_tools/tool_registry.py @@ -47,10 +47,10 @@ class MemoryRetrievalTool: async def execute(self, **kwargs) -> str: """执行工具""" return await self.execute_func(**kwargs) - + def get_tool_definition(self) -> Dict[str, Any]: """获取工具定义,用于LLM function calling - + Returns: Dict[str, Any]: 工具定义字典,格式与BaseTool一致 格式: {"name": str, "description": str, "parameters": List[Tuple]} @@ -58,14 +58,14 @@ class MemoryRetrievalTool: # 转换参数格式为元组列表,格式与BaseTool一致 # 格式: [("param_name", ToolParamType, "description", required, enum_values)] param_tuples = [] - + for param in self.parameters: param_name = param.get("name", "") param_type_str = param.get("type", "string").lower() param_desc = param.get("description", "") is_required = param.get("required", False) enum_values = param.get("enum", None) - + # 转换类型字符串到ToolParamType type_mapping = { "string": ToolParamType.STRING, @@ -76,18 +76,14 @@ class MemoryRetrievalTool: "bool": ToolParamType.BOOLEAN, } param_type = type_mapping.get(param_type_str, ToolParamType.STRING) - + # 构建参数元组 param_tuple = (param_name, param_type, param_desc, is_required, enum_values) param_tuples.append(param_tuple) - + # 构建工具定义,格式与BaseTool.get_tool_definition()一致 - tool_def = { - "name": self.name, - "description": self.description, - "parameters": param_tuples - } - + tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples} + return tool_def @@ -126,10 +122,10 @@ class MemoryRetrievalToolRegistry: action_types.append("final_answer") action_types.append("no_answer") return " 或 ".join([f'"{at}"' for at in action_types]) - + def get_tool_definitions(self) -> List[Dict[str, Any]]: """获取所有工具的定义列表,用于LLM function calling - + Returns: List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典 """ diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index bbc3fb99..f5cfdcd0 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int: class Person: @classmethod def register_person( - cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None + cls, + platform: str, + user_id: str, + nickname: str, + group_id: Optional[str] = None, + group_nick_name: Optional[str] = None, ): """ 注册新用户的类方法 @@ -727,7 +732,7 @@ person_info_manager = PersonInfoManager() async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None: """将人物信息存入person_info的memory_points - + Args: person_name: 人物名称 memory_content: 记忆内容 @@ -739,13 +744,13 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, if not chat_stream: logger.warning(f"无法获取chat_stream for chat_id: {chat_id}") return - + platform = chat_stream.platform - + # 尝试从person_name查找person_id # 首先尝试通过person_name查找 person_id = get_person_id_by_person_name(person_name) - + if not person_id: # 如果通过person_name找不到,尝试从chat_stream获取user_info if chat_stream.user_info: @@ -754,25 +759,25 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, else: logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") return - + # 创建或获取Person对象 person = Person(person_id=person_id) - + if not person.is_known: logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆") return - + # 确定记忆分类(可以根据memory_content判断,这里使用通用分类) category = "其他" # 默认分类,可以根据需要调整 - + # 记忆点格式:category:content:weight weight = "1.0" # 默认权重 memory_point = f"{category}:{memory_content}:{weight}" - + # 添加到memory_points if not person.memory_points: person.memory_points = [] - + # 检查是否已存在相似的记忆点(避免重复) is_duplicate = False for existing_point in person.memory_points: @@ -781,16 +786,20 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, if len(parts) >= 2: existing_content = parts[1].strip() # 简单相似度检查(如果内容相同或非常相似,则跳过) - if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content: + if ( + existing_content == memory_content + or memory_content in existing_content + or existing_content in memory_content + ): is_duplicate = True break - + if not is_duplicate: person.memory_points.append(memory_point) person.sync_to_database() logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}") else: logger.debug(f"记忆点已存在,跳过: {memory_point}") - + except Exception as e: logger.error(f"存储人物记忆失败: {e}") diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 12c11795..915ed7aa 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -124,7 +124,6 @@ class ToolExecutor: response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async( prompt=prompt, tools=tools, raise_when_empty=False ) - # 执行工具调用 tool_results, used_tools = await self.execute_tool_calls(tool_calls) diff --git a/src/webui/config_routes.py b/src/webui/config_routes.py index 03a4643f..c4a4d417 100644 --- a/src/webui/config_routes.py +++ b/src/webui/config_routes.py @@ -51,7 +51,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None: """ 递归合并字典,保留 target 中的注释和格式 将 source 的值更新到 target 中(仅更新已存在的键) - + Args: target: 目标字典(tomlkit 对象,包含注释) source: 源字典(普通 dict 或 list) @@ -59,7 +59,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None: # 如果 source 是列表,直接替换(数组表没有注释保留的意义) if isinstance(source, list): return # 调用者需要直接赋值 - + # 如果都是字典,递归合并 if isinstance(source, dict) and isinstance(target, dict): for key, value in source.items(): diff --git a/src/webui/emoji_routes.py b/src/webui/emoji_routes.py index 18603258..96899bf3 100644 --- a/src/webui/emoji_routes.py +++ b/src/webui/emoji_routes.py @@ -1,4 +1,5 @@ """表情包管理 API 路由""" + from fastapi import APIRouter, HTTPException, Header, Query from fastapi.responses import FileResponse from pydantic import BaseModel @@ -18,6 +19,7 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"]) class EmojiResponse(BaseModel): """表情包响应""" + id: int full_path: str format: str @@ -35,6 +37,7 @@ class EmojiResponse(BaseModel): class EmojiListResponse(BaseModel): """表情包列表响应""" + success: bool total: int page: int @@ -44,12 +47,14 @@ class EmojiListResponse(BaseModel): class EmojiDetailResponse(BaseModel): """表情包详情响应""" + success: bool data: EmojiResponse class EmojiUpdateRequest(BaseModel): """表情包更新请求""" + description: Optional[str] = None is_registered: Optional[bool] = None is_banned: Optional[bool] = None @@ -58,6 +63,7 @@ class EmojiUpdateRequest(BaseModel): class EmojiUpdateResponse(BaseModel): """表情包更新响应""" + success: bool message: str data: Optional[EmojiResponse] = None @@ -65,6 +71,7 @@ class EmojiUpdateResponse(BaseModel): class EmojiDeleteResponse(BaseModel): """表情包删除响应""" + success: bool message: str @@ -73,13 +80,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool: """验证认证 Token""" if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + token = authorization.replace("Bearer ", "") token_manager = get_token_manager() - + if not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="Token 无效或已过期") - + return True @@ -120,11 +127,11 @@ async def get_emoji_list( is_registered: Optional[bool] = Query(None, description="是否已注册筛选"), is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"), format: Optional[str] = Query(None, description="格式筛选"), - authorization: Optional[str] = Header(None) + authorization: Optional[str] = Header(None), ): """ 获取表情包列表 - + Args: page: 页码 (从 1 开始) page_size: 每页数量 (1-100) @@ -133,61 +140,51 @@ async def get_emoji_list( is_banned: 是否被禁用筛选 format: 格式筛选 authorization: Authorization header - + Returns: 表情包列表 """ try: verify_auth_token(authorization) - + # 构建查询 query = Emoji.select() - + # 搜索过滤 if search: - query = query.where( - (Emoji.description.contains(search)) | - (Emoji.emoji_hash.contains(search)) - ) - + query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search))) + # 注册状态过滤 if is_registered is not None: query = query.where(Emoji.is_registered == is_registered) - + # 禁用状态过滤 if is_banned is not None: query = query.where(Emoji.is_banned == is_banned) - + # 格式过滤 if format: query = query.where(Emoji.format == format) - + # 排序:使用次数倒序,然后按记录时间倒序 from peewee import Case + query = query.order_by( - Emoji.usage_count.desc(), - Case(None, [(Emoji.record_time.is_null(), 1)], 0), - Emoji.record_time.desc() + Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc() ) - + # 获取总数 total = query.count() - + # 分页 offset = (page - 1) * page_size emojis = query.offset(offset).limit(page_size) - + # 转换为响应对象 data = [emoji_to_response(emoji) for emoji in emojis] - - return EmojiListResponse( - success=True, - total=total, - page=page, - page_size=page_size, - data=data - ) - + + return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data) + except HTTPException: raise except Exception as e: @@ -196,33 +193,27 @@ async def get_emoji_list( @router.get("/{emoji_id}", response_model=EmojiDetailResponse) -async def get_emoji_detail( - emoji_id: int, - authorization: Optional[str] = Header(None) -): +async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)): """ 获取表情包详细信息 - + Args: emoji_id: 表情包ID authorization: Authorization header - + Returns: 表情包详细信息 """ try: verify_auth_token(authorization) - + emoji = Emoji.get_or_none(Emoji.id == emoji_id) - + if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") - - return EmojiDetailResponse( - success=True, - data=emoji_to_response(emoji) - ) - + + return EmojiDetailResponse(success=True, data=emoji_to_response(emoji)) + except HTTPException: raise except Exception as e: @@ -231,61 +222,55 @@ async def get_emoji_detail( @router.patch("/{emoji_id}", response_model=EmojiUpdateResponse) -async def update_emoji( - emoji_id: int, - request: EmojiUpdateRequest, - authorization: Optional[str] = Header(None) -): +async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)): """ 增量更新表情包(只更新提供的字段) - + Args: emoji_id: 表情包ID request: 更新请求(只包含需要更新的字段) authorization: Authorization header - + Returns: 更新结果 """ try: verify_auth_token(authorization) - + emoji = Emoji.get_or_none(Emoji.id == emoji_id) - + if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") - + # 只更新提供的字段 update_data = request.model_dump(exclude_unset=True) - + if not update_data: raise HTTPException(status_code=400, detail="未提供任何需要更新的字段") - + # 处理情感标签(转换为 JSON) - if 'emotion' in update_data: - if update_data['emotion'] is None: - update_data['emotion'] = None + if "emotion" in update_data: + if update_data["emotion"] is None: + update_data["emotion"] = None else: - update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False) - + update_data["emotion"] = json.dumps(update_data["emotion"], ensure_ascii=False) + # 如果注册状态从 False 变为 True,记录注册时间 - if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered: - update_data['register_time'] = time.time() - + if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered: + update_data["register_time"] = time.time() + # 执行更新 for field, value in update_data.items(): setattr(emoji, field, value) - + emoji.save() - + logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}") - + return EmojiUpdateResponse( - success=True, - message=f"成功更新 {len(update_data)} 个字段", - data=emoji_to_response(emoji) + success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji) ) - + except HTTPException: raise except Exception as e: @@ -294,41 +279,35 @@ async def update_emoji( @router.delete("/{emoji_id}", response_model=EmojiDeleteResponse) -async def delete_emoji( - emoji_id: int, - authorization: Optional[str] = Header(None) -): +async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)): """ 删除表情包 - + Args: emoji_id: 表情包ID authorization: Authorization header - + Returns: 删除结果 """ try: verify_auth_token(authorization) - + emoji = Emoji.get_or_none(Emoji.id == emoji_id) - + if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") - + # 记录删除信息 emoji_hash = emoji.emoji_hash - + # 执行删除 emoji.delete_instance() - + logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}") - - return EmojiDeleteResponse( - success=True, - message=f"成功删除表情包: {emoji_hash}" - ) - + + return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}") + except HTTPException: raise except Exception as e: @@ -337,31 +316,29 @@ async def delete_emoji( @router.get("/stats/summary") -async def get_emoji_stats( - authorization: Optional[str] = Header(None) -): +async def get_emoji_stats(authorization: Optional[str] = Header(None)): """ 获取表情包统计数据 - + Args: authorization: Authorization header - + Returns: 统计数据 """ try: verify_auth_token(authorization) - + total = Emoji.select().count() registered = Emoji.select().where(Emoji.is_registered).count() banned = Emoji.select().where(Emoji.is_banned).count() - + # 按格式统计 formats = {} for emoji in Emoji.select(Emoji.format): fmt = emoji.format formats[fmt] = formats.get(fmt, 0) + 1 - + # 获取最常用的表情包(前10) top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10) top_used_list = [ @@ -369,11 +346,11 @@ async def get_emoji_stats( "id": emoji.id, "emoji_hash": emoji.emoji_hash, "description": emoji.description, - "usage_count": emoji.usage_count + "usage_count": emoji.usage_count, } for emoji in top_used ] - + return { "success": True, "data": { @@ -382,10 +359,10 @@ async def get_emoji_stats( "banned": banned, "unregistered": total - registered, "formats": formats, - "top_used": top_used_list - } + "top_used": top_used_list, + }, } - + except HTTPException: raise except Exception as e: @@ -394,47 +371,40 @@ async def get_emoji_stats( @router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse) -async def register_emoji( - emoji_id: int, - authorization: Optional[str] = Header(None) -): +async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)): """ 注册表情包(快捷操作) - + Args: emoji_id: 表情包ID authorization: Authorization header - + Returns: 更新结果 """ try: verify_auth_token(authorization) - + emoji = Emoji.get_or_none(Emoji.id == emoji_id) - + if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") - + if emoji.is_registered: raise HTTPException(status_code=400, detail="该表情包已经注册") - + if emoji.is_banned: raise HTTPException(status_code=400, detail="该表情包已被禁用,无法注册") - + # 注册表情包 emoji.is_registered = True emoji.register_time = time.time() emoji.save() - + logger.info(f"表情包已注册: ID={emoji_id}") - - return EmojiUpdateResponse( - success=True, - message="表情包注册成功", - data=emoji_to_response(emoji) - ) - + + return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji)) + except HTTPException: raise except Exception as e: @@ -443,41 +413,34 @@ async def register_emoji( @router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse) -async def ban_emoji( - emoji_id: int, - authorization: Optional[str] = Header(None) -): +async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)): """ 禁用表情包(快捷操作) - + Args: emoji_id: 表情包ID authorization: Authorization header - + Returns: 更新结果 """ try: verify_auth_token(authorization) - + emoji = Emoji.get_or_none(Emoji.id == emoji_id) - + if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") - + # 禁用表情包(同时取消注册) emoji.is_banned = True emoji.is_registered = False emoji.save() - + logger.info(f"表情包已禁用: ID={emoji_id}") - - return EmojiUpdateResponse( - success=True, - message="表情包禁用成功", - data=emoji_to_response(emoji) - ) - + + return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji)) + except HTTPException: raise except Exception as e: @@ -489,16 +452,16 @@ async def ban_emoji( async def get_emoji_thumbnail( emoji_id: int, token: Optional[str] = Query(None, description="访问令牌"), - authorization: Optional[str] = Header(None) + authorization: Optional[str] = Header(None), ): """ 获取表情包缩略图 - + Args: emoji_id: 表情包ID token: 访问令牌(通过 query parameter) authorization: Authorization header - + Returns: 表情包图片文件 """ @@ -511,37 +474,32 @@ async def get_emoji_thumbnail( else: # 如果没有 query token,则验证 Authorization header verify_auth_token(authorization) - + emoji = Emoji.get_or_none(Emoji.id == emoji_id) - + if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") - + # 检查文件是否存在 if not os.path.exists(emoji.full_path): raise HTTPException(status_code=404, detail="表情包文件不存在") - + # 根据格式设置 MIME 类型 mime_types = { - 'png': 'image/png', - 'jpg': 'image/jpeg', - 'jpeg': 'image/jpeg', - 'gif': 'image/gif', - 'webp': 'image/webp', - 'bmp': 'image/bmp' + "png": "image/png", + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "gif": "image/gif", + "webp": "image/webp", + "bmp": "image/bmp", } - - media_type = mime_types.get(emoji.format.lower(), 'application/octet-stream') - - return FileResponse( - path=emoji.full_path, - media_type=media_type, - filename=f"{emoji.emoji_hash}.{emoji.format}" - ) - + + media_type = mime_types.get(emoji.format.lower(), "application/octet-stream") + + return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}") + except HTTPException: raise except Exception as e: logger.exception(f"获取表情包缩略图失败: {e}") raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e - diff --git a/src/webui/expression_routes.py b/src/webui/expression_routes.py index de2594ee..aa9261d2 100644 --- a/src/webui/expression_routes.py +++ b/src/webui/expression_routes.py @@ -1,4 +1,5 @@ """表达方式管理 API 路由""" + from fastapi import APIRouter, HTTPException, Header, Query from pydantic import BaseModel from typing import Optional, List @@ -15,6 +16,7 @@ router = APIRouter(prefix="/expression", tags=["Expression"]) class ExpressionResponse(BaseModel): """表达方式响应""" + id: int situation: str style: str @@ -27,6 +29,7 @@ class ExpressionResponse(BaseModel): class ExpressionListResponse(BaseModel): """表达方式列表响应""" + success: bool total: int page: int @@ -36,12 +39,14 @@ class ExpressionListResponse(BaseModel): class ExpressionDetailResponse(BaseModel): """表达方式详情响应""" + success: bool data: ExpressionResponse class ExpressionCreateRequest(BaseModel): """表达方式创建请求""" + situation: str style: str context: Optional[str] = None @@ -51,6 +56,7 @@ class ExpressionCreateRequest(BaseModel): class ExpressionUpdateRequest(BaseModel): """表达方式更新请求""" + situation: Optional[str] = None style: Optional[str] = None context: Optional[str] = None @@ -60,6 +66,7 @@ class ExpressionUpdateRequest(BaseModel): class ExpressionUpdateResponse(BaseModel): """表达方式更新响应""" + success: bool message: str data: Optional[ExpressionResponse] = None @@ -67,12 +74,14 @@ class ExpressionUpdateResponse(BaseModel): class ExpressionDeleteResponse(BaseModel): """表达方式删除响应""" + success: bool message: str class ExpressionCreateResponse(BaseModel): """表达方式创建响应""" + success: bool message: str data: ExpressionResponse @@ -82,13 +91,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool: """验证认证 Token""" if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + token = authorization.replace("Bearer ", "") token_manager = get_token_manager() - + if not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="Token 无效或已过期") - + return True @@ -112,64 +121,58 @@ async def get_expression_list( page_size: int = Query(20, ge=1, le=100, description="每页数量"), search: Optional[str] = Query(None, description="搜索关键词"), chat_id: Optional[str] = Query(None, description="聊天ID筛选"), - authorization: Optional[str] = Header(None) + authorization: Optional[str] = Header(None), ): """ 获取表达方式列表 - + Args: page: 页码 (从 1 开始) page_size: 每页数量 (1-100) search: 搜索关键词 (匹配 situation, style, context) chat_id: 聊天ID筛选 authorization: Authorization header - + Returns: 表达方式列表 """ try: verify_auth_token(authorization) - + # 构建查询 query = Expression.select() - + # 搜索过滤 if search: query = query.where( - (Expression.situation.contains(search)) | - (Expression.style.contains(search)) | - (Expression.context.contains(search)) + (Expression.situation.contains(search)) + | (Expression.style.contains(search)) + | (Expression.context.contains(search)) ) - + # 聊天ID过滤 if chat_id: query = query.where(Expression.chat_id == chat_id) - + # 排序:最后活跃时间倒序(NULL 值放在最后) from peewee import Case + query = query.order_by( - Case(None, [(Expression.last_active_time.is_null(), 1)], 0), - Expression.last_active_time.desc() + Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc() ) - + # 获取总数 total = query.count() - + # 分页 offset = (page - 1) * page_size expressions = query.offset(offset).limit(page_size) - + # 转换为响应对象 data = [expression_to_response(expr) for expr in expressions] - - return ExpressionListResponse( - success=True, - total=total, - page=page, - page_size=page_size, - data=data - ) - + + return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data) + except HTTPException: raise except Exception as e: @@ -178,33 +181,27 @@ async def get_expression_list( @router.get("/{expression_id}", response_model=ExpressionDetailResponse) -async def get_expression_detail( - expression_id: int, - authorization: Optional[str] = Header(None) -): +async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)): """ 获取表达方式详细信息 - + Args: expression_id: 表达方式ID authorization: Authorization header - + Returns: 表达方式详细信息 """ try: verify_auth_token(authorization) - + expression = Expression.get_or_none(Expression.id == expression_id) - + if not expression: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式") - - return ExpressionDetailResponse( - success=True, - data=expression_to_response(expression) - ) - + + return ExpressionDetailResponse(success=True, data=expression_to_response(expression)) + except HTTPException: raise except Exception as e: @@ -213,25 +210,22 @@ async def get_expression_detail( @router.post("/", response_model=ExpressionCreateResponse) -async def create_expression( - request: ExpressionCreateRequest, - authorization: Optional[str] = Header(None) -): +async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)): """ 创建新的表达方式 - + Args: request: 创建请求 authorization: Authorization header - + Returns: 创建结果 """ try: verify_auth_token(authorization) - + current_time = time.time() - + # 创建表达方式 expression = Expression.create( situation=request.situation, @@ -242,15 +236,13 @@ async def create_expression( last_active_time=current_time, create_date=current_time, ) - + logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}") - + return ExpressionCreateResponse( - success=True, - message="表达方式创建成功", - data=expression_to_response(expression) + success=True, message="表达方式创建成功", data=expression_to_response(expression) ) - + except HTTPException: raise except Exception as e: @@ -260,52 +252,48 @@ async def create_expression( @router.patch("/{expression_id}", response_model=ExpressionUpdateResponse) async def update_expression( - expression_id: int, - request: ExpressionUpdateRequest, - authorization: Optional[str] = Header(None) + expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None) ): """ 增量更新表达方式(只更新提供的字段) - + Args: expression_id: 表达方式ID request: 更新请求(只包含需要更新的字段) authorization: Authorization header - + Returns: 更新结果 """ try: verify_auth_token(authorization) - + expression = Expression.get_or_none(Expression.id == expression_id) - + if not expression: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式") - + # 只更新提供的字段 update_data = request.model_dump(exclude_unset=True) - + if not update_data: raise HTTPException(status_code=400, detail="未提供任何需要更新的字段") - + # 更新最后活跃时间 - update_data['last_active_time'] = time.time() - + update_data["last_active_time"] = time.time() + # 执行更新 for field, value in update_data.items(): setattr(expression, field, value) - + expression.save() - + logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}") - + return ExpressionUpdateResponse( - success=True, - message=f"成功更新 {len(update_data)} 个字段", - data=expression_to_response(expression) + success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression) ) - + except HTTPException: raise except Exception as e: @@ -314,41 +302,35 @@ async def update_expression( @router.delete("/{expression_id}", response_model=ExpressionDeleteResponse) -async def delete_expression( - expression_id: int, - authorization: Optional[str] = Header(None) -): +async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)): """ 删除表达方式 - + Args: expression_id: 表达方式ID authorization: Authorization header - + Returns: 删除结果 """ try: verify_auth_token(authorization) - + expression = Expression.get_or_none(Expression.id == expression_id) - + if not expression: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式") - + # 记录删除信息 situation = expression.situation - + # 执行删除 expression.delete_instance() - + logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}") - - return ExpressionDeleteResponse( - success=True, - message=f"成功删除表达方式: {situation}" - ) - + + return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}") + except HTTPException: raise except Exception as e: @@ -357,46 +339,45 @@ async def delete_expression( @router.get("/stats/summary") -async def get_expression_stats( - authorization: Optional[str] = Header(None) -): +async def get_expression_stats(authorization: Optional[str] = Header(None)): """ 获取表达方式统计数据 - + Args: authorization: Authorization header - + Returns: 统计数据 """ try: verify_auth_token(authorization) - + total = Expression.select().count() - + # 按 chat_id 统计 chat_stats = {} for expr in Expression.select(Expression.chat_id): chat_id = expr.chat_id chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1 - + # 获取最近创建的记录数(7天内) seven_days_ago = time.time() - (7 * 24 * 60 * 60) - recent = Expression.select().where( - (Expression.create_date.is_null(False)) & - (Expression.create_date >= seven_days_ago) - ).count() - + recent = ( + Expression.select() + .where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago)) + .count() + ) + return { "success": True, "data": { "total": total, "recent_7days": recent, "chat_count": len(chat_stats), - "top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]) - } + "top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]), + }, } - + except HTTPException: raise except Exception as e: diff --git a/src/webui/git_mirror_service.py b/src/webui/git_mirror_service.py index 02645f70..df00cde9 100644 --- a/src/webui/git_mirror_service.py +++ b/src/webui/git_mirror_service.py @@ -1,4 +1,5 @@ """Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取""" + from typing import Optional, List, Dict, Any from enum import Enum import httpx @@ -15,6 +16,7 @@ logger = get_logger("webui.git_mirror") # 导入进度更新函数(避免循环导入) _update_progress = None + def set_update_progress_callback(callback): """设置进度更新回调函数""" global _update_progress @@ -23,6 +25,7 @@ def set_update_progress_callback(callback): class MirrorType(str, Enum): """镜像源类型""" + GH_PROXY = "gh-proxy" # gh-proxy 主节点 HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点 CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点 @@ -34,10 +37,10 @@ class MirrorType(str, Enum): class GitMirrorConfig: """Git 镜像源配置管理""" - + # 配置文件路径 CONFIG_FILE = Path("data/webui.json") - + # 默认镜像源配置 DEFAULT_MIRRORS = [ { @@ -47,7 +50,7 @@ class GitMirrorConfig: "clone_prefix": "https://gh-proxy.org/https://github.com", "enabled": True, "priority": 1, - "created_at": None + "created_at": None, }, { "id": "hk-gh-proxy", @@ -56,7 +59,7 @@ class GitMirrorConfig: "clone_prefix": "https://hk.gh-proxy.org/https://github.com", "enabled": True, "priority": 2, - "created_at": None + "created_at": None, }, { "id": "cdn-gh-proxy", @@ -65,7 +68,7 @@ class GitMirrorConfig: "clone_prefix": "https://cdn.gh-proxy.org/https://github.com", "enabled": True, "priority": 3, - "created_at": None + "created_at": None, }, { "id": "edgeone-gh-proxy", @@ -74,7 +77,7 @@ class GitMirrorConfig: "clone_prefix": "https://edgeone.gh-proxy.org/https://github.com", "enabled": True, "priority": 4, - "created_at": None + "created_at": None, }, { "id": "meyzh-github", @@ -83,7 +86,7 @@ class GitMirrorConfig: "clone_prefix": "https://meyzh.github.io/https://github.com", "enabled": True, "priority": 5, - "created_at": None + "created_at": None, }, { "id": "github", @@ -92,23 +95,23 @@ class GitMirrorConfig: "clone_prefix": "https://github.com", "enabled": True, "priority": 999, - "created_at": None - } + "created_at": None, + }, ] - + def __init__(self): """初始化配置管理器""" self.config_file = self.CONFIG_FILE self.mirrors: List[Dict[str, Any]] = [] self._load_config() - + def _load_config(self) -> None: """加载配置文件""" try: if self.config_file.exists(): - with open(self.config_file, 'r', encoding='utf-8') as f: + with open(self.config_file, "r", encoding="utf-8") as f: data = json.load(f) - + # 检查是否有镜像源配置 if "git_mirrors" not in data or not data["git_mirrors"]: logger.info("配置文件中未找到镜像源配置,使用默认配置") @@ -122,59 +125,59 @@ class GitMirrorConfig: except Exception as e: logger.error(f"加载配置文件失败: {e}") self._init_default_mirrors() - + def _init_default_mirrors(self) -> None: """初始化默认镜像源""" current_time = datetime.now().isoformat() self.mirrors = [] - + for mirror in self.DEFAULT_MIRRORS: mirror_copy = mirror.copy() mirror_copy["created_at"] = current_time self.mirrors.append(mirror_copy) - + self._save_config() logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源") - + def _save_config(self) -> None: """保存配置到文件""" try: # 确保目录存在 self.config_file.parent.mkdir(parents=True, exist_ok=True) - + # 读取现有配置 existing_data = {} if self.config_file.exists(): - with open(self.config_file, 'r', encoding='utf-8') as f: + with open(self.config_file, "r", encoding="utf-8") as f: existing_data = json.load(f) - + # 更新镜像源配置 existing_data["git_mirrors"] = self.mirrors - + # 写入文件 - with open(self.config_file, 'w', encoding='utf-8') as f: + with open(self.config_file, "w", encoding="utf-8") as f: json.dump(existing_data, f, indent=2, ensure_ascii=False) - + logger.debug(f"配置已保存到 {self.config_file}") except Exception as e: logger.error(f"保存配置文件失败: {e}") - + def get_all_mirrors(self) -> List[Dict[str, Any]]: """获取所有镜像源""" return self.mirrors.copy() - + def get_enabled_mirrors(self) -> List[Dict[str, Any]]: """获取所有启用的镜像源,按优先级排序""" enabled = [m for m in self.mirrors if m.get("enabled", False)] return sorted(enabled, key=lambda x: x.get("priority", 999)) - + def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]: """根据 ID 获取镜像源""" for mirror in self.mirrors: if mirror.get("id") == mirror_id: return mirror.copy() return None - + def add_mirror( self, mirror_id: str, @@ -182,26 +185,26 @@ class GitMirrorConfig: raw_prefix: str, clone_prefix: str, enabled: bool = True, - priority: Optional[int] = None + priority: Optional[int] = None, ) -> Dict[str, Any]: """ 添加新的镜像源 - + Returns: 添加的镜像源配置 - + Raises: ValueError: 如果镜像源 ID 已存在 """ # 检查 ID 是否已存在 if self.get_mirror_by_id(mirror_id): raise ValueError(f"镜像源 ID 已存在: {mirror_id}") - + # 如果未指定优先级,使用最大优先级 + 1 if priority is None: max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0) priority = max_priority + 1 - + new_mirror = { "id": mirror_id, "name": name, @@ -209,15 +212,15 @@ class GitMirrorConfig: "clone_prefix": clone_prefix, "enabled": enabled, "priority": priority, - "created_at": datetime.now().isoformat() + "created_at": datetime.now().isoformat(), } - + self.mirrors.append(new_mirror) self._save_config() - + logger.info(f"已添加镜像源: {mirror_id} - {name}") return new_mirror.copy() - + def update_mirror( self, mirror_id: str, @@ -225,11 +228,11 @@ class GitMirrorConfig: raw_prefix: Optional[str] = None, clone_prefix: Optional[str] = None, enabled: Optional[bool] = None, - priority: Optional[int] = None + priority: Optional[int] = None, ) -> Optional[Dict[str, Any]]: """ 更新镜像源配置 - + Returns: 更新后的镜像源配置,如果不存在则返回 None """ @@ -245,19 +248,19 @@ class GitMirrorConfig: mirror["enabled"] = enabled if priority is not None: mirror["priority"] = priority - + mirror["updated_at"] = datetime.now().isoformat() self._save_config() - + logger.info(f"已更新镜像源: {mirror_id}") return mirror.copy() - + return None - + def delete_mirror(self, mirror_id: str) -> bool: """ 删除镜像源 - + Returns: True 如果删除成功,False 如果镜像源不存在 """ @@ -267,9 +270,9 @@ class GitMirrorConfig: self._save_config() logger.info(f"已删除镜像源: {mirror_id}") return True - + return False - + def get_default_priority_list(self) -> List[str]: """获取默认优先级列表(仅启用的镜像源 ID)""" enabled = self.get_enabled_mirrors() @@ -278,16 +281,11 @@ class GitMirrorConfig: class GitMirrorService: """Git 镜像源服务""" - - def __init__( - self, - max_retries: int = 3, - timeout: int = 30, - config: Optional[GitMirrorConfig] = None - ): + + def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None): """ 初始化 Git 镜像源服务 - + Args: max_retries: 最大重试次数 timeout: 请求超时时间(秒) @@ -297,16 +295,16 @@ class GitMirrorService: self.timeout = timeout self.config = config or GitMirrorConfig() logger.info(f"Git镜像源服务初始化完成,已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源") - + def get_mirror_config(self) -> GitMirrorConfig: """获取镜像源配置管理器""" return self.config - + @staticmethod def check_git_installed() -> Dict[str, Any]: """ 检查本机是否安装了 Git - + Returns: Dict 包含: - installed: bool - 是否已安装 Git @@ -316,54 +314,33 @@ class GitMirrorService: """ import subprocess import shutil - + try: # 查找 git 可执行文件路径 git_path = shutil.which("git") - + if not git_path: logger.warning("未找到 Git 可执行文件") - return { - "installed": False, - "error": "系统中未找到 Git,请先安装 Git" - } - + return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"} + # 获取 Git 版本 - result = subprocess.run( - ["git", "--version"], - capture_output=True, - text=True, - timeout=5 - ) - + result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5) + if result.returncode == 0: version = result.stdout.strip() logger.info(f"检测到 Git: {version} at {git_path}") - return { - "installed": True, - "version": version, - "path": git_path - } + return {"installed": True, "version": version, "path": git_path} else: logger.warning(f"Git 命令执行失败: {result.stderr}") - return { - "installed": False, - "error": f"Git 命令执行失败: {result.stderr}" - } - + return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"} + except subprocess.TimeoutExpired: logger.error("Git 版本检测超时") - return { - "installed": False, - "error": "Git 版本检测超时" - } + return {"installed": False, "error": "Git 版本检测超时"} except Exception as e: logger.error(f"检测 Git 时发生错误: {e}") - return { - "installed": False, - "error": f"检测 Git 时发生错误: {str(e)}" - } - + return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"} + async def fetch_raw_file( self, owner: str, @@ -371,11 +348,11 @@ class GitMirrorService: branch: str, file_path: str, mirror_id: Optional[str] = None, - custom_url: Optional[str] = None + custom_url: Optional[str] = None, ) -> Dict[str, Any]: """ 获取 GitHub 仓库的 Raw 文件内容 - + Args: owner: 仓库所有者 repo: 仓库名称 @@ -383,7 +360,7 @@ class GitMirrorService: file_path: 文件路径 mirror_id: 指定的镜像源 ID custom_url: 自定义完整 URL(如果提供,将忽略其他参数) - + Returns: Dict 包含: - success: bool - 是否成功 @@ -393,29 +370,24 @@ class GitMirrorService: - attempts: int - 尝试次数 """ logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}") - + if custom_url: # 使用自定义 URL return await self._fetch_with_url(custom_url, "custom") - + # 确定要使用的镜像源列表 if mirror_id: # 使用指定的镜像源 mirror = self.config.get_mirror_by_id(mirror_id) if not mirror: - return { - "success": False, - "error": f"未找到镜像源: {mirror_id}", - "mirror_used": None, - "attempts": 0 - } + return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0} mirrors_to_try = [mirror] else: # 使用所有启用的镜像源 mirrors_to_try = self.config.get_enabled_mirrors() - + total_mirrors = len(mirrors_to_try) - + # 依次尝试每个镜像源 for index, mirror in enumerate(mirrors_to_try, 1): # 推送进度:正在尝试第 N 个镜像源 @@ -427,15 +399,13 @@ class GitMirrorService: progress=progress, message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}", total_plugins=0, - loaded_plugins=0 + loaded_plugins=0, ) except Exception as e: logger.warning(f"推送进度失败: {e}") - - result = await self._fetch_raw_from_mirror( - owner, repo, branch, file_path, mirror - ) - + + result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror) + if result["success"]: # 成功,推送进度 if _update_progress: @@ -445,15 +415,15 @@ class GitMirrorService: progress=70, message=f"成功从 {mirror['name']} 获取数据", total_plugins=0, - loaded_plugins=0 + loaded_plugins=0, ) except Exception as e: logger.warning(f"推送进度失败: {e}") return result - + # 失败,记录日志并推送失败信息 logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}") - + if _update_progress and index < total_mirrors: try: await _update_progress( @@ -461,39 +431,29 @@ class GitMirrorService: progress=30 + int(index / total_mirrors * 40), message=f"镜像源 {mirror['name']} 失败,尝试下一个...", total_plugins=0, - loaded_plugins=0 + loaded_plugins=0, ) except Exception as e: logger.warning(f"推送进度失败: {e}") - + # 所有镜像源都失败 - return { - "success": False, - "error": "所有镜像源均失败", - "mirror_used": None, - "attempts": len(mirrors_to_try) - } - + return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)} + async def _fetch_raw_from_mirror( - self, - owner: str, - repo: str, - branch: str, - file_path: str, - mirror: Dict[str, Any] + self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any] ) -> Dict[str, Any]: """从指定镜像源获取文件""" # 构建 URL raw_prefix = mirror["raw_prefix"] url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}" - + return await self._fetch_with_url(url, mirror["id"]) - + async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]: """使用指定 URL 获取文件,支持重试""" attempts = 0 last_error = None - + for attempt in range(self.max_retries): attempts += 1 try: @@ -501,14 +461,14 @@ class GitMirrorService: async with httpx.AsyncClient(timeout=self.timeout) as client: response = await client.get(url) response.raise_for_status() - + logger.info(f"成功获取文件: {url}") return { "success": True, "data": response.text, "mirror_used": mirror_type, "attempts": attempts, - "url": url + "url": url, } except httpx.HTTPStatusError as e: last_error = f"HTTP {e.response.status_code}: {e}" @@ -519,15 +479,9 @@ class GitMirrorService: except Exception as e: last_error = f"未知错误: {e}" logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}") - - return { - "success": False, - "error": last_error, - "mirror_used": mirror_type, - "attempts": attempts, - "url": url - } - + + return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url} + async def clone_repository( self, owner: str, @@ -536,11 +490,11 @@ class GitMirrorService: branch: Optional[str] = None, mirror_id: Optional[str] = None, custom_url: Optional[str] = None, - depth: Optional[int] = None + depth: Optional[int] = None, ) -> Dict[str, Any]: """ 克隆 GitHub 仓库 - + Args: owner: 仓库所有者 repo: 仓库名称 @@ -549,7 +503,7 @@ class GitMirrorService: mirror_id: 指定的镜像源 ID custom_url: 自定义克隆 URL depth: 克隆深度(浅克隆) - + Returns: Dict 包含: - success: bool - 是否成功 @@ -559,44 +513,32 @@ class GitMirrorService: - attempts: int - 尝试次数 """ logger.info(f"开始克隆仓库: {owner}/{repo} 到 {target_path}") - + if custom_url: # 使用自定义 URL return await self._clone_with_url(custom_url, target_path, branch, depth, "custom") - + # 确定要使用的镜像源列表 if mirror_id: # 使用指定的镜像源 mirror = self.config.get_mirror_by_id(mirror_id) if not mirror: - return { - "success": False, - "error": f"未找到镜像源: {mirror_id}", - "mirror_used": None, - "attempts": 0 - } + return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0} mirrors_to_try = [mirror] else: # 使用所有启用的镜像源 mirrors_to_try = self.config.get_enabled_mirrors() - + # 依次尝试每个镜像源 for mirror in mirrors_to_try: - result = await self._clone_from_mirror( - owner, repo, target_path, branch, depth, mirror - ) + result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror) if result["success"]: return result logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}") - + # 所有镜像源都失败 - return { - "success": False, - "error": "所有镜像源克隆均失败", - "mirror_used": None, - "attempts": len(mirrors_to_try) - } - + return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)} + async def _clone_from_mirror( self, owner: str, @@ -604,52 +546,47 @@ class GitMirrorService: target_path: Path, branch: Optional[str], depth: Optional[int], - mirror: Dict[str, Any] + mirror: Dict[str, Any], ) -> Dict[str, Any]: """从指定镜像源克隆仓库""" # 构建克隆 URL clone_prefix = mirror["clone_prefix"] url = f"{clone_prefix}/{owner}/{repo}.git" - + return await self._clone_with_url(url, target_path, branch, depth, mirror["id"]) - + async def _clone_with_url( - self, - url: str, - target_path: Path, - branch: Optional[str], - depth: Optional[int], - mirror_type: str + self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str ) -> Dict[str, Any]: """使用指定 URL 克隆仓库,支持重试""" attempts = 0 last_error = None - + for attempt in range(self.max_retries): attempts += 1 - + try: # 确保目标路径不存在 if target_path.exists(): logger.warning(f"目标路径已存在,删除: {target_path}") shutil.rmtree(target_path, ignore_errors=True) - + # 构建 git clone 命令 cmd = ["git", "clone"] - + # 添加分支参数 if branch: cmd.extend(["-b", branch]) - + # 添加深度参数(浅克隆) if depth: cmd.extend(["--depth", str(depth)]) - + # 添加 URL 和目标路径 cmd.extend([url, str(target_path)]) - + logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}") - + # 推送进度 if _update_progress: try: @@ -657,24 +594,24 @@ class GitMirrorService: stage="loading", progress=20 + attempt * 10, message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...", - operation="install" + operation="install", ) except Exception as e: logger.warning(f"推送进度失败: {e}") - + # 执行 git clone(在线程池中运行以避免阻塞) loop = asyncio.get_event_loop() - + def run_git_clone(): return subprocess.run( cmd, capture_output=True, text=True, - timeout=300 # 5分钟超时 + timeout=300, # 5分钟超时 ) - + process = await loop.run_in_executor(None, run_git_clone) - + if process.returncode == 0: logger.info(f"成功克隆仓库: {url} -> {target_path}") return { @@ -683,40 +620,34 @@ class GitMirrorService: "mirror_used": mirror_type, "attempts": attempts, "url": url, - "branch": branch or "default" + "branch": branch or "default", } else: last_error = f"Git 克隆失败: {process.stderr}" logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}") - + except subprocess.TimeoutExpired: last_error = "克隆超时(超过 5 分钟)" logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})") - + # 清理可能的部分克隆 if target_path.exists(): shutil.rmtree(target_path, ignore_errors=True) - + except FileNotFoundError: last_error = "Git 未安装或不在 PATH 中" logger.error(f"Git 未找到: {last_error}") break # Git 不存在,不需要重试 - + except Exception as e: last_error = f"未知错误: {e}" logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}") - + # 清理可能的部分克隆 if target_path.exists(): shutil.rmtree(target_path, ignore_errors=True) - - return { - "success": False, - "error": last_error, - "mirror_used": mirror_type, - "attempts": attempts, - "url": url - } + + return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url} # 全局服务实例 diff --git a/src/webui/logs_ws.py b/src/webui/logs_ws.py index d8ef65aa..e0e0a9a1 100644 --- a/src/webui/logs_ws.py +++ b/src/webui/logs_ws.py @@ -1,4 +1,5 @@ """WebSocket 日志推送模块""" + from fastapi import APIRouter, WebSocket, WebSocketDisconnect from typing import Set import json @@ -14,30 +15,30 @@ active_connections: Set[WebSocket] = set() def load_recent_logs(limit: int = 100) -> list[dict]: """从日志文件中加载最近的日志 - + Args: limit: 返回的最大日志条数 - + Returns: 日志列表 """ logs = [] log_dir = Path("logs") - + if not log_dir.exists(): return logs - + # 获取所有日志文件,按修改时间排序 log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True) - + # 用于生成唯一 ID 的计数器 log_counter = 0 - + # 从最新的文件开始读取 for log_file in log_files: if len(logs) >= limit: break - + try: with open(log_file, "r", encoding="utf-8") as f: lines = f.readlines() @@ -49,7 +50,9 @@ def load_recent_logs(limit: int = 100) -> list[dict]: log_entry = json.loads(line.strip()) # 转换为前端期望的格式 # 使用时间戳 + 计数器生成唯一 ID - timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "") + timestamp_id = ( + log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "") + ) formatted_log = { "id": f"{timestamp_id}_{log_counter}", "timestamp": log_entry.get("timestamp", ""), @@ -64,7 +67,7 @@ def load_recent_logs(limit: int = 100) -> list[dict]: except Exception as e: logger.error(f"读取日志文件失败 {log_file}: {e}") continue - + # 反转列表,使其按时间顺序排列(旧到新) return list(reversed(logs)) @@ -72,35 +75,35 @@ def load_recent_logs(limit: int = 100) -> list[dict]: @router.websocket("/ws/logs") async def websocket_logs(websocket: WebSocket): """WebSocket 日志推送端点 - + 客户端连接后会持续接收服务器端的日志消息 """ await websocket.accept() active_connections.add(websocket) logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}") - + # 连接建立后,立即发送历史日志 try: recent_logs = load_recent_logs(limit=100) logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端") - + for log_entry in recent_logs: await websocket.send_text(json.dumps(log_entry, ensure_ascii=False)) except Exception as e: logger.error(f"发送历史日志失败: {e}") - + try: # 保持连接,等待客户端消息或断开 while True: # 接收客户端消息(用于心跳或控制指令) data = await websocket.receive_text() - + # 可以处理客户端的控制消息,例如: # - "ping" -> 心跳检测 # - {"filter": "ERROR"} -> 设置日志级别过滤 if data == "ping": await websocket.send_text("pong") - + except WebSocketDisconnect: active_connections.discard(websocket) logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}") @@ -111,19 +114,19 @@ async def websocket_logs(websocket: WebSocket): async def broadcast_log(log_data: dict): """广播日志到所有连接的 WebSocket 客户端 - + Args: log_data: 日志数据字典 """ if not active_connections: return - + # 格式化为 JSON message = json.dumps(log_data, ensure_ascii=False) - + # 记录需要断开的连接 disconnected = set() - + # 广播到所有客户端 for connection in active_connections: try: @@ -131,7 +134,7 @@ async def broadcast_log(log_data: dict): except Exception: # 发送失败,标记为断开 disconnected.add(connection) - + # 清理断开的连接 if disconnected: active_connections.difference_update(disconnected) diff --git a/src/webui/manager.py b/src/webui/manager.py index 3919df17..4dc472e2 100644 --- a/src/webui/manager.py +++ b/src/webui/manager.py @@ -1,4 +1,5 @@ """WebUI 管理器 - 处理开发/生产环境的 WebUI 启动""" + import os from pathlib import Path from src.common.logger import get_logger @@ -10,10 +11,10 @@ logger = get_logger("webui") def setup_webui(mode: str = "production") -> bool: """ 设置 WebUI - + Args: mode: 运行模式,"development" 或 "production" - + Returns: bool: 是否成功设置 """ @@ -22,7 +23,7 @@ def setup_webui(mode: str = "production") -> bool: current_token = token_manager.get_token() logger.info(f"🔑 WebUI Access Token: {current_token}") logger.info("💡 请使用此 Token 登录 WebUI") - + if mode == "development": return setup_dev_mode() else: @@ -33,12 +34,12 @@ def setup_dev_mode() -> bool: """设置开发模式 - 仅启用 CORS,前端自行启动""" from src.common.server import get_global_server from .logs_ws import router as logs_router - + # 注册 WebSocket 日志路由(开发模式也需要) server = get_global_server() server.register_router(logs_router) logger.info("✅ WebSocket 日志推送路由已注册") - + logger.info("📝 WebUI 开发模式已启用") logger.info("🌐 请手动启动前端开发服务器: cd webui && npm run dev") logger.info("💡 前端将运行在 http://localhost:7999") @@ -52,33 +53,33 @@ def setup_production_mode() -> bool: from starlette.responses import FileResponse from .logs_ws import router as logs_router import mimetypes - + # 确保正确的 MIME 类型映射 mimetypes.init() - mimetypes.add_type('application/javascript', '.js') - mimetypes.add_type('application/javascript', '.mjs') - mimetypes.add_type('text/css', '.css') - mimetypes.add_type('application/json', '.json') - + mimetypes.add_type("application/javascript", ".js") + mimetypes.add_type("application/javascript", ".mjs") + mimetypes.add_type("text/css", ".css") + mimetypes.add_type("application/json", ".json") + server = get_global_server() - + # 注册 WebSocket 日志路由 server.register_router(logs_router) logger.info("✅ WebSocket 日志推送路由已注册") - + base_dir = Path(__file__).parent.parent.parent static_path = base_dir / "webui" / "dist" - + if not static_path.exists(): logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}") logger.warning("💡 请先构建前端: cd webui && npm run build") return False - + if not (static_path / "index.html").exists(): logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}") logger.warning("💡 请确认前端已正确构建") return False - + # 处理 SPA 路由 @server.app.get("/{full_path:path}") async def serve_spa(full_path: str): @@ -86,23 +87,23 @@ def setup_production_mode() -> bool: # API 路由不处理 if full_path.startswith("api/"): return None - + # 检查文件是否存在 file_path = static_path / full_path if file_path.is_file(): # 自动检测 MIME 类型 media_type = mimetypes.guess_type(str(file_path))[0] return FileResponse(file_path, media_type=media_type) - + # 返回 index.html(SPA 路由) return FileResponse(static_path / "index.html", media_type="text/html") - + host = os.getenv("HOST", "127.0.0.1") port = os.getenv("PORT", "8000") logger.info("✅ WebUI 生产模式已挂载") logger.info(f"🌐 访问 http://{host}:{port} 查看 WebUI") return True - + except Exception as e: logger.error(f"挂载 WebUI 静态文件失败: {e}") return False diff --git a/src/webui/person_routes.py b/src/webui/person_routes.py index a5488d49..24855aba 100644 --- a/src/webui/person_routes.py +++ b/src/webui/person_routes.py @@ -1,4 +1,5 @@ """人物信息管理 API 路由""" + from fastapi import APIRouter, HTTPException, Header, Query from pydantic import BaseModel from typing import Optional, List, Dict @@ -16,6 +17,7 @@ router = APIRouter(prefix="/person", tags=["Person"]) class PersonInfoResponse(BaseModel): """人物信息响应""" + id: int is_known: bool person_id: str @@ -33,6 +35,7 @@ class PersonInfoResponse(BaseModel): class PersonListResponse(BaseModel): """人物列表响应""" + success: bool total: int page: int @@ -42,12 +45,14 @@ class PersonListResponse(BaseModel): class PersonDetailResponse(BaseModel): """人物详情响应""" + success: bool data: PersonInfoResponse class PersonUpdateRequest(BaseModel): """人物信息更新请求""" + person_name: Optional[str] = None name_reason: Optional[str] = None nickname: Optional[str] = None @@ -57,6 +62,7 @@ class PersonUpdateRequest(BaseModel): class PersonUpdateResponse(BaseModel): """人物信息更新响应""" + success: bool message: str data: Optional[PersonInfoResponse] = None @@ -64,6 +70,7 @@ class PersonUpdateResponse(BaseModel): class PersonDeleteResponse(BaseModel): """人物删除响应""" + success: bool message: str @@ -72,13 +79,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool: """验证认证 Token""" if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + token = authorization.replace("Bearer ", "") token_manager = get_token_manager() - + if not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="Token 无效或已过期") - + return True @@ -118,11 +125,11 @@ async def get_person_list( search: Optional[str] = Query(None, description="搜索关键词"), is_known: Optional[bool] = Query(None, description="是否已认识筛选"), platform: Optional[str] = Query(None, description="平台筛选"), - authorization: Optional[str] = Header(None) + authorization: Optional[str] = Header(None), ): """ 获取人物信息列表 - + Args: page: 页码 (从 1 开始) page_size: 每页数量 (1-100) @@ -130,58 +137,50 @@ async def get_person_list( is_known: 是否已认识筛选 platform: 平台筛选 authorization: Authorization header - + Returns: 人物信息列表 """ try: verify_auth_token(authorization) - + # 构建查询 query = PersonInfo.select() - + # 搜索过滤 if search: query = query.where( - (PersonInfo.person_name.contains(search)) | - (PersonInfo.nickname.contains(search)) | - (PersonInfo.user_id.contains(search)) + (PersonInfo.person_name.contains(search)) + | (PersonInfo.nickname.contains(search)) + | (PersonInfo.user_id.contains(search)) ) - + # 已认识状态过滤 if is_known is not None: query = query.where(PersonInfo.is_known == is_known) - + # 平台过滤 if platform: query = query.where(PersonInfo.platform == platform) - + # 排序:最后更新时间倒序(NULL 值放在最后) # Peewee 不支持 nulls_last,使用 CASE WHEN 来实现 from peewee import Case - query = query.order_by( - Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), - PersonInfo.last_know.desc() - ) - + + query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc()) + # 获取总数 total = query.count() - + # 分页 offset = (page - 1) * page_size persons = query.offset(offset).limit(page_size) - + # 转换为响应对象 data = [person_to_response(person) for person in persons] - - return PersonListResponse( - success=True, - total=total, - page=page, - page_size=page_size, - data=data - ) - + + return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data) + except HTTPException: raise except Exception as e: @@ -190,33 +189,27 @@ async def get_person_list( @router.get("/{person_id}", response_model=PersonDetailResponse) -async def get_person_detail( - person_id: str, - authorization: Optional[str] = Header(None) -): +async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)): """ 获取人物详细信息 - + Args: person_id: 人物唯一 ID authorization: Authorization header - + Returns: 人物详细信息 """ try: verify_auth_token(authorization) - + person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) - + if not person: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息") - - return PersonDetailResponse( - success=True, - data=person_to_response(person) - ) - + + return PersonDetailResponse(success=True, data=person_to_response(person)) + except HTTPException: raise except Exception as e: @@ -225,53 +218,47 @@ async def get_person_detail( @router.patch("/{person_id}", response_model=PersonUpdateResponse) -async def update_person( - person_id: str, - request: PersonUpdateRequest, - authorization: Optional[str] = Header(None) -): +async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)): """ 增量更新人物信息(只更新提供的字段) - + Args: person_id: 人物唯一 ID request: 更新请求(只包含需要更新的字段) authorization: Authorization header - + Returns: 更新结果 """ try: verify_auth_token(authorization) - + person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) - + if not person: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息") - + # 只更新提供的字段 update_data = request.model_dump(exclude_unset=True) - + if not update_data: raise HTTPException(status_code=400, detail="未提供任何需要更新的字段") - + # 更新最后修改时间 - update_data['last_know'] = time.time() - + update_data["last_know"] = time.time() + # 执行更新 for field, value in update_data.items(): setattr(person, field, value) - + person.save() - + logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}") - + return PersonUpdateResponse( - success=True, - message=f"成功更新 {len(update_data)} 个字段", - data=person_to_response(person) + success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person) ) - + except HTTPException: raise except Exception as e: @@ -280,41 +267,35 @@ async def update_person( @router.delete("/{person_id}", response_model=PersonDeleteResponse) -async def delete_person( - person_id: str, - authorization: Optional[str] = Header(None) -): +async def delete_person(person_id: str, authorization: Optional[str] = Header(None)): """ 删除人物信息 - + Args: person_id: 人物唯一 ID authorization: Authorization header - + Returns: 删除结果 """ try: verify_auth_token(authorization) - + person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) - + if not person: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息") - + # 记录删除信息 person_name = person.person_name or person.nickname or person.user_id - + # 执行删除 person.delete_instance() - + logger.info(f"人物信息已删除: {person_id} ({person_name})") - - return PersonDeleteResponse( - success=True, - message=f"成功删除人物信息: {person_name}" - ) - + + return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}") + except HTTPException: raise except Exception as e: @@ -323,41 +304,31 @@ async def delete_person( @router.get("/stats/summary") -async def get_person_stats( - authorization: Optional[str] = Header(None) -): +async def get_person_stats(authorization: Optional[str] = Header(None)): """ 获取人物信息统计数据 - + Args: authorization: Authorization header - + Returns: 统计数据 """ try: verify_auth_token(authorization) - + total = PersonInfo.select().count() known = PersonInfo.select().where(PersonInfo.is_known).count() unknown = total - known - + # 按平台统计 platforms = {} for person in PersonInfo.select(PersonInfo.platform): platform = person.platform platforms[platform] = platforms.get(platform, 0) + 1 - - return { - "success": True, - "data": { - "total": total, - "known": known, - "unknown": unknown, - "platforms": platforms - } - } - + + return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}} + except HTTPException: raise except Exception as e: diff --git a/src/webui/plugin_progress_ws.py b/src/webui/plugin_progress_ws.py index 927dbb13..7e0fb647 100644 --- a/src/webui/plugin_progress_ws.py +++ b/src/webui/plugin_progress_ws.py @@ -1,4 +1,5 @@ """WebSocket 插件加载进度推送模块""" + from fastapi import APIRouter, WebSocket, WebSocketDisconnect from typing import Set, Dict, Any import json @@ -22,7 +23,7 @@ current_progress: Dict[str, Any] = { "error": None, "plugin_id": None, # 当前操作的插件 ID "total_plugins": 0, - "loaded_plugins": 0 + "loaded_plugins": 0, } @@ -30,20 +31,20 @@ async def broadcast_progress(progress_data: Dict[str, Any]): """广播进度更新到所有连接的客户端""" global current_progress current_progress = progress_data.copy() - + if not active_connections: return - + message = json.dumps(progress_data, ensure_ascii=False) disconnected = set() - + for websocket in active_connections: try: await websocket.send_text(message) except Exception as e: logger.error(f"发送进度更新失败: {e}") disconnected.add(websocket) - + # 移除断开的连接 for websocket in disconnected: active_connections.discard(websocket) @@ -57,10 +58,10 @@ async def update_progress( error: str = None, plugin_id: str = None, total_plugins: int = 0, - loaded_plugins: int = 0 + loaded_plugins: int = 0, ): """更新并广播进度 - + Args: stage: 阶段 (idle, loading, success, error) progress: 进度百分比 (0-100) @@ -80,9 +81,9 @@ async def update_progress( "plugin_id": plugin_id, "total_plugins": total_plugins, "loaded_plugins": loaded_plugins, - "timestamp": asyncio.get_event_loop().time() + "timestamp": asyncio.get_event_loop().time(), } - + await broadcast_progress(progress_data) logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}") @@ -90,30 +91,30 @@ async def update_progress( @router.websocket("/ws/plugin-progress") async def websocket_plugin_progress(websocket: WebSocket): """WebSocket 插件加载进度推送端点 - + 客户端连接后会立即收到当前进度状态 """ await websocket.accept() active_connections.add(websocket) logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}") - + try: # 发送当前进度状态 await websocket.send_text(json.dumps(current_progress, ensure_ascii=False)) - + # 保持连接并处理客户端消息 while True: try: data = await websocket.receive_text() - + # 处理客户端心跳 if data == "ping": await websocket.send_text("pong") - + except Exception as e: logger.error(f"处理客户端消息时出错: {e}") break - + except WebSocketDisconnect: active_connections.discard(websocket) logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}") diff --git a/src/webui/plugin_routes.py b/src/webui/plugin_routes.py index 5054a391..cb559fb7 100644 --- a/src/webui/plugin_routes.py +++ b/src/webui/plugin_routes.py @@ -21,22 +21,22 @@ set_update_progress_callback(update_progress) def parse_version(version_str: str) -> tuple[int, int, int]: """ 解析版本号字符串 - + 支持格式: - 0.11.2 -> (0, 11, 2) - 0.11.2.snapshot.2 -> (0, 11, 2) - + Returns: (major, minor, patch) 三元组 """ # 移除 snapshot 等后缀 - base_version = version_str.split('.snapshot')[0].split('.dev')[0].split('.alpha')[0].split('.beta')[0] - - parts = base_version.split('.') + base_version = version_str.split(".snapshot")[0].split(".dev")[0].split(".alpha")[0].split(".beta")[0] + + parts = base_version.split(".") if len(parts) < 3: # 补齐到 3 位 - parts.extend(['0'] * (3 - len(parts))) - + parts.extend(["0"] * (3 - len(parts))) + try: major = int(parts[0]) minor = int(parts[1]) @@ -49,8 +49,10 @@ def parse_version(version_str: str) -> tuple[int, int, int]: # ============ 请求/响应模型 ============ + class FetchRawFileRequest(BaseModel): """获取 Raw 文件请求""" + owner: str = Field(..., description="仓库所有者", example="MaiM-with-u") repo: str = Field(..., description="仓库名称", example="plugin-repo") branch: str = Field(..., description="分支名称", example="main") @@ -61,6 +63,7 @@ class FetchRawFileRequest(BaseModel): class FetchRawFileResponse(BaseModel): """获取 Raw 文件响应""" + success: bool = Field(..., description="是否成功") data: Optional[str] = Field(None, description="文件内容") error: Optional[str] = Field(None, description="错误信息") @@ -71,6 +74,7 @@ class FetchRawFileResponse(BaseModel): class CloneRepositoryRequest(BaseModel): """克隆仓库请求""" + owner: str = Field(..., description="仓库所有者", example="MaiM-with-u") repo: str = Field(..., description="仓库名称", example="plugin-repo") target_path: str = Field(..., description="目标路径(相对于插件目录)") @@ -82,6 +86,7 @@ class CloneRepositoryRequest(BaseModel): class CloneRepositoryResponse(BaseModel): """克隆仓库响应""" + success: bool = Field(..., description="是否成功") path: Optional[str] = Field(None, description="克隆路径") error: Optional[str] = Field(None, description="错误信息") @@ -93,6 +98,7 @@ class CloneRepositoryResponse(BaseModel): class MirrorConfigResponse(BaseModel): """镜像源配置响应""" + id: str = Field(..., description="镜像源 ID") name: str = Field(..., description="镜像源名称") raw_prefix: str = Field(..., description="Raw 文件前缀") @@ -103,12 +109,14 @@ class MirrorConfigResponse(BaseModel): class AvailableMirrorsResponse(BaseModel): """可用镜像源列表响应""" + mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表") default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)") class AddMirrorRequest(BaseModel): """添加镜像源请求""" + id: str = Field(..., description="镜像源 ID", example="custom-mirror") name: str = Field(..., description="镜像源名称", example="自定义镜像源") raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw") @@ -119,6 +127,7 @@ class AddMirrorRequest(BaseModel): class UpdateMirrorRequest(BaseModel): """更新镜像源请求""" + name: Optional[str] = Field(None, description="镜像源名称") raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀") clone_prefix: Optional[str] = Field(None, description="克隆前缀") @@ -128,6 +137,7 @@ class UpdateMirrorRequest(BaseModel): class GitStatusResponse(BaseModel): """Git 安装状态响应""" + installed: bool = Field(..., description="是否已安装 Git") version: Optional[str] = Field(None, description="Git 版本号") path: Optional[str] = Field(None, description="Git 可执行文件路径") @@ -136,6 +146,7 @@ class GitStatusResponse(BaseModel): class InstallPluginRequest(BaseModel): """安装插件请求""" + plugin_id: str = Field(..., description="插件 ID") repository_url: str = Field(..., description="插件仓库 URL") branch: Optional[str] = Field("main", description="分支名称") @@ -144,6 +155,7 @@ class InstallPluginRequest(BaseModel): class VersionResponse(BaseModel): """麦麦版本响应""" + version: str = Field(..., description="麦麦版本号") version_major: int = Field(..., description="主版本号") version_minor: int = Field(..., description="次版本号") @@ -152,11 +164,13 @@ class VersionResponse(BaseModel): class UninstallPluginRequest(BaseModel): """卸载插件请求""" + plugin_id: str = Field(..., description="插件 ID") class UpdatePluginRequest(BaseModel): """更新插件请求""" + plugin_id: str = Field(..., description="插件 ID") repository_url: str = Field(..., description="插件仓库 URL") branch: Optional[str] = Field("main", description="分支名称") @@ -165,40 +179,34 @@ class UpdatePluginRequest(BaseModel): # ============ API 路由 ============ + @router.get("/version", response_model=VersionResponse) async def get_maimai_version() -> VersionResponse: """ 获取麦麦版本信息 - + 此接口无需认证,用于前端检查插件兼容性 """ major, minor, patch = parse_version(MMC_VERSION) - - return VersionResponse( - version=MMC_VERSION, - version_major=major, - version_minor=minor, - version_patch=patch - ) + + return VersionResponse(version=MMC_VERSION, version_major=major, version_minor=minor, version_patch=patch) @router.get("/git-status", response_model=GitStatusResponse) async def check_git_status() -> GitStatusResponse: """ 检查本机 Git 安装状态 - + 此接口无需认证,用于前端快速检测是否可以使用插件安装功能 """ service = get_git_mirror_service() result = service.check_git_installed() - + return GitStatusResponse(**result) @router.get("/mirrors", response_model=AvailableMirrorsResponse) -async def get_available_mirrors( - authorization: Optional[str] = Header(None) -) -> AvailableMirrorsResponse: +async def get_available_mirrors(authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse: """ 获取所有可用的镜像源配置 """ @@ -207,10 +215,10 @@ async def get_available_mirrors( token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - + service = get_git_mirror_service() config = service.get_mirror_config() - + all_mirrors = config.get_all_mirrors() mirrors = [ MirrorConfigResponse( @@ -219,22 +227,16 @@ async def get_available_mirrors( raw_prefix=m["raw_prefix"], clone_prefix=m["clone_prefix"], enabled=m["enabled"], - priority=m["priority"] + priority=m["priority"], ) for m in all_mirrors ] - - return AvailableMirrorsResponse( - mirrors=mirrors, - default_priority=config.get_default_priority_list() - ) + + return AvailableMirrorsResponse(mirrors=mirrors, default_priority=config.get_default_priority_list()) @router.post("/mirrors", response_model=MirrorConfigResponse) -async def add_mirror( - request: AddMirrorRequest, - authorization: Optional[str] = Header(None) -) -> MirrorConfigResponse: +async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = Header(None)) -> MirrorConfigResponse: """ 添加新的镜像源 """ @@ -243,27 +245,27 @@ async def add_mirror( token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - + try: service = get_git_mirror_service() config = service.get_mirror_config() - + mirror = config.add_mirror( mirror_id=request.id, name=request.name, raw_prefix=request.raw_prefix, clone_prefix=request.clone_prefix, enabled=request.enabled, - priority=request.priority + priority=request.priority, ) - + return MirrorConfigResponse( id=mirror["id"], name=mirror["name"], raw_prefix=mirror["raw_prefix"], clone_prefix=mirror["clone_prefix"], enabled=mirror["enabled"], - priority=mirror["priority"] + priority=mirror["priority"], ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e @@ -274,9 +276,7 @@ async def add_mirror( @router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse) async def update_mirror( - mirror_id: str, - request: UpdateMirrorRequest, - authorization: Optional[str] = Header(None) + mirror_id: str, request: UpdateMirrorRequest, authorization: Optional[str] = Header(None) ) -> MirrorConfigResponse: """ 更新镜像源配置 @@ -286,30 +286,30 @@ async def update_mirror( token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - + try: service = get_git_mirror_service() config = service.get_mirror_config() - + mirror = config.update_mirror( mirror_id=mirror_id, name=request.name, raw_prefix=request.raw_prefix, clone_prefix=request.clone_prefix, enabled=request.enabled, - priority=request.priority + priority=request.priority, ) - + if not mirror: raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}") - + return MirrorConfigResponse( id=mirror["id"], name=mirror["name"], raw_prefix=mirror["raw_prefix"], clone_prefix=mirror["clone_prefix"], enabled=mirror["enabled"], - priority=mirror["priority"] + priority=mirror["priority"], ) except HTTPException: raise @@ -319,10 +319,7 @@ async def update_mirror( @router.delete("/mirrors/{mirror_id}") -async def delete_mirror( - mirror_id: str, - authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: +async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 删除镜像源 """ @@ -331,57 +328,53 @@ async def delete_mirror( token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - + service = get_git_mirror_service() config = service.get_mirror_config() - + success = config.delete_mirror(mirror_id) - + if not success: raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}") - - return { - "success": True, - "message": f"已删除镜像源: {mirror_id}" - } + + return {"success": True, "message": f"已删除镜像源: {mirror_id}"} @router.post("/fetch-raw", response_model=FetchRawFileResponse) async def fetch_raw_file( - request: FetchRawFileRequest, - authorization: Optional[str] = Header(None) + request: FetchRawFileRequest, authorization: Optional[str] = Header(None) ) -> FetchRawFileResponse: """ 获取 GitHub 仓库的 Raw 文件内容 - + 支持多镜像源自动切换和错误重试 - + 注意:此接口可公开访问,用于获取插件仓库等公开资源 """ # Token 验证(可选,用于日志记录) token = authorization.replace("Bearer ", "") if authorization else None token_manager = get_token_manager() is_authenticated = token and token_manager.verify_token(token) - + # 对于公开仓库的访问,不强制要求认证 # 只在日志中记录是否认证 logger.info( f"收到获取 Raw 文件请求 (认证: {is_authenticated}): " f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}" ) - + # 发送开始加载进度 await update_progress( stage="loading", progress=10, message=f"正在获取插件列表: {request.file_path}", total_plugins=0, - loaded_plugins=0 + loaded_plugins=0, ) - + try: service = get_git_mirror_service() - + # git_mirror_service 会自动推送 30%-70% 的详细镜像源尝试进度 result = await service.fetch_raw_file( owner=request.owner, @@ -389,69 +382,56 @@ async def fetch_raw_file( branch=request.branch, file_path=request.file_path, mirror_id=request.mirror_id, - custom_url=request.custom_url + custom_url=request.custom_url, ) - + if result.get("success"): # 更新进度:成功获取 await update_progress( - stage="loading", - progress=70, - message="正在解析插件数据...", - total_plugins=0, - loaded_plugins=0 + stage="loading", progress=70, message="正在解析插件数据...", total_plugins=0, loaded_plugins=0 ) - + # 尝试解析插件数量 try: import json + data = json.loads(result.get("data", "[]")) total = len(data) if isinstance(data, list) else 0 - + # 发送成功状态 await update_progress( stage="success", progress=100, message=f"成功加载 {total} 个插件", total_plugins=total, - loaded_plugins=total + loaded_plugins=total, ) except Exception: # 如果解析失败,仍然发送成功状态 await update_progress( - stage="success", - progress=100, - message="加载完成", - total_plugins=0, - loaded_plugins=0 + stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0 ) - + return FetchRawFileResponse(**result) - + except Exception as e: logger.error(f"获取 Raw 文件失败: {e}") - + # 发送错误进度 await update_progress( - stage="error", - progress=0, - message="加载失败", - error=str(e), - total_plugins=0, - loaded_plugins=0 + stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0 ) - + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e @router.post("/clone", response_model=CloneRepositoryResponse) async def clone_repository( - request: CloneRepositoryRequest, - authorization: Optional[str] = Header(None) + request: CloneRepositoryRequest, authorization: Optional[str] = Header(None) ) -> CloneRepositoryResponse: """ 克隆 GitHub 仓库到本地 - + 支持多镜像源自动切换和错误重试 """ # Token 验证 @@ -459,17 +439,15 @@ async def clone_repository( token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info( - f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}" - ) - + + logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}") + try: # TODO: 验证 target_path 的安全性,防止路径遍历攻击 # TODO: 确定实际的插件目录基路径 base_plugin_path = Path("./plugins") # 临时路径 target_path = base_plugin_path / request.target_path - + service = get_git_mirror_service() result = await service.clone_repository( owner=request.owner, @@ -478,24 +456,21 @@ async def clone_repository( branch=request.branch, mirror_id=request.mirror_id, custom_url=request.custom_url, - depth=request.depth + depth=request.depth, ) - + return CloneRepositoryResponse(**result) - + except Exception as e: logger.error(f"克隆仓库失败: {e}") raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e @router.post("/install") -async def install_plugin( - request: InstallPluginRequest, - authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: +async def install_plugin(request: InstallPluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 安装插件 - + 从 Git 仓库克隆插件到本地插件目录 """ # Token 验证 @@ -503,9 +478,9 @@ async def install_plugin( token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - + logger.info(f"收到安装插件请求: {request.plugin_id}") - + try: # 推送进度:开始安装 await update_progress( @@ -513,80 +488,75 @@ async def install_plugin( progress=5, message=f"开始安装插件: {request.plugin_id}", operation="install", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + # 1. 解析仓库 URL # repository_url 格式: https://github.com/owner/repo - repo_url = request.repository_url.rstrip('/') - if repo_url.endswith('.git'): + repo_url = request.repository_url.rstrip("/") + if repo_url.endswith(".git"): repo_url = repo_url[:-4] - - parts = repo_url.split('/') + + parts = repo_url.split("/") if len(parts) < 2: raise HTTPException(status_code=400, detail="无效的仓库 URL") - + owner = parts[-2] repo = parts[-1] - + await update_progress( stage="loading", progress=10, message=f"解析仓库信息: {owner}/{repo}", operation="install", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + # 2. 确定插件安装路径 plugins_dir = Path("plugins") plugins_dir.mkdir(exist_ok=True) - + target_path = plugins_dir / request.plugin_id - + # 检查插件是否已安装 if target_path.exists(): await update_progress( stage="error", progress=0, - message=f"插件已存在", + message="插件已存在", operation="install", plugin_id=request.plugin_id, - error="插件已安装,请先卸载" + error="插件已安装,请先卸载", ) raise HTTPException(status_code=400, detail="插件已安装") - + await update_progress( stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + # 3. 克隆仓库(这里会自动推送 20%-80% 的进度) service = get_git_mirror_service() - + # 如果是 GitHub 仓库,使用镜像源 - if 'github.com' in repo_url: + if "github.com" in repo_url: result = await service.clone_repository( owner=owner, repo=repo, target_path=target_path, branch=request.branch, mirror_id=request.mirror_id, - depth=1 # 浅克隆,节省时间和空间 + depth=1, # 浅克隆,节省时间和空间 ) else: # 自定义仓库,直接使用 URL result = await service.clone_repository( - owner=owner, - repo=repo, - target_path=target_path, - branch=request.branch, - custom_url=repo_url, - depth=1 + owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1 ) - + if not result.get("success"): error_msg = result.get("error", "克隆失败") await update_progress( @@ -595,113 +565,107 @@ async def install_plugin( message="克隆仓库失败", operation="install", plugin_id=request.plugin_id, - error=error_msg + error=error_msg, ) raise HTTPException(status_code=500, detail=error_msg) - + # 4. 验证插件完整性 await update_progress( - stage="loading", - progress=85, - message="验证插件文件...", - operation="install", - plugin_id=request.plugin_id + stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id ) - + manifest_path = target_path / "_manifest.json" if not manifest_path.exists(): # 清理失败的安装 import shutil + shutil.rmtree(target_path, ignore_errors=True) - + await update_progress( stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=request.plugin_id, - error="无效的插件格式" + error="无效的插件格式", ) raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") - + # 5. 读取并验证 manifest await update_progress( - stage="loading", - progress=90, - message="读取插件配置...", - operation="install", - plugin_id=request.plugin_id + stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id ) - + try: import json as json_module - with open(manifest_path, 'r', encoding='utf-8') as f: + + with open(manifest_path, "r", encoding="utf-8") as f: manifest = json_module.load(f) - + # 基本验证 - required_fields = ['manifest_version', 'name', 'version', 'author'] + required_fields = ["manifest_version", "name", "version", "author"] for field in required_fields: if field not in manifest: raise ValueError(f"缺少必需字段: {field}") - + except Exception as e: # 清理失败的安装 import shutil + shutil.rmtree(target_path, ignore_errors=True) - + await update_progress( stage="error", progress=0, message="_manifest.json 无效", operation="install", plugin_id=request.plugin_id, - error=str(e) + error=str(e), ) raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e - + # 6. 安装成功 await update_progress( stage="success", progress=100, message=f"成功安装插件: {manifest['name']} v{manifest['version']}", operation="install", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + return { "success": True, "message": "插件安装成功", "plugin_id": request.plugin_id, - "plugin_name": manifest['name'], - "version": manifest['version'], - "path": str(target_path) + "plugin_name": manifest["name"], + "version": manifest["version"], + "path": str(target_path), } - + except HTTPException: raise except Exception as e: logger.error(f"安装插件失败: {e}", exc_info=True) - + await update_progress( stage="error", progress=0, message="安装失败", operation="install", plugin_id=request.plugin_id, - error=str(e) + error=str(e), ) - + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e @router.post("/uninstall") async def uninstall_plugin( - request: UninstallPluginRequest, - authorization: Optional[str] = Header(None) + request: UninstallPluginRequest, authorization: Optional[str] = Header(None) ) -> Dict[str, Any]: """ 卸载插件 - + 删除插件目录及其所有文件 """ # Token 验证 @@ -709,9 +673,9 @@ async def uninstall_plugin( token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - + logger.info(f"收到卸载插件请求: {request.plugin_id}") - + try: # 推送进度:开始卸载 await update_progress( @@ -719,13 +683,13 @@ async def uninstall_plugin( progress=10, message=f"开始卸载插件: {request.plugin_id}", operation="uninstall", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + # 1. 检查插件是否存在 plugins_dir = Path("plugins") plugin_path = plugins_dir / request.plugin_id - + if not plugin_path.exists(): await update_progress( stage="error", @@ -733,107 +697,101 @@ async def uninstall_plugin( message="插件不存在", operation="uninstall", plugin_id=request.plugin_id, - error="插件未安装或已被删除" + error="插件未安装或已被删除", ) raise HTTPException(status_code=404, detail="插件未安装") - + await update_progress( stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + # 2. 读取插件信息(用于日志) manifest_path = plugin_path / "_manifest.json" plugin_name = request.plugin_id - + if manifest_path.exists(): try: import json as json_module - with open(manifest_path, 'r', encoding='utf-8') as f: + + with open(manifest_path, "r", encoding="utf-8") as f: manifest = json_module.load(f) plugin_name = manifest.get("name", request.plugin_id) except Exception: pass # 如果读取失败,使用插件 ID 作为名称 - + await update_progress( stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + # 3. 删除插件目录 import shutil import stat - + def remove_readonly(func, path, _): """清除只读属性并删除文件""" import os + os.chmod(path, stat.S_IWRITE) func(path) - + shutil.rmtree(plugin_path, onerror=remove_readonly) - + logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})") - + # 4. 推送成功状态 await update_progress( stage="success", progress=100, message=f"成功卸载插件: {plugin_name}", operation="uninstall", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - - return { - "success": True, - "message": "插件卸载成功", - "plugin_id": request.plugin_id, - "plugin_name": plugin_name - } - + + return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name} + except HTTPException: raise except PermissionError as e: logger.error(f"卸载插件失败(权限错误): {e}") - + await update_progress( stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=request.plugin_id, - error="权限不足,无法删除插件文件" + error="权限不足,无法删除插件文件", ) - + raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e except Exception as e: logger.error(f"卸载插件失败: {e}", exc_info=True) - + await update_progress( stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=request.plugin_id, - error=str(e) + error=str(e), ) - + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e @router.post("/update") -async def update_plugin( - request: UpdatePluginRequest, - authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: +async def update_plugin(request: UpdatePluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 更新插件 - + 删除旧版本,重新克隆新版本 """ # Token 验证 @@ -841,9 +799,9 @@ async def update_plugin( token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - + logger.info(f"收到更新插件请求: {request.plugin_id}") - + try: # 推送进度:开始更新 await update_progress( @@ -851,13 +809,13 @@ async def update_plugin( progress=5, message=f"开始更新插件: {request.plugin_id}", operation="update", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + # 1. 检查插件是否已安装 plugins_dir = Path("plugins") plugin_path = plugins_dir / request.plugin_id - + if not plugin_path.exists(): await update_progress( stage="error", @@ -865,97 +823,90 @@ async def update_plugin( message="插件不存在", operation="update", plugin_id=request.plugin_id, - error="插件未安装,请先安装" + error="插件未安装,请先安装", ) raise HTTPException(status_code=404, detail="插件未安装") - + # 2. 读取旧版本信息 manifest_path = plugin_path / "_manifest.json" old_version = "unknown" plugin_name = request.plugin_id - + if manifest_path.exists(): try: import json as json_module - with open(manifest_path, 'r', encoding='utf-8') as f: + + with open(manifest_path, "r", encoding="utf-8") as f: manifest = json_module.load(f) old_version = manifest.get("version", "unknown") - plugin_name = manifest.get("name", request.plugin_id) + _plugin_name = manifest.get("name", request.plugin_id) except Exception: pass - + await update_progress( stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + # 3. 删除旧版本 await update_progress( - stage="loading", - progress=20, - message="正在删除旧版本...", - operation="update", - plugin_id=request.plugin_id + stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id ) - + import shutil import stat - + def remove_readonly(func, path, _): """清除只读属性并删除文件""" import os + os.chmod(path, stat.S_IWRITE) func(path) - + shutil.rmtree(plugin_path, onerror=remove_readonly) - + logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}") - + # 4. 解析仓库 URL await update_progress( stage="loading", progress=30, message="正在准备下载新版本...", operation="update", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - - repo_url = request.repository_url.rstrip('/') - if repo_url.endswith('.git'): + + repo_url = request.repository_url.rstrip("/") + if repo_url.endswith(".git"): repo_url = repo_url[:-4] - - parts = repo_url.split('/') + + parts = repo_url.split("/") if len(parts) < 2: raise HTTPException(status_code=400, detail="无效的仓库 URL") - + owner = parts[-2] repo = parts[-1] - + # 5. 克隆新版本(这里会推送 35%-85% 的进度) service = get_git_mirror_service() - - if 'github.com' in repo_url: + + if "github.com" in repo_url: result = await service.clone_repository( owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, mirror_id=request.mirror_id, - depth=1 + depth=1, ) else: result = await service.clone_repository( - owner=owner, - repo=repo, - target_path=plugin_path, - branch=request.branch, - custom_url=repo_url, - depth=1 + owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1 ) - + if not result.get("success"): error_msg = result.get("error", "克隆失败") await update_progress( @@ -964,106 +915,96 @@ async def update_plugin( message="下载新版本失败", operation="update", plugin_id=request.plugin_id, - error=error_msg + error=error_msg, ) raise HTTPException(status_code=500, detail=error_msg) - + # 6. 验证新版本 await update_progress( - stage="loading", - progress=90, - message="验证新版本...", - operation="update", - plugin_id=request.plugin_id + stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id ) - + new_manifest_path = plugin_path / "_manifest.json" if not new_manifest_path.exists(): # 清理失败的更新 def remove_readonly(func, path, _): """清除只读属性并删除文件""" import os + os.chmod(path, stat.S_IWRITE) func(path) - + shutil.rmtree(plugin_path, onerror=remove_readonly) - + await update_progress( stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=request.plugin_id, - error="无效的插件格式" + error="无效的插件格式", ) raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") - + # 7. 读取新版本信息 try: - with open(new_manifest_path, 'r', encoding='utf-8') as f: + with open(new_manifest_path, "r", encoding="utf-8") as f: new_manifest = json_module.load(f) - + new_version = new_manifest.get("version", "unknown") new_name = new_manifest.get("name", request.plugin_id) - + logger.info(f"成功更新插件: {request.plugin_id} {old_version} → {new_version}") - + # 8. 推送成功状态 await update_progress( stage="success", progress=100, message=f"成功更新 {new_name}: {old_version} → {new_version}", operation="update", - plugin_id=request.plugin_id + plugin_id=request.plugin_id, ) - + return { "success": True, "message": "插件更新成功", "plugin_id": request.plugin_id, "plugin_name": new_name, "old_version": old_version, - "new_version": new_version + "new_version": new_version, } - + except Exception as e: # 清理失败的更新 shutil.rmtree(plugin_path, ignore_errors=True) - + await update_progress( stage="error", progress=0, message="_manifest.json 无效", operation="update", plugin_id=request.plugin_id, - error=str(e) + error=str(e), ) raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e - + except HTTPException: raise except Exception as e: logger.error(f"更新插件失败: {e}", exc_info=True) - + await update_progress( - stage="error", - progress=0, - message="更新失败", - operation="update", - plugin_id=request.plugin_id, - error=str(e) + stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e) ) - + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e @router.get("/installed") -async def get_installed_plugins( - authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: +async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> Dict[str, Any]: """ 获取已安装的插件列表 - + 扫描 plugins 目录,返回所有已安装插件的 ID 和基本信息 """ # Token 验证 @@ -1071,75 +1012,71 @@ async def get_installed_plugins( token_manager = get_token_manager() if not token or not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - + logger.info("收到获取已安装插件列表请求") - + try: plugins_dir = Path("plugins") - + # 如果插件目录不存在,返回空列表 if not plugins_dir.exists(): logger.info("插件目录不存在,创建目录") plugins_dir.mkdir(exist_ok=True) - return { - "success": True, - "plugins": [] - } - + return {"success": True, "plugins": []} + installed_plugins = [] - + # 遍历插件目录 for plugin_path in plugins_dir.iterdir(): # 只处理目录 if not plugin_path.is_dir(): continue - + # 目录名即为插件 ID plugin_id = plugin_path.name - + # 跳过隐藏目录和特殊目录 - if plugin_id.startswith('.') or plugin_id.startswith('__'): + if plugin_id.startswith(".") or plugin_id.startswith("__"): continue - + # 读取 _manifest.json manifest_path = plugin_path / "_manifest.json" - + if not manifest_path.exists(): logger.warning(f"插件 {plugin_id} 缺少 _manifest.json,跳过") continue - + try: import json as json_module - with open(manifest_path, 'r', encoding='utf-8') as f: + + with open(manifest_path, "r", encoding="utf-8") as f: manifest = json_module.load(f) - + # 基本验证 - if 'name' not in manifest or 'version' not in manifest: + if "name" not in manifest or "version" not in manifest: logger.warning(f"插件 {plugin_id} 的 _manifest.json 格式无效,跳过") continue - + # 添加到已安装列表(返回完整的 manifest 信息) - installed_plugins.append({ - "id": plugin_id, - "manifest": manifest, # 返回完整的 manifest 对象 - "path": str(plugin_path.absolute()) - }) - + installed_plugins.append( + { + "id": plugin_id, + "manifest": manifest, # 返回完整的 manifest 对象 + "path": str(plugin_path.absolute()), + } + ) + except json.JSONDecodeError as e: logger.warning(f"插件 {plugin_id} 的 _manifest.json 解析失败: {e}") continue except Exception as e: logger.error(f"读取插件 {plugin_id} 信息时出错: {e}") continue - + logger.info(f"找到 {len(installed_plugins)} 个已安装插件") - - return { - "success": True, - "plugins": installed_plugins, - "total": len(installed_plugins) - } - + + return {"success": True, "plugins": installed_plugins, "total": len(installed_plugins)} + except Exception as e: logger.error(f"获取已安装插件列表失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e diff --git a/src/webui/routers/system.py b/src/webui/routers/system.py index 743da5b9..fb203f79 100644 --- a/src/webui/routers/system.py +++ b/src/webui/routers/system.py @@ -3,6 +3,7 @@ 提供系统重启、状态查询等功能 """ + import os import sys import time @@ -19,12 +20,14 @@ _start_time = time.time() class RestartResponse(BaseModel): """重启响应""" + success: bool message: str class StatusResponse(BaseModel): """状态响应""" + running: bool uptime: float version: str @@ -35,74 +38,60 @@ class StatusResponse(BaseModel): async def restart_maibot(): """ 重启麦麦主程序 - + 使用 os.execv 重启当前进程,配置更改将在重启后生效。 注意:此操作会使麦麦暂时离线。 """ try: # 记录重启操作 print(f"[{datetime.now()}] WebUI 触发重启操作") - + # 使用 os.execv 重启当前进程 # 这会替换当前进程,保持相同的 PID python = sys.executable args = [python] + sys.argv - + # 返回成功响应(实际上这个响应可能不会发送,因为进程会立即重启) # 但我们仍然返回它以保持 API 一致性 os.execv(python, args) - - return RestartResponse( - success=True, - message="麦麦正在重启中..." - ) + + return RestartResponse(success=True, message="麦麦正在重启中...") except Exception as e: - raise HTTPException( - status_code=500, - detail=f"重启失败: {str(e)}" - ) from e + raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e @router.get("/status", response_model=StatusResponse) async def get_maibot_status(): """ 获取麦麦运行状态 - + 返回麦麦的运行状态、运行时长和版本信息。 """ try: uptime = time.time() - _start_time - + # 尝试获取版本信息(需要根据实际情况调整) version = MMC_VERSION # 可以从配置或常量中读取 - + return StatusResponse( - running=True, - uptime=uptime, - version=version, - start_time=datetime.fromtimestamp(_start_time).isoformat() + running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat() ) except Exception as e: - raise HTTPException( - status_code=500, - detail=f"获取状态失败: {str(e)}" - ) from e + raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e # 可选:添加更多系统控制功能 + @router.post("/reload-config") async def reload_config(): """ 热重载配置(不重启进程) - + 仅重新加载配置文件,某些配置可能需要重启才能生效。 此功能需要在主程序中实现配置热重载逻辑。 """ # 这里需要调用主程序的配置重载函数 # 示例:await app_instance.reload_config() - - return { - "success": True, - "message": "配置重载功能待实现" - } + + return {"success": True, "message": "配置重载功能待实现"} diff --git a/src/webui/routes.py b/src/webui/routes.py index b71619ed..3eb7e673 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -1,4 +1,5 @@ """WebUI API 路由""" + from fastapi import APIRouter, HTTPException, Header from pydantic import BaseModel, Field from typing import Optional @@ -38,28 +39,33 @@ router.include_router(system_router) class TokenVerifyRequest(BaseModel): """Token 验证请求""" + token: str = Field(..., description="访问令牌") class TokenVerifyResponse(BaseModel): """Token 验证响应""" + valid: bool = Field(..., description="Token 是否有效") message: str = Field(..., description="验证结果消息") class TokenUpdateRequest(BaseModel): """Token 更新请求""" + new_token: str = Field(..., description="新的访问令牌", min_length=10) class TokenUpdateResponse(BaseModel): """Token 更新响应""" + success: bool = Field(..., description="是否更新成功") message: str = Field(..., description="更新结果消息") class TokenRegenerateResponse(BaseModel): """Token 重新生成响应""" + success: bool = Field(..., description="是否生成成功") token: str = Field(..., description="新生成的令牌") message: str = Field(..., description="生成结果消息") @@ -67,18 +73,21 @@ class TokenRegenerateResponse(BaseModel): class FirstSetupStatusResponse(BaseModel): """首次配置状态响应""" + is_first_setup: bool = Field(..., description="是否为首次配置") message: str = Field(..., description="状态消息") class CompleteSetupResponse(BaseModel): """完成配置响应""" + success: bool = Field(..., description="是否成功") message: str = Field(..., description="结果消息") class ResetSetupResponse(BaseModel): """重置配置响应""" + success: bool = Field(..., description="是否成功") message: str = Field(..., description="结果消息") @@ -93,44 +102,35 @@ async def health_check(): async def verify_token(request: TokenVerifyRequest): """ 验证访问令牌 - + Args: request: 包含 token 的验证请求 - + Returns: 验证结果 """ try: token_manager = get_token_manager() is_valid = token_manager.verify_token(request.token) - + if is_valid: - return TokenVerifyResponse( - valid=True, - message="Token 验证成功" - ) + return TokenVerifyResponse(valid=True, message="Token 验证成功") else: - return TokenVerifyResponse( - valid=False, - message="Token 无效或已过期" - ) + return TokenVerifyResponse(valid=False, message="Token 无效或已过期") except Exception as e: logger.error(f"Token 验证失败: {e}") raise HTTPException(status_code=500, detail="Token 验证失败") from e @router.post("/auth/update", response_model=TokenUpdateResponse) -async def update_token( - request: TokenUpdateRequest, - authorization: Optional[str] = Header(None) -): +async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)): """ 更新访问令牌(需要当前有效的 token) - + Args: request: 包含新 token 的更新请求 authorization: Authorization header (Bearer token) - + Returns: 更新结果 """ @@ -138,20 +138,17 @@ async def update_token( # 验证当前 token if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + current_token = authorization.replace("Bearer ", "") token_manager = get_token_manager() - + if not token_manager.verify_token(current_token): raise HTTPException(status_code=401, detail="当前 Token 无效") - + # 更新 token success, message = token_manager.update_token(request.new_token) - - return TokenUpdateResponse( - success=success, - message=message - ) + + return TokenUpdateResponse(success=success, message=message) except HTTPException: raise except Exception as e: @@ -163,10 +160,10 @@ async def update_token( async def regenerate_token(authorization: Optional[str] = Header(None)): """ 重新生成访问令牌(需要当前有效的 token) - + Args: authorization: Authorization header (Bearer token) - + Returns: 新生成的 token """ @@ -174,21 +171,17 @@ async def regenerate_token(authorization: Optional[str] = Header(None)): # 验证当前 token if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + current_token = authorization.replace("Bearer ", "") token_manager = get_token_manager() - + if not token_manager.verify_token(current_token): raise HTTPException(status_code=401, detail="当前 Token 无效") - + # 重新生成 token new_token = token_manager.regenerate_token() - - return TokenRegenerateResponse( - success=True, - token=new_token, - message="Token 已重新生成" - ) + + return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成") except HTTPException: raise except Exception as e: @@ -200,10 +193,10 @@ async def regenerate_token(authorization: Optional[str] = Header(None)): async def get_setup_status(authorization: Optional[str] = Header(None)): """ 获取首次配置状态 - + Args: authorization: Authorization header (Bearer token) - + Returns: 首次配置状态 """ @@ -211,20 +204,17 @@ async def get_setup_status(authorization: Optional[str] = Header(None)): # 验证 token if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + current_token = authorization.replace("Bearer ", "") token_manager = get_token_manager() - + if not token_manager.verify_token(current_token): raise HTTPException(status_code=401, detail="Token 无效") - + # 检查是否为首次配置 is_first = token_manager.is_first_setup() - - return FirstSetupStatusResponse( - is_first_setup=is_first, - message="首次配置" if is_first else "已完成配置" - ) + + return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置") except HTTPException: raise except Exception as e: @@ -236,10 +226,10 @@ async def get_setup_status(authorization: Optional[str] = Header(None)): async def complete_setup(authorization: Optional[str] = Header(None)): """ 标记首次配置完成 - + Args: authorization: Authorization header (Bearer token) - + Returns: 完成结果 """ @@ -247,20 +237,17 @@ async def complete_setup(authorization: Optional[str] = Header(None)): # 验证 token if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + current_token = authorization.replace("Bearer ", "") token_manager = get_token_manager() - + if not token_manager.verify_token(current_token): raise HTTPException(status_code=401, detail="Token 无效") - + # 标记配置完成 success = token_manager.mark_setup_completed() - - return CompleteSetupResponse( - success=success, - message="配置已完成" if success else "标记失败" - ) + + return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败") except HTTPException: raise except Exception as e: @@ -272,10 +259,10 @@ async def complete_setup(authorization: Optional[str] = Header(None)): async def reset_setup(authorization: Optional[str] = Header(None)): """ 重置首次配置状态,允许重新进入配置向导 - + Args: authorization: Authorization header (Bearer token) - + Returns: 重置结果 """ @@ -283,20 +270,17 @@ async def reset_setup(authorization: Optional[str] = Header(None)): # 验证 token if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + current_token = authorization.replace("Bearer ", "") token_manager = get_token_manager() - + if not token_manager.verify_token(current_token): raise HTTPException(status_code=401, detail="Token 无效") - + # 重置配置状态 success = token_manager.reset_setup_status() - - return ResetSetupResponse( - success=success, - message="配置状态已重置" if success else "重置失败" - ) + + return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败") except HTTPException: raise except Exception as e: diff --git a/src/webui/statistics_routes.py b/src/webui/statistics_routes.py index 01d5ea28..45855475 100644 --- a/src/webui/statistics_routes.py +++ b/src/webui/statistics_routes.py @@ -1,4 +1,5 @@ """统计数据 API 路由""" + from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field from typing import Dict, Any, List @@ -15,6 +16,7 @@ router = APIRouter(prefix="/statistics", tags=["statistics"]) class StatisticsSummary(BaseModel): """统计数据摘要""" + total_requests: int = Field(0, description="总请求数") total_cost: float = Field(0.0, description="总花费") total_tokens: int = Field(0, description="总token数") @@ -28,6 +30,7 @@ class StatisticsSummary(BaseModel): class ModelStatistics(BaseModel): """模型统计""" + model_name: str request_count: int total_cost: float @@ -37,6 +40,7 @@ class ModelStatistics(BaseModel): class TimeSeriesData(BaseModel): """时间序列数据""" + timestamp: str requests: int = 0 cost: float = 0.0 @@ -45,6 +49,7 @@ class TimeSeriesData(BaseModel): class DashboardData(BaseModel): """仪表盘数据""" + summary: StatisticsSummary model_stats: List[ModelStatistics] hourly_data: List[TimeSeriesData] @@ -56,39 +61,39 @@ class DashboardData(BaseModel): async def get_dashboard_data(hours: int = 24): """ 获取仪表盘统计数据 - + Args: hours: 统计时间范围(小时),默认24小时 - + Returns: 仪表盘数据 """ try: now = datetime.now() start_time = now - timedelta(hours=hours) - + # 获取摘要数据 summary = await _get_summary_statistics(start_time, now) - + # 获取模型统计 model_stats = await _get_model_statistics(start_time) - + # 获取小时级时间序列数据 hourly_data = await _get_hourly_statistics(start_time, now) - + # 获取日级时间序列数据(最近7天) daily_start = now - timedelta(days=7) daily_data = await _get_daily_statistics(daily_start, now) - + # 获取最近活动 recent_activity = await _get_recent_activity(limit=10) - + return DashboardData( summary=summary, model_stats=model_stats, hourly_data=hourly_data, daily_data=daily_data, - recent_activity=recent_activity + recent_activity=recent_activity, ) except Exception as e: logger.error(f"获取仪表盘数据失败: {e}") @@ -98,100 +103,84 @@ async def get_dashboard_data(hours: int = 24): async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary: """获取摘要统计数据""" summary = StatisticsSummary() - + # 查询 LLM 使用记录 - llm_records = list( - LLMUsage.select() - .where(LLMUsage.timestamp >= start_time) - .where(LLMUsage.timestamp <= end_time) - ) - + llm_records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time)) + total_time_cost = 0.0 time_cost_count = 0 - + for record in llm_records: summary.total_requests += 1 summary.total_cost += record.cost or 0.0 summary.total_tokens += (record.prompt_tokens or 0) + (record.completion_tokens or 0) - + if record.time_cost and record.time_cost > 0: total_time_cost += record.time_cost time_cost_count += 1 - + # 计算平均响应时间 if time_cost_count > 0: summary.avg_response_time = total_time_cost / time_cost_count - + # 查询在线时间 online_records = list( - OnlineTime.select() - .where( - (OnlineTime.start_timestamp >= start_time) | - (OnlineTime.end_timestamp >= start_time) - ) + OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time)) ) - + for record in online_records: start = max(record.start_timestamp, start_time) end = min(record.end_timestamp, end_time) if end > start: summary.online_time += (end - start).total_seconds() - + # 查询消息数量 messages = list( - Messages.select() - .where(Messages.time >= start_time.timestamp()) - .where(Messages.time <= end_time.timestamp()) + Messages.select().where(Messages.time >= start_time.timestamp()).where(Messages.time <= end_time.timestamp()) ) - + summary.total_messages = len(messages) # 简单统计:如果 reply_to 不为空,则认为是回复 summary.total_replies = len([m for m in messages if m.reply_to]) - + # 计算派生指标 if summary.online_time > 0: online_hours = summary.online_time / 3600.0 summary.cost_per_hour = summary.total_cost / online_hours summary.tokens_per_hour = summary.total_tokens / online_hours - + return summary async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]: """获取模型统计数据""" - model_data = defaultdict(lambda: { - 'request_count': 0, - 'total_cost': 0.0, - 'total_tokens': 0, - 'time_costs': [] - }) - - records = list( - LLMUsage.select() - .where(LLMUsage.timestamp >= start_time) - ) - + model_data = defaultdict(lambda: {"request_count": 0, "total_cost": 0.0, "total_tokens": 0, "time_costs": []}) + + records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time)) + for record in records: model_name = record.model_assign_name or record.model_name or "unknown" - model_data[model_name]['request_count'] += 1 - model_data[model_name]['total_cost'] += record.cost or 0.0 - model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0) - + model_data[model_name]["request_count"] += 1 + model_data[model_name]["total_cost"] += record.cost or 0.0 + model_data[model_name]["total_tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0) + if record.time_cost and record.time_cost > 0: - model_data[model_name]['time_costs'].append(record.time_cost) - + model_data[model_name]["time_costs"].append(record.time_cost) + # 转换为列表并排序 result = [] for model_name, data in model_data.items(): - avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0 - result.append(ModelStatistics( - model_name=model_name, - request_count=data['request_count'], - total_cost=data['total_cost'], - total_tokens=data['total_tokens'], - avg_response_time=avg_time - )) - + avg_time = sum(data["time_costs"]) / len(data["time_costs"]) if data["time_costs"] else 0.0 + result.append( + ModelStatistics( + model_name=model_name, + request_count=data["request_count"], + total_cost=data["total_cost"], + total_tokens=data["total_tokens"], + avg_response_time=avg_time, + ) + ) + # 按请求数排序 result.sort(key=lambda x: x.request_count, reverse=True) return result[:10] # 返回前10个 @@ -200,96 +189,80 @@ async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]: async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]: """获取小时级统计数据""" # 创建小时桶 - hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0}) - - records = list( - LLMUsage.select() - .where(LLMUsage.timestamp >= start_time) - .where(LLMUsage.timestamp <= end_time) - ) - + hourly_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0}) + + records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time)) + for record in records: # 获取小时键(去掉分钟和秒) hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0) hour_str = hour_key.isoformat() - - hourly_buckets[hour_str]['requests'] += 1 - hourly_buckets[hour_str]['cost'] += record.cost or 0.0 - hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0) - + + hourly_buckets[hour_str]["requests"] += 1 + hourly_buckets[hour_str]["cost"] += record.cost or 0.0 + hourly_buckets[hour_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0) + # 填充所有小时(包括没有数据的) result = [] current = start_time.replace(minute=0, second=0, microsecond=0) while current <= end_time: hour_str = current.isoformat() - data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0}) - result.append(TimeSeriesData( - timestamp=hour_str, - requests=data['requests'], - cost=data['cost'], - tokens=data['tokens'] - )) + data = hourly_buckets.get(hour_str, {"requests": 0, "cost": 0.0, "tokens": 0}) + result.append( + TimeSeriesData(timestamp=hour_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"]) + ) current += timedelta(hours=1) - + return result async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]: """获取日级统计数据""" - daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0}) - - records = list( - LLMUsage.select() - .where(LLMUsage.timestamp >= start_time) - .where(LLMUsage.timestamp <= end_time) - ) - + daily_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0}) + + records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time)) + for record in records: # 获取日期键 day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0) day_str = day_key.isoformat() - - daily_buckets[day_str]['requests'] += 1 - daily_buckets[day_str]['cost'] += record.cost or 0.0 - daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0) - + + daily_buckets[day_str]["requests"] += 1 + daily_buckets[day_str]["cost"] += record.cost or 0.0 + daily_buckets[day_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0) + # 填充所有天 result = [] current = start_time.replace(hour=0, minute=0, second=0, microsecond=0) while current <= end_time: day_str = current.isoformat() - data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0}) - result.append(TimeSeriesData( - timestamp=day_str, - requests=data['requests'], - cost=data['cost'], - tokens=data['tokens'] - )) + data = daily_buckets.get(day_str, {"requests": 0, "cost": 0.0, "tokens": 0}) + result.append( + TimeSeriesData(timestamp=day_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"]) + ) current += timedelta(days=1) - + return result async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]: """获取最近活动""" - records = list( - LLMUsage.select() - .order_by(LLMUsage.timestamp.desc()) - .limit(limit) - ) - + records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit)) + activities = [] for record in records: - activities.append({ - 'timestamp': record.timestamp.isoformat(), - 'model': record.model_assign_name or record.model_name, - 'request_type': record.request_type, - 'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0), - 'cost': record.cost or 0.0, - 'time_cost': record.time_cost or 0.0, - 'status': record.status - }) - + activities.append( + { + "timestamp": record.timestamp.isoformat(), + "model": record.model_assign_name or record.model_name, + "request_type": record.request_type, + "tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0), + "cost": record.cost or 0.0, + "time_cost": record.time_cost or 0.0, + "status": record.status, + } + ) + return activities @@ -297,7 +270,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]: async def get_summary(hours: int = 24): """ 获取统计摘要 - + Args: hours: 统计时间范围(小时) """ @@ -315,7 +288,7 @@ async def get_summary(hours: int = 24): async def get_model_stats(hours: int = 24): """ 获取模型统计 - + Args: hours: 统计时间范围(小时) """ diff --git a/src/webui/token_manager.py b/src/webui/token_manager.py index 7ab16d75..69abf1d8 100644 --- a/src/webui/token_manager.py +++ b/src/webui/token_manager.py @@ -19,7 +19,7 @@ class TokenManager: def __init__(self, config_path: Optional[Path] = None): """ 初始化 Token 管理器 - + Args: config_path: 配置文件路径,默认为项目根目录的 data/webui.json """ @@ -27,10 +27,10 @@ class TokenManager: # 获取项目根目录 (src/webui -> src -> 根目录) project_root = Path(__file__).parent.parent.parent config_path = project_root / "data" / "webui.json" - + self.config_path = config_path self.config_path.parent.mkdir(parents=True, exist_ok=True) - + # 确保配置文件存在并包含有效的 token self._ensure_config() @@ -75,22 +75,23 @@ class TokenManager: """生成新的 64 位随机 token""" # 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符) token = secrets.token_hex(32) - + config = { "access_token": token, "created_at": self._get_current_timestamp(), "updated_at": self._get_current_timestamp(), - "first_setup_completed": False # 标记首次配置未完成 + "first_setup_completed": False, # 标记首次配置未完成 } - + self._save_config(config) logger.info(f"新的 WebUI Token 已生成: {token[:8]}...") - + return token def _get_current_timestamp(self) -> str: """获取当前时间戳字符串""" from datetime import datetime + return datetime.now().isoformat() def get_token(self) -> str: @@ -101,38 +102,38 @@ class TokenManager: def verify_token(self, token: str) -> bool: """ 验证 token 是否有效 - + Args: token: 待验证的 token - + Returns: bool: token 是否有效 """ if not token: return False - + current_token = self.get_token() if not current_token: logger.error("系统中没有有效的 token") return False - + # 使用 secrets.compare_digest 防止时序攻击 is_valid = secrets.compare_digest(token, current_token) - + if is_valid: logger.debug("Token 验证成功") else: logger.warning("Token 验证失败") - + return is_valid def update_token(self, new_token: str) -> tuple[bool, str]: """ 更新 token - + Args: new_token: 新的 token (最少 10 位,必须包含大小写字母和特殊符号) - + Returns: tuple[bool, str]: (是否更新成功, 错误消息) """ @@ -141,17 +142,17 @@ class TokenManager: if not is_valid: logger.error(f"Token 格式无效: {error_msg}") return False, error_msg - + try: config = self._load_config() old_token = config.get("access_token", "")[:8] - + config["access_token"] = new_token config["updated_at"] = self._get_current_timestamp() - + self._save_config(config) logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...") - + return True, "Token 更新成功" except Exception as e: logger.error(f"更新 Token 失败: {e}") @@ -160,7 +161,7 @@ class TokenManager: def regenerate_token(self) -> str: """ 重新生成 token - + Returns: str: 新生成的 token """ @@ -170,20 +171,20 @@ class TokenManager: def _validate_token_format(self, token: str) -> bool: """ 验证 token 格式是否正确(旧的 64 位十六进制验证,保留用于系统生成的 token) - + Args: token: 待验证的 token - + Returns: bool: 格式是否正确 """ if not token or not isinstance(token, str): return False - + # 必须是 64 位十六进制字符串 if len(token) != 64: return False - + # 验证是否为有效的十六进制字符串 try: int(token, 16) @@ -194,48 +195,48 @@ class TokenManager: def _validate_custom_token(self, token: str) -> tuple[bool, str]: """ 验证自定义 token 格式 - + 要求: - 最少 10 位 - 包含大写字母 - 包含小写字母 - 包含特殊符号 - + Args: token: 待验证的 token - + Returns: tuple[bool, str]: (是否有效, 错误消息) """ if not token or not isinstance(token, str): return False, "Token 不能为空" - + # 检查长度 if len(token) < 10: return False, "Token 长度至少为 10 位" - + # 检查是否包含大写字母 has_upper = any(c.isupper() for c in token) if not has_upper: return False, "Token 必须包含大写字母" - + # 检查是否包含小写字母 has_lower = any(c.islower() for c in token) if not has_lower: return False, "Token 必须包含小写字母" - + # 检查是否包含特殊符号 special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/" has_special = any(c in special_chars for c in token) if not has_special: return False, f"Token 必须包含特殊符号 ({special_chars})" - + return True, "Token 格式正确" def is_first_setup(self) -> bool: """ 检查是否为首次配置 - + Returns: bool: 是否为首次配置 """ @@ -245,7 +246,7 @@ class TokenManager: def mark_setup_completed(self) -> bool: """ 标记首次配置已完成 - + Returns: bool: 是否标记成功 """ @@ -263,7 +264,7 @@ class TokenManager: def reset_setup_status(self) -> bool: """ 重置首次配置状态,允许重新进入配置向导 - + Returns: bool: 是否重置成功 """