import asyncio import json import random # 添加 random 模块导入 import re from datetime import datetime from typing import Tuple, Union, Dict, Any, Set # 引入 Set import aiohttp from aiohttp.client import ClientResponse # 相对路径导入,根据你的项目结构调整 # 例如,如果 utils_model.py 在 src/utils/ 下,而 logger 在 src/common/ 下 # from ..common.logger import get_module_logger # from ..common.database import db # from ..config.config import global_config # 假设它们在期望的路径 from src.common.logger import get_module_logger from ...common.database import db from ...config.config import global_config import base64 from PIL import Image import io import os from rich.traceback import install install(extra_lines=3) # 尝试加载 .env 文件中的环境变量 (如果项目结构需要) # load_dotenv() # 如果你的 .env 文件不在标准位置,可能需要指定路径 load_dotenv(dotenv_path='path/to/.env') logger = get_module_logger("model_utils") class PayLoadTooLargeError(Exception): """自定义异常类,用于处理请求体过大错误""" # (代码不变) def __init__(self, message: str): super().__init__(message) self.message = message def __str__(self): return "请求体过大,请尝试压缩图片或减少输入内容。" class RequestAbortException(Exception): """自定义异常类,用于处理请求中断异常""" # (代码不变) def __init__(self, message: str, response: ClientResponse): super().__init__(message) self.message = message self.response = response def __str__(self): return self.message class PermissionDeniedException(Exception): """自定义异常类,用于处理访问拒绝的异常""" # (代码不变) def __init__(self, message: str, key_identifier: str = None): # 添加 key 标识符 super().__init__(message) self.message = message self.key_identifier = key_identifier # 存储导致 403 的 key def __str__(self): return self.message # 新增:用于内部标记需要切换 Key 的异常 class _SwitchKeyException(Exception): """内部异常,用于标记需要切换Key并且跳过标准等待时间.""" # (代码不变) pass # 常见Error Code Mapping error_code_mapping = { # (代码不变) 400: "参数不正确", 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", # 401 也可能是 Key 无效 402: "账号余额不足", 403: "需要实名,或余额不足,或Key无权限", # 扩展 403 的含义 404: "Not Found", 429: "请求过于频繁,请稍后再试", 500: "服务器内部故障", 503: "服务器负载过高", } async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]): """安全地记录请求内容,隐藏敏感信息""" # (代码不变) image_base64: str = request_content.get("image_base64") image_format: str = request_content.get("image_format") is_gemini_payload = payload and isinstance(payload, dict) and "contents" in payload safe_payload = json.loads(json.dumps(payload)) if payload else {} if image_base64 and safe_payload and isinstance(safe_payload, dict): if "messages" in safe_payload and len(safe_payload["messages"]) > 0: if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]: content = safe_payload["messages"][0]["content"] if ( isinstance(content, list) and len(content) > 1 and isinstance(content[1], dict) and "image_url" in content[1] ): safe_payload["messages"][0]["content"][1]["image_url"]["url"] = ( f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," f"{image_base64[:10]}...{image_base64[-10:]}" ) elif is_gemini_payload and "contents" in safe_payload and len(safe_payload["contents"]) > 0: if isinstance(safe_payload["contents"][0], dict) and "parts" in safe_payload["contents"][0]: parts = safe_payload["contents"][0]["parts"] for i, part in enumerate(parts): if isinstance(part, dict) and "inlineData" in part: safe_payload["contents"][0]["parts"][i]["inlineData"]["data"] = ( f"{image_base64[:10]}...{image_base64[-10:]}" ) break return safe_payload class LLMRequest: # (代码不变) MODELS_NEEDING_TRANSFORMATION = [ "o1", "o1-2024-12-17", "o1-mini", "o1-mini-2024-09-12", "o1-preview", "o1-preview-2024-09-12", "o1-pro", "o1-pro-2025-03-19", "o3", "o3-2025-04-16", "o3-mini", "o3-mini-2025-01-31o4-mini", "o4-mini-2025-04-16", ] _abandoned_keys_runtime: Set[str] = set() def __init__(self, model: dict, **kwargs): """初始化 LLMRequest 实例""" # (代码不变) self.model_key_name = model["key"] self.model_name: str = model["name"] self.params = kwargs self.stream = model.get("stream", False) self.pri_in = model.get("pri_in", 0) self.pri_out = model.get("pri_out", 0) self.request_type = model.get("request_type", "default") try: raw_api_key_config = os.environ[self.model_key_name] self.base_url = os.environ[model["base_url"]] self.is_gemini = "googleapis.com" in self.base_url.lower() if self.is_gemini: logger.debug(f"模型 {self.model_name}: 检测到为 Gemini API (Base URL: {self.base_url})") if self.stream: logger.warning(f"模型 {self.model_name}: Gemini 流式输出处理与 OpenAI 不同,暂时强制禁用流式。") self.stream = False # 解析和过滤 API Keys (代码不变) parsed_keys = [] is_list_config = False try: loaded_keys = json.loads(raw_api_key_config) if isinstance(loaded_keys, list): parsed_keys = [str(key) for key in loaded_keys if key] is_list_config = True elif isinstance(loaded_keys, str) and loaded_keys: parsed_keys = [loaded_keys] else: raise ValueError(f"Parsed API key for {self.model_key_name} is not a valid list or string.") except (json.JSONDecodeError, TypeError) as e: if isinstance(raw_api_key_config, list): parsed_keys = [str(key) for key in raw_api_key_config if key] is_list_config = True elif isinstance(raw_api_key_config, str) and raw_api_key_config: parsed_keys = [raw_api_key_config] else: raise ValueError( f"Invalid or empty API key config for {self.model_key_name}: {raw_api_key_config}" ) from e if not parsed_keys: raise ValueError(f"No valid API keys found for {self.model_key_name}.") abandoned_key_name = f"abandon_{self.model_key_name}" abandoned_keys_set = set() raw_abandoned_keys = os.environ.get(abandoned_key_name) if raw_abandoned_keys: try: loaded_abandoned = json.loads(raw_abandoned_keys) if isinstance(loaded_abandoned, list): abandoned_keys_set.update(str(key) for key in loaded_abandoned if key) elif isinstance(loaded_abandoned, str) and loaded_abandoned: abandoned_keys_set.add(loaded_abandoned) logger.info( f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' 的废弃 Keys。" ) except (json.JSONDecodeError, TypeError): if isinstance(raw_abandoned_keys, list): abandoned_keys_set.update(str(key) for key in raw_abandoned_keys if key) logger.info( f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' (直接列表) 的废弃 Keys。" ) elif isinstance(raw_abandoned_keys, str) and raw_abandoned_keys: abandoned_keys_set.add(raw_abandoned_keys) logger.info( f"模型 {model['name']}: 加载了 1 个来自配置 '{abandoned_key_name}' (字符串) 的废弃 Key。" ) else: logger.warning(f"无法解析环境变量 '{abandoned_key_name}' 的内容: {raw_abandoned_keys}") all_abandoned_keys = abandoned_keys_set.union(LLMRequest._abandoned_keys_runtime) active_keys = [key for key in parsed_keys if key not in all_abandoned_keys] if not active_keys: logger.error(f"模型 {model['name']}: 所有为 '{self.model_key_name}' 配置的 Keys 都已被废弃或无效。") raise ValueError( f"No active API keys available for {self.model_key_name} after filtering abandoned keys." ) if is_list_config and len(active_keys) > 1: self._api_key_config = active_keys logger.info( f"模型 {model['name']}: 初始化完成,可用 Keys: {len(self._api_key_config)} (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。" ) elif active_keys: self._api_key_config = active_keys[0] logger.info( f"模型 {model['name']}: 初始化完成,使用单个活动 Key (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。" ) else: raise ValueError(f"Unexpected state: No active keys for {self.model_key_name}.") # 加载代理配置 (代码不变) self.proxy_url = None self.proxy_models_set = set() proxy_host = os.environ.get("PROXY_HOST") proxy_port = os.environ.get("PROXY_PORT") proxy_models_str = os.environ.get("PROXY_MODELS", "") if proxy_host and proxy_port: try: int(proxy_port) self.proxy_url = f"http://{proxy_host}:{proxy_port}" logger.debug(f"代理已配置: {self.proxy_url}") if proxy_models_str: try: cleaned_str = proxy_models_str.strip("'\"") self.proxy_models_set = { model_name.strip() for model_name in cleaned_str.split(",") if model_name.strip() } logger.debug(f"以下模型将使用代理: {self.proxy_models_set}") except Exception as e: logger.error( f"解析 PROXY_MODELS ('{proxy_models_str}') 出错: {e}. 代理将不会对特定模型生效。" ) self.proxy_models_set = set() except ValueError: logger.error(f"无效的代理端口号: {proxy_port}。代理将不被启用。") self.proxy_url = None self.proxy_models_set = set() except Exception as e: logger.error(f"加载代理配置时发生错误: {e}") self.proxy_url = None self.proxy_models_set = set() else: logger.info("未配置代理服务器 (PROXY_HOST 或 PROXY_PORT 未设置)。") except KeyError as e: # (代码不变) missing_key = str(e).strip("'") if missing_key == self.model_key_name: logger.error(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'") raise ValueError(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'") from e elif missing_key == model["base_url"]: logger.error(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'") raise ValueError(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'") from e else: logger.error(f"配置错误:找不到环境变量 - {str(e)}") raise ValueError(f"配置错误:找不到环境变量 - {str(e)}") from e except AttributeError as e: # (代码不变) logger.error(f"原始 model dict 信息:{model}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e except ValueError as e: # (代码不变) logger.error(f"API Key 或配置初始化错误 for {self.model_key_name}: {str(e)}") raise e self._init_database() @staticmethod def _init_database(): """初始化数据库集合""" # (代码不变) try: db.llm_usage.create_index([("timestamp", 1)]) db.llm_usage.create_index([("model_name", 1)]) db.llm_usage.create_index([("user_id", 1)]) db.llm_usage.create_index([("request_type", 1)]) except Exception as e: logger.error(f"创建数据库索引失败: {str(e)}") def _record_usage( self, prompt_tokens: int, completion_tokens: int, total_tokens: int, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions", ): """记录模型使用情况到数据库 Args: prompt_tokens: 输入token数 completion_tokens: 输出token数 total_tokens: 总token数 user_id: 用户ID,默认为system request_type: 请求类型 endpoint: API端点 """ # 如果 request_type 为 None,则使用实例变量中的值 if request_type is None: request_type = self.request_type actual_endpoint = endpoint if self.is_gemini: if endpoint == "/embeddings": actual_endpoint = ":embedContent" else: actual_endpoint = ":generateContent" try: usage_data = { "model_name": self.model_name, "user_id": user_id, "request_type": request_type or self.request_type, "endpoint": actual_endpoint, "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, "cost": self._calculate_cost(prompt_tokens, completion_tokens), "status": "success", "timestamp": datetime.now(), } db.llm_usage.insert_one(usage_data) logger.trace( f"Token使用情况 - 模型: {self.model_name}, " f"用户: {user_id}, 类型: {request_type or self.request_type}, " f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " f"总计: {total_tokens}" ) except Exception as e: logger.error(f"记录token使用情况失败: {str(e)}") def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: """计算API调用成本""" # (代码不变) input_cost = (prompt_tokens / 1000000) * self.pri_in output_cost = (completion_tokens / 1000000) * self.pri_out return round(input_cost + output_cost, 6) async def _prepare_request( self, endpoint: str, prompt: str = None, image_base64: str = None, image_format: str = None, payload: dict = None, retry_policy: dict = None, **kwargs: Any, ) -> Dict[str, Any]: """配置请求参数,合并实例参数和调用时参数""" default_retry = { "max_retries": global_config.api_polling_max_retries, "base_wait": 10, "retry_codes": [429, 413, 500, 503], "abort_codes": [400, 401, 402, 403], } policy = {**default_retry, **(retry_policy or {})} _actual_endpoint = endpoint if self.is_gemini: action = endpoint.lstrip("/") api_url = f"{self.base_url.rstrip('/')}/{self.model_name}{action}" stream_mode = False else: api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" stream_mode = self.stream call_params = {k: v for k, v in kwargs.items() if k != "request_type"} merged_params = {**self.params, **call_params} if payload is None: payload = await self._build_payload(prompt, image_base64, image_format, merged_params) else: logger.debug("使用外部提供的 payload,忽略单次调用参数合并。") if not self.is_gemini and stream_mode: payload["stream"] = merged_params.get("stream", stream_mode) return { "policy": policy, "payload": payload, "api_url": api_url, "stream_mode": payload.get("stream", False), "image_base64": image_base64, "image_format": image_format, "prompt": prompt, } async def _execute_request( self, endpoint: str, prompt: str = None, image_base64: str = None, image_format: str = None, payload: dict = None, retry_policy: dict = None, response_handler: callable = None, user_id: str = "system", request_type: str = None, **kwargs: Any, ): """统一请求执行入口, 支持列表 key 切换、代理和单次调用参数覆盖""" final_request_type = request_type or kwargs.get("request_type") or self.request_type api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"} request_content = await self._prepare_request( endpoint, prompt, image_base64, image_format, payload, retry_policy, **api_kwargs ) policy = request_content["policy"] api_url = request_content["api_url"] actual_payload = request_content["payload"] stream_mode = request_content["stream_mode"] use_proxy = False current_proxy_url = None if self.proxy_url and self.model_name in self.proxy_models_set: use_proxy = True current_proxy_url = self.proxy_url logger.debug(f"模型 {self.model_name}: 将通过代理 {current_proxy_url} 发送请求。") elif self.proxy_url: logger.debug(f"模型 {self.model_name}: 配置了代理,但此模型不在 PROXY_MODELS 列表中,将不使用代理。") else: logger.debug(f"模型 {self.model_name}: 未配置或不为此模型使用代理。") current_key = None keys_failed_429 = set() keys_abandoned_runtime = set() key_switch_limit_429 = global_config.api_polling_max_retries key_switch_limit_403 = global_config.api_polling_max_retries available_keys_pool = [] is_key_list = isinstance(self._api_key_config, list) if is_key_list: available_keys_pool = list(self._api_key_config) if not available_keys_pool: logger.error(f"模型 {self.model_name}: 初始化后无可用活动 Keys。") raise ValueError(f"模型 {self.model_name}: 无可用活动 Keys。") random.shuffle(available_keys_pool) key_switch_limit_429 = min(key_switch_limit_429, len(available_keys_pool)) key_switch_limit_403 = min(key_switch_limit_403, len(available_keys_pool)) logger.info( f"模型 {self.model_name}: Key 列表模式,启用 429/403 自动切换(429上限: {key_switch_limit_429}, 403上限: {key_switch_limit_403})。" ) elif isinstance(self._api_key_config, str): available_keys_pool = [self._api_key_config] key_switch_limit_429 = 1 key_switch_limit_403 = 1 else: logger.error(f"模型 {self.model_name}: 无效的 API Key 配置类型在执行时遇到: {type(self._api_key_config)}") raise TypeError(f"模型 {self.model_name}: 无效的 API Key 配置类型") last_exception = None for attempt in range(policy["max_retries"]): if available_keys_pool: current_key = available_keys_pool.pop(0) elif current_key: logger.debug( f"模型 {self.model_name}: 无新 Key 可用或为单 Key 模式,将使用 Key ...{current_key[-4:]} 进行重试 (第 {attempt + 1} 次尝试)" ) else: if ( not self._api_key_config or all( k in LLMRequest._abandoned_keys_runtime for k in self._api_key_config if isinstance(self._api_key_config, list) ) or ( isinstance(self._api_key_config, str) and self._api_key_config in LLMRequest._abandoned_keys_runtime ) ): final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。" logger.critical(final_error_msg) raise PermissionDeniedException(final_error_msg) else: raise RuntimeError(f"模型 {self.model_name}: 无法选择 API key (第 {attempt + 1} 次尝试)") logger.debug(f"模型 {self.model_name}: 尝试使用 Key: ...{current_key[-4:]} (总第 {attempt + 1} 次尝试)") try: headers = await self._build_headers(current_key) if not self.is_gemini and stream_mode: headers["Accept"] = "text/event-stream" async with aiohttp.ClientSession() as session: post_kwargs = {"headers": headers, "json": actual_payload, "timeout": 60} if use_proxy: post_kwargs["proxy"] = current_proxy_url async with session.post(api_url, **post_kwargs) as response: if response.status == 429 and is_key_list: logger.warning(f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 429 错误。") response_text = await response.text() logger.debug( f"模型 {self.model_name}: Key ...{current_key[-4:]} response:\n{json.dumps(json.loads(response_text), indent=2, ensure_ascii=False)}\napi_url:\n{api_url}\nheader:\n{headers}\npayload:\n{actual_payload}" ) if current_key not in keys_failed_429: keys_failed_429.add(current_key) logger.info( f" (因 429 已失败 {len(keys_failed_429)}/{key_switch_limit_429} 个不同 Key)" ) if available_keys_pool and len(keys_failed_429) < key_switch_limit_429: logger.info(" 尝试因 429 切换到下一个可用 Key...") raise _SwitchKeyException() else: logger.warning(" 无更多 Key 可因 429 切换或已达上限。") else: logger.warning(f" Key ...{current_key[-4:]} 再次遇到 429,按标准重试流程。") elif response.status == 403 and is_key_list: logger.error( f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 403 (权限拒绝) 错误。" ) if current_key not in keys_abandoned_runtime: keys_abandoned_runtime.add(current_key) LLMRequest._abandoned_keys_runtime.add(current_key) logger.critical( f" !! Key ...{current_key[-4:]} 已添加到运行时废弃列表。请考虑将其移至配置中的 'abandon_{self.model_key_name}' !!" ) if current_key in available_keys_pool: available_keys_pool.remove(current_key) if available_keys_pool and len(keys_abandoned_runtime) < key_switch_limit_403: logger.info(" 尝试因 403 切换到下一个可用 Key...") raise _SwitchKeyException() else: logger.error(" 无更多 Key 可因 403 切换或已达上限。将中止请求。") await response.read() raise PermissionDeniedException( f"Key ...{current_key[-4:]} 权限被拒,且无其他可用 Key 切换。", key_identifier=current_key, ) else: logger.error(f" Key ...{current_key[-4:]} 再次遇到 403,这不应发生。中止请求。") await response.read() raise PermissionDeniedException( f"Key ...{current_key[-4:]} 重复遇到 403。", key_identifier=current_key ) elif response.status in policy["retry_codes"] or response.status in policy["abort_codes"]: await self._handle_error_response(response, attempt, policy, current_key) if response.status in policy["retry_codes"] and attempt < policy["max_retries"] - 1: if response.status not in [429, 403]: wait_time = policy["base_wait"] * (2**attempt) logger.warning( f"模型 {self.model_name}: 遇到可重试错误 {response.status}, 等待 {wait_time} 秒后重试..." ) await asyncio.sleep(wait_time) last_exception = RuntimeError(f"重试错误 {response.status}") continue if response.status in policy["abort_codes"] or ( response.status in policy["retry_codes"] and attempt >= policy["max_retries"] - 1 ): if attempt >= policy["max_retries"] - 1 and response.status in policy["retry_codes"]: logger.error( f"模型 {self.model_name}: 达到最大重试次数,最后一次尝试仍为可重试错误 {response.status}。" ) # await self._handle_error_response(response, attempt, policy, current_key) # await response.read() # final_error_msg = f"请求中止或达到最大重试次数,最终状态码: {response.status}" # logger.error(final_error_msg) # raise RequestAbortException(final_error_msg, response) response.raise_for_status() result = {} if not self.is_gemini and stream_mode: result = await self._handle_stream_output(response) else: result = await response.json() return ( response_handler(result) if response_handler else self._default_response_handler(result, user_id, final_request_type, endpoint) ) except _SwitchKeyException: last_exception = _SwitchKeyException() logger.debug("捕获到 _SwitchKeyException,立即进行下一次尝试。") continue except PermissionDeniedException as e: logger.error(f"模型 {self.model_name}: 因权限拒绝 (403) 中止请求: {e}") if is_key_list and not available_keys_pool and e.key_identifier: logger.critical(f" 中止原因是 Key ...{e.key_identifier[-4:]} 触发 403 后已无其他 Key 可用。") raise e except aiohttp.ClientProxyConnectionError as e: logger.error(f"代理连接错误: {e} (代理地址: {current_proxy_url})") last_exception = e if attempt >= policy["max_retries"] - 1: raise RuntimeError(f"代理连接失败达到最大重试次数: {e}") from e wait_time = policy["base_wait"] * (2**attempt) logger.warning(f"模型 {self.model_name}: 代理连接错误,等待 {wait_time} 秒后重试...") await asyncio.sleep(wait_time) continue except aiohttp.ClientConnectorError as e: logger.error(f"网络连接错误: {e} (URL: {api_url}, 代理: {current_proxy_url})") last_exception = e if attempt >= policy["max_retries"] - 1: raise RuntimeError(f"网络连接失败达到最大重试次数: {e}") from e wait_time = policy["base_wait"] * (2**attempt) logger.warning(f"模型 {self.model_name}: 网络连接错误,等待 {wait_time} 秒后重试...") await asyncio.sleep(wait_time) continue except (PayLoadTooLargeError, RequestAbortException) as e: # (代码不变) logger.error(f"模型 {self.model_name}: 请求处理中遇到关键错误,将中止: {e}") raise e except Exception as e: # (代码不变) last_exception = e logger.warning( f"模型 {self.model_name}: 第 {attempt + 1} 次尝试中发生非 HTTP 错误: {str(e.__class__.__name__)} - {str(e)}" ) if attempt >= policy["max_retries"] - 1: logger.error( f"模型 {self.model_name}: 达到最大重试次数 ({policy['max_retries']}),因非 HTTP 错误失败。" ) else: try: temp_request_content = { "policy": policy, "payload": actual_payload, "api_url": api_url, "stream_mode": stream_mode, "image_base64": image_base64, "image_format": image_format, "prompt": prompt, } handled_payload, count_delta = await self._handle_exception( e, attempt, temp_request_content, merged_params=api_kwargs ) if handled_payload: actual_payload = handled_payload logger.info(f"模型 {self.model_name}: 异常处理更新了 payload,将使用当前 Key 重试。") wait_time = policy["base_wait"] * (2**attempt) logger.warning(f"模型 {self.model_name}: 等待 {wait_time} 秒后重试...") await asyncio.sleep(wait_time) continue except (RequestAbortException, PermissionDeniedException) as abort_exception: logger.error(f"模型 {self.model_name}: 异常处理判断需要中止请求: {abort_exception}") raise abort_exception except RuntimeError as rt_error: logger.error(f"模型 {self.model_name}: 异常处理遇到运行时错误: {rt_error}") raise rt_error # --- 循环结束 --- logger.error(f"模型 {self.model_name}: 所有重试尝试 ({policy['max_retries']} 次) 均失败。") if last_exception: if isinstance(last_exception, PermissionDeniedException): logger.error(f"最后遇到的错误是权限拒绝: {str(last_exception)}") raise last_exception logger.error(f"最后遇到的错误: {str(last_exception.__class__.__name__)} - {str(last_exception)}") raise RuntimeError( f"模型 {self.model_name} 达到最大重试次数,API 请求失败。最后错误: {str(last_exception)}" ) from last_exception else: if not available_keys_pool and keys_abandoned_runtime: final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。" logger.critical(final_error_msg) raise PermissionDeniedException(final_error_msg) else: raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API 请求失败,原因未知。") async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]: """处理 OpenAI 兼容的流式输出""" # (代码不变) flag_delta_content_finished = False accumulated_content = "" usage = None reasoning_content = "" content = "" tool_calls = None async for line_bytes in response.content: try: line = line_bytes.decode("utf-8").strip() if not line: continue if line.startswith("data:"): data_str = line[5:].strip() if data_str == "[DONE]": break try: chunk = json.loads(data_str) if flag_delta_content_finished: chunk_usage = chunk.get("usage", None) if chunk_usage: usage = chunk_usage else: delta = chunk["choices"][0]["delta"] delta_content = delta.get("content") if delta_content is None: delta_content = "" accumulated_content += delta_content if "tool_calls" in delta: if tool_calls is None: tool_calls = [] for tc in delta["tool_calls"]: new_tc = dict(tc) if "function" in new_tc and "arguments" not in new_tc["function"]: new_tc["function"]["arguments"] = "" tool_calls.append(new_tc) else: for i, tc_delta in enumerate(delta["tool_calls"]): if ( i < len(tool_calls) and "function" in tc_delta and "arguments" in tc_delta["function"] ): if "arguments" in tool_calls[i]["function"]: tool_calls[i]["function"]["arguments"] += tc_delta["function"][ "arguments" ] else: tool_calls[i]["function"]["arguments"] = tc_delta["function"][ "arguments" ] elif i >= len(tool_calls): new_tc = dict(tc_delta) if "function" in new_tc and "arguments" not in new_tc["function"]: new_tc["function"]["arguments"] = "" tool_calls.append(new_tc) finish_reason = chunk["choices"][0].get("finish_reason") if delta.get("reasoning_content", None): reasoning_content += delta["reasoning_content"] if finish_reason == "stop" or finish_reason == "tool_calls": chunk_usage = chunk.get("usage", None) if chunk_usage: usage = chunk_usage break flag_delta_content_finished = True except json.JSONDecodeError as e: logger.error(f"模型 {self.model_name} 解析流式 JSON 错误: {e} - data: '{data_str}'") except Exception as e: logger.exception(f"模型 {self.model_name} 解析流式输出块错误: {str(e)}") except UnicodeDecodeError as e: logger.warning(f"模型 {self.model_name} 流式输出解码错误: {e} - bytes: {line_bytes[:50]}...") except Exception as e: if isinstance(e, GeneratorExit): log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..." else: log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}" logger.warning(log_content) try: await response.release() except Exception as cleanup_error: logger.error(f"清理资源时发生错误: {cleanup_error}") content = accumulated_content break if not content and accumulated_content: content = accumulated_content think_match = re.search(r"(.*?)", content, re.DOTALL) if think_match: reasoning_content = think_match.group(1).strip() content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() message = { "content": content, "reasoning_content": reasoning_content, } if tool_calls: message["tool_calls"] = tool_calls result = { "choices": [{"message": message}], "usage": usage, } return result async def _handle_error_response( self, response: ClientResponse, retry_count: int, policy: Dict[str, Any], current_key: str = None ) -> None: """处理 HTTP 错误响应 (区分 403 和其他错误)""" # (代码不变) status = response.status try: error_text = await response.text() except Exception as e: error_text = f"(无法读取响应体: {e})" if status == 403: logger.error( f"模型 {self.model_name}: 遇到 403 (权限拒绝) 错误。Key: ...{current_key[-4:] if current_key else 'N/A'}. " f"响应: {error_text[:200]}" ) raise PermissionDeniedException(f"模型禁止访问 ({status})", key_identifier=current_key) elif status in policy["retry_codes"] and status != 429: if status == 413: logger.warning( f"模型 {self.model_name}: 错误码 413 (Payload Too Large)。Key: ...{current_key[-4:] if current_key else 'N/A'}. 尝试压缩..." ) raise PayLoadTooLargeError("请求体过大") elif status in [500, 503]: logger.error( f"模型 {self.model_name}: 服务器内部错误或过载 ({status})。Key: ...{current_key[-4:] if current_key else 'N/A'}. " f"响应: {error_text[:200]}" ) return else: logger.warning( f"模型 {self.model_name}: 遇到可重试错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}" ) return elif status in policy["abort_codes"]: logger.error( f"模型 {self.model_name}: 遇到需要中止的错误码: {status} - {error_code_mapping.get(status, '未知错误')}. " f"Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}" ) raise RequestAbortException(f"请求出现错误 {status},中止处理", response) else: logger.error( f"模型 {self.model_name}: 遇到未明确处理的错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}" ) try: response.raise_for_status() raise RequestAbortException(f"未处理的错误状态码 {status}", response) except aiohttp.ClientResponseError as e: raise RequestAbortException(f"未处理的错误状态码 {status}: {e.message}", response) from e async def _handle_exception( self, exception, retry_count: int, request_content: Dict[str, Any], merged_params: Dict[str, Any] = None ) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]: """处理非 HTTP 错误,支持使用合并后的参数重建 payload""" policy = request_content["policy"] payload = request_content["payload"] _wait_time = policy["base_wait"] * (2**retry_count) keep_request = False if retry_count < policy["max_retries"] - 1: keep_request = True params_for_rebuild = merged_params if merged_params is not None else payload if isinstance(exception, PayLoadTooLargeError): if keep_request: logger.warning("请求体过大 (PayLoadTooLargeError),尝试压缩图片...") image_base64 = request_content.get("image_base64") if image_base64: compressed_image_base64 = compress_base64_image_by_scale(image_base64) if compressed_image_base64 != image_base64: new_payload = await self._build_payload( request_content["prompt"], compressed_image_base64, request_content["image_format"], params_for_rebuild, ) logger.info("图片压缩成功,将使用压缩后的图片重试。") return new_payload, 0 else: logger.warning("图片压缩未改变大小或失败。") else: logger.warning("请求体过大但请求中不包含图片,无法压缩。") return None, 0 else: logger.error("达到最大重试次数,请求体仍然过大。") raise RuntimeError("请求体过大,压缩或重试后仍然失败。") from exception elif isinstance(exception, (aiohttp.ClientError, asyncio.TimeoutError)): if keep_request: logger.error(f"模型 {self.model_name} 网络错误: {str(exception)}") return None, 0 else: logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}") raise RuntimeError(f"网络请求失败: {str(exception)}") from exception elif isinstance(exception, aiohttp.ClientResponseError): if keep_request: logger.error( f"模型 {self.model_name} HTTP响应错误 (未被策略覆盖): 状态码: {exception.status}, 错误: {exception.message}" ) try: error_text = await exception.response.text() if hasattr(exception, "response") else str(exception) logger.error(f"服务器错误响应详情: {error_text[:500]}") except Exception as parse_err: logger.warning(f"无法解析服务器错误响应内容: {str(parse_err)}") return None, 0 else: logger.critical( f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}" ) current_key_placeholder = request_content.get("current_key", "******") handled_payload = await _safely_record(request_content, payload) logger.critical( f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}" ) raise RuntimeError( f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}" ) from exception else: if keep_request: logger.error( f"模型 {self.model_name} 遇到未知错误: {str(exception.__class__.__name__)} - {str(exception)}" ) return None, 0 else: logger.critical( f"模型 {self.model_name} 请求因未知错误失败: {str(exception.__class__.__name__)} - {str(exception)}" ) current_key_placeholder = request_content.get("current_key", "******") handled_payload = await _safely_record(request_content, payload) logger.critical( f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}" ) raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") from exception async def _transform_parameters(self, merged_params: dict) -> dict: """根据模型名称转换合并后的参数,并移除内部参数""" # (代码不变) new_params = dict(merged_params) new_params.pop("request_type", None) if not self.is_gemini and self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION: new_params.pop("temperature", None) if "max_tokens" in new_params: new_params["max_completion_tokens"] = new_params.pop("max_tokens") elif self.is_gemini: gen_config = new_params.get("generationConfig", {}) if "temperature" in new_params: gen_config["temperature"] = new_params.pop("temperature") if "max_tokens" in new_params: gen_config["maxOutputTokens"] = new_params.pop("max_tokens") if "top_p" in new_params: gen_config["topP"] = new_params.pop("top_p") if "top_k" in new_params: gen_config["topK"] = new_params.pop("top_k") if gen_config: new_params["generationConfig"] = gen_config new_params.pop("frequency_penalty", None) new_params.pop("presence_penalty", None) new_params.pop("max_completion_tokens", None) return new_params async def _build_payload( self, prompt: str, image_base64: str = None, image_format: str = None, merged_params: dict = None ) -> dict: """构建请求体 (区分 Gemini 和 OpenAI),使用合并和转换后的参数""" # (代码不变) if merged_params is None: merged_params = self.params params_copy = await self._transform_parameters(merged_params) if self.is_gemini: parts = [] if prompt: parts.append({"text": prompt}) if image_base64: mime_type = f"image/{image_format.lower() if image_format else 'jpeg'}" parts.append({"inlineData": {"mimeType": mime_type, "data": image_base64}}) payload = {"contents": [{"parts": parts}], **params_copy} payload.pop("model", None) # --- 添加 Gemini 安全设置 --- safety_settings = [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, ] payload["safetySettings"] = safety_settings logger.debug(f"模型 {self.model_name}: 已为 Gemini 函数调用请求添加 safetySettings (BLOCK_NONE)。") # --- 结束添加安全设置 --- else: if image_base64: messages = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": { "url": f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64}" }, }, ], } ] else: messages = [{"role": "user", "content": prompt}] payload = { "model": self.model_name, "messages": messages, **params_copy, } if "max_tokens" not in payload and "max_completion_tokens" not in payload: if "max_tokens" not in params_copy and "max_completion_tokens" not in params_copy: payload["max_tokens"] = global_config.model_max_output_length if "max_completion_tokens" in payload: payload["max_tokens"] = payload.pop("max_completion_tokens") return payload def _default_response_handler( self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions" ) -> Tuple: """默认响应解析 (区分 Gemini 和 OpenAI),并处理函数/工具调用""" content = "没有返回结果" reasoning_content = "" tool_calls = None # OpenAI 格式 function_call = None # Gemini 格式 prompt_tokens = 0 completion_tokens = 0 total_tokens = 0 if self.is_gemini: # --- 解析 Gemini 响应 --- try: if "candidates" in result and result["candidates"]: candidate = result["candidates"][0] # 检查是否有 content 和 parts if "content" in candidate and "parts" in candidate["content"] and candidate["content"]["parts"]: # 查找 functionCall 或 text 部分 final_text_parts = [] for part in candidate["content"]["parts"]: if "functionCall" in part: function_call = part["functionCall"] # 获取 Gemini 的 functionCall # Gemini functionCall 通常不与 text 一起返回,这里假设只处理 functionCall break # 找到 functionCall 就停止处理 parts elif "text" in part: final_text_parts.append(part.get("text", "")) if not function_call: # 如果没有 functionCall,处理 text raw_content = "".join(final_text_parts).strip() content, reasoning = self._extract_reasoning(raw_content) reasoning_content = reasoning # else: function_call 已获取,content 留空或设为特定值 else: content = "Gemini响应中缺少 content 或 parts" logger.warning(f"模型 {self.model_name}: Gemini 响应格式不完整 (缺少 content/parts): {result}") finish_reason = candidate.get("finishReason") if finish_reason == "SAFETY": logger.warning(f"模型 {self.model_name}: Gemini 响应因安全设置被阻止。") content = "响应内容因安全原因被过滤。" elif finish_reason == "RECITATION": logger.warning(f"模型 {self.model_name}: Gemini 响应因引用限制被阻止。") content = "响应内容因引用限制被过滤。" elif finish_reason == "OTHER": logger.warning(f"模型 {self.model_name}: Gemini 响应因未知原因停止。") # finishReason == "TOOL_CODE" or "FUNCTION_CALL" 是正常情况 usage = result.get("usageMetadata", {}) if usage: prompt_tokens = usage.get("promptTokenCount", 0) completion_tokens = usage.get("candidatesTokenCount", 0) total_tokens = usage.get("totalTokenCount", 0) if completion_tokens == 0 and total_tokens > 0: completion_tokens = total_tokens - prompt_tokens else: logger.warning(f"模型 {self.model_name} (Gemini) 的响应中缺少 'usageMetadata' 信息。") except Exception as e: logger.error(f"解析 Gemini 响应出错: {e} - 响应: {result}") content = "解析 Gemini 响应时出错" else: # --- 解析 OpenAI 兼容响应 --- # (代码不变) if "choices" in result and result["choices"]: message = result["choices"][0].get("message", {}) raw_content = message.get("content", "") content, reasoning = self._extract_reasoning(raw_content if raw_content else "") explicit_reasoning = message.get("model_extra", {}).get("reasoning_content", "") if not explicit_reasoning: explicit_reasoning = message.get("reasoning_content", "") reasoning_content = explicit_reasoning if explicit_reasoning else reasoning tool_calls = message.get("tool_calls", None) # 获取 OpenAI 的 tool_calls usage = result.get("usage", {}) if usage: prompt_tokens = usage.get("prompt_tokens", 0) completion_tokens = usage.get("completion_tokens", 0) total_tokens = usage.get("total_tokens", 0) else: logger.warning(f"模型 {self.model_name} (OpenAI) 的响应中缺少 'usage' 信息。") else: logger.warning(f"模型 {self.model_name} (OpenAI) 的响应格式不符合预期: {result}") # --- 记录 Token 使用情况 --- # (代码不变) if prompt_tokens > 0 or completion_tokens > 0 or total_tokens > 0: self._record_usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, user_id=user_id, request_type=request_type, endpoint=endpoint, ) else: logger.warning(f"模型 {self.model_name}: 未能从响应中提取有效的 token 使用信息。") # --- 返回结果 (统一格式) --- final_tool_calls = None if tool_calls: # 来自 OpenAI final_tool_calls = tool_calls logger.debug(f"检测到 OpenAI 工具调用: {final_tool_calls}") elif function_call: # 来自 Gemini logger.debug(f"检测到 Gemini 函数调用: {function_call}") # 将 Gemini functionCall 转换为 OpenAI tool_calls 格式 # 注意: Gemini 的 functionCall 没有显式的 id 和 type,需要模拟 final_tool_calls = [ { "id": f"call_{random.randint(1000, 9999)}", # 生成一个随机 ID "type": "function", "function": { "name": function_call.get("name"), # Gemini 的参数在 'args' 中,OpenAI 在 'arguments' (通常是 JSON 字符串) # 需要将 Gemini 的 dict 参数转换为 JSON 字符串 "arguments": json.dumps(function_call.get("args", {})), }, } ] logger.debug(f"转换为 OpenAI tool_calls 格式: {final_tool_calls}") if final_tool_calls: # 如果有工具/函数调用,通常 content 为空或包含思考过程,这里返回转换后的调用信息 return content, reasoning_content, final_tool_calls else: # 没有工具/函数调用,返回普通文本响应 return content, reasoning_content @staticmethod def _extract_reasoning(content: str) -> Tuple[str, str]: """CoT思维链提取""" # (代码不变) if not content: return "", "" match = re.search(r"(.*?)", content, re.DOTALL) cleaned_content = re.sub(r".*?", "", content, flags=re.DOTALL, count=1).strip() if match: reasoning = match.group(1).strip() else: reasoning = "" return cleaned_content, reasoning async def _build_headers(self, api_key: str, no_key: bool = False) -> dict: """构建请求头 (区分 Gemini 和 OpenAI)""" # (代码不变) if no_key: if self.is_gemini: return {"x-goog-api-key": "**********", "Content-Type": "application/json"} else: return {"Authorization": "Bearer **********", "Content-Type": "application/json"} else: if not api_key: logger.error(f"尝试使用无效 (空) 的 API key 为模型 {self.model_name} 构建请求头。") raise ValueError("无效的 API key 提供给 _build_headers。") if self.is_gemini: return {"x-goog-api-key": api_key, "Content-Type": "application/json"} else: return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} async def generate_response(self, prompt: str, user_id: str = "system", **kwargs) -> Tuple: """根据输入的提示生成模型的异步响应,支持覆盖参数""" endpoint = ":generateContent" if self.is_gemini else "/chat/completions" response = await self._execute_request( endpoint=endpoint, prompt=prompt, user_id=user_id, request_type="chat", **kwargs ) if len(response) == 3: content, reasoning_content, tool_calls = response return content, reasoning_content, self.model_name, tool_calls else: content, reasoning_content = response return content, reasoning_content, self.model_name async def generate_response_for_image( self, prompt: str, image_base64: str, image_format: str, user_id: str = "system", **kwargs ) -> Tuple: """根据输入的提示和图片生成模型的异步响应,支持覆盖参数""" endpoint = ":generateContent" if self.is_gemini else "/chat/completions" response = await self._execute_request( endpoint=endpoint, prompt=prompt, image_base64=image_base64, image_format=image_format, user_id=user_id, request_type="vision", **kwargs, ) # _default_response_handler 现在总是返回至少2个值 if len(response) == 3: return response # content, reasoning, tool_calls (tool_calls 可能为 None) elif len(response) == 2: content, reasoning = response return content, reasoning # 对于 vision 请求,通常没有 tool_calls else: logger.error(f"来自 _default_response_handler 的意外响应格式: {response}") return "处理响应出错", "" async def generate_response_async( self, prompt: str, user_id: str = "system", request_type: str = "chat", **kwargs ) -> Union[str, Tuple]: """异步方式根据输入的提示生成模型的响应 (通用),支持覆盖参数""" # (代码不变) endpoint = ":generateContent" if self.is_gemini else "/chat/completions" response = await self._execute_request( endpoint=endpoint, prompt=prompt, payload=None, retry_policy=None, response_handler=None, user_id=user_id, request_type=request_type, **kwargs, ) return response # 修改:实现 Gemini Function Calling 的 Payload 构建 async def generate_response_tool_async( self, prompt: str, tools: list, user_id: str = "system", **kwargs ) -> tuple[str, str, list | None]: """异步方式根据输入的提示和工具生成模型的响应,支持覆盖参数和 Gemini 函数调用""" endpoint = ":generateContent" if self.is_gemini else "/chat/completions" merged_params = {**self.params, **kwargs} transformed_params = await self._transform_parameters(merged_params) # 清理 request_type 等 payload = None if self.is_gemini: # --- 构建 Gemini Function Calling Payload --- logger.debug(f"为 Gemini ({self.model_name}) 构建函数调用请求。") # 1. 转换工具定义 (OpenAI -> Gemini) # OpenAI tool format: [{"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}}] # Gemini tool format: [{"functionDeclarations": [{"name": ..., "description": ..., "parameters": ...}]}] function_declarations = [] if tools: for tool in tools: if tool.get("type") == "function" and "function" in tool: func_def = tool["function"] # Gemini parameters 使用 OpenAPI Schema,与 OpenAI 基本兼容 function_declarations.append( { "name": func_def.get("name"), "description": func_def.get("description", ""), # Description is required for Gemini "parameters": func_def.get( "parameters", {"type": "object", "properties": {}} ), # Ensure parameters exist } ) else: logger.warning(f"跳过不支持的工具类型或格式: {tool}") if not function_declarations: logger.error("没有有效的函数声明可用于 Gemini 请求。") return "没有提供有效的函数定义", "", None gemini_tools = [{"functionDeclarations": function_declarations}] # 2. 构建 Gemini Payload # parts = [{"text": prompt}] # 初始 parts payload = { "contents": [{"parts": [{"text": prompt}]}], # 包含用户提示 "tools": gemini_tools, # toolConfig 默认是 AUTO,可以根据需要从 kwargs 获取或硬编码 # "toolConfig": {"functionCallingConfig": {"mode": "ANY"}}, # 例如强制调用 **transformed_params, # 合并其他转换后的参数 (如 generationConfig) } payload.pop("model", None) # Gemini 不在顶层传 model payload.pop("messages", None) # 移除 OpenAI 特有的 messages payload.pop("tool_choice", None) # 移除 OpenAI 特有的 tool_choice logger.trace(f"构建的 Gemini 函数调用 Payload: {json.dumps(payload, indent=2)}") else: # --- 构建 OpenAI Tool Calling Payload --- # (逻辑不变) logger.debug(f"为 OpenAI 兼容模型 ({self.model_name}) 构建工具调用请求。") payload = { "model": self.model_name, "messages": [{"role": "user", "content": prompt}], **transformed_params, "tools": tools, "tool_choice": transformed_params.get("tool_choice", "auto"), } if "max_completion_tokens" in payload: payload["max_tokens"] = payload.pop("max_completion_tokens") if "max_tokens" not in payload: payload["max_tokens"] = global_config.model_max_output_length # --- 执行请求 --- if payload is None: logger.error("未能构建有效的 API 请求 payload。") return "内部错误:无法构建请求", "", None response = await self._execute_request( endpoint=endpoint, payload=payload, prompt=prompt, # prompt 仍然需要,用于可能的重试 user_id=user_id, request_type="tool_call", **kwargs, # 传递原始 kwargs 以便在重试时重新合并 ) # _default_response_handler 现在会处理 Gemini functionCall 并统一格式 logger.debug(f"模型 {self.model_name} 工具/函数调用返回结果: {response}") if isinstance(response, tuple) and len(response) == 3: content, reasoning_content, final_tool_calls = response # final_tool_calls 已经是统一的 OpenAI 格式 return content, reasoning_content, final_tool_calls elif isinstance(response, tuple) and len(response) == 2: content, reasoning_content = response logger.debug("收到普通响应,无工具/函数调用") return content, reasoning_content, None else: logger.error(f"收到来自 _execute_request/_default_response_handler 的意外响应格式: {response}") return "处理响应时出错", "", None async def get_embedding(self, text: str, user_id: str = "system", **kwargs) -> Union[list, None]: """异步方法:获取文本的embedding向量,支持覆盖参数 (Gemini Embedding 需注意模型名称)""" # (代码不变) if len(text) < 1: logger.debug("该消息没有长度,不再发送获取embedding向量的请求") return None api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"} if self.is_gemini: endpoint = ":embedContent" payload = {"model": f"models/{self.model_name}", "content": {"parts": [{"text": text}]}, **api_kwargs} payload.pop("encoding_format", None) payload.pop("input", None) else: endpoint = "/embeddings" payload = {"model": self.model_name, "input": text, "encoding_format": "float", **api_kwargs} payload.pop("content", None) payload.pop("taskType", None) def embedding_handler(result): # (代码不变) embedding_value = None prompt_tokens = 0 completion_tokens = 0 total_tokens = 0 if self.is_gemini: if "embedding" in result and "value" in result["embedding"]: embedding_value = result["embedding"]["value"] logger.warning(f"模型 {self.model_name} (Gemini Embedding): 响应中未找到明确的 token 使用信息。") else: if "data" in result and len(result["data"]) > 0: embedding_value = result["data"][0].get("embedding", None) usage = result.get("usage", {}) if usage: prompt_tokens = usage.get("prompt_tokens", 0) total_tokens = usage.get("total_tokens", 0) else: logger.warning(f"模型 {self.model_name} (OpenAI Embedding) 的响应中缺少 'usage' 信息。") self._record_usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, user_id=user_id, request_type="embedding", endpoint=endpoint, ) return embedding_value embedding = await self._execute_request( endpoint=endpoint, payload=payload, prompt=text, retry_policy={"max_retries": 2, "base_wait": 6}, response_handler=embedding_handler, user_id=user_id, request_type="embedding", **api_kwargs, ) return embedding def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str: """压缩base64格式的图片到指定大小""" # (代码不变) try: image_data = base64.b64decode(base64_data) if len(image_data) <= target_size * 1.05: logger.info(f"图片大小 {len(image_data) / 1024:.1f}KB 已足够小,无需压缩。") return base64_data img = Image.open(io.BytesIO(image_data)) img_format = img.format original_width, original_height = img.size scale = max(0.2, min(1.0, (target_size / len(image_data)) ** 0.5)) new_width = max(1, int(original_width * scale)) new_height = max(1, int(original_height * scale)) output_buffer = io.BytesIO() save_format = img_format # Default to original format if getattr(img, "is_animated", False) and img.n_frames > 1: frames = [] durations = [] loop = img.info.get("loop", 0) disposal = img.info.get("disposal", 2) logger.info(f"检测到 GIF 动图 ({img.n_frames} 帧),尝试按比例压缩...") for frame_idx in range(img.n_frames): img.seek(frame_idx) current_duration = img.info.get("duration", 100) durations.append(current_duration) new_frame = img.convert("RGBA").copy() resized_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS) frames.append(resized_frame) if frames: frames[0].save( output_buffer, format="GIF", save_all=True, append_images=frames[1:], optimize=False, duration=durations, loop=loop, disposal=disposal, transparency=img.info.get("transparency", None), background=img.info.get("background", None), ) save_format = "GIF" else: logger.warning("未能处理 GIF 帧。") return base64_data else: if img.mode in ("RGBA", "LA") or "transparency" in img.info: resized_img = img.convert("RGBA").resize((new_width, new_height), Image.Resampling.LANCZOS) save_format = "PNG" save_params = {"optimize": True} else: resized_img = img.convert("RGB").resize((new_width, new_height), Image.Resampling.LANCZOS) if img_format and img_format.upper() == "JPEG": save_format = "JPEG" save_params = {"quality": 85, "optimize": True} else: save_format = "PNG" save_params = {"optimize": True} resized_img.save(output_buffer, format=save_format, **save_params) compressed_data = output_buffer.getvalue() logger.success( f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height} ({img.format} -> {save_format})" ) logger.info( f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB (目标: {target_size / 1024:.1f}KB)" ) if len(compressed_data) < len(image_data) * 0.95: return base64.b64encode(compressed_data).decode("utf-8") else: logger.info("压缩效果不明显或反而增大,返回原始图片。") return base64_data except Exception as e: logger.error(f"压缩图片失败: {str(e)}") import traceback logger.error(traceback.format_exc()) return base64_data