diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py
index 8ee21956..7a834e67 100644
--- a/src/plugins/models/utils_model.py
+++ b/src/plugins/models/utils_model.py
@@ -1,23 +1,37 @@
import asyncio
import json
+import random # 添加 random 模块导入
import re
from datetime import datetime
-from typing import Tuple, Union, Dict, Any
+from typing import Tuple, Union, Dict, Any, Set # 引入 Set
import aiohttp
from aiohttp.client import ClientResponse
+# 相对路径导入,根据你的项目结构调整
+# 例如,如果 utils_model.py 在 src/utils/ 下,而 logger 在 src/common/ 下
+# from ..common.logger import get_module_logger
+# from ..common.database import db
+# from ..config.config import global_config
+# 假设它们在期望的路径
from src.common.logger import get_module_logger
+from ...common.database import db
+from ...config.config import global_config
+
+
import base64
from PIL import Image
import io
import os
-from ...common.database import db
-from ...config.config import global_config
+from dotenv import load_dotenv # 导入 dotenv 用于加载 .env 文件 (如果需要直接加载)
+
from rich.traceback import install
install(extra_lines=3)
+# 尝试加载 .env 文件中的环境变量 (如果项目结构需要)
+# load_dotenv() # 如果你的 .env 文件不在标准位置,可能需要指定路径 load_dotenv(dotenv_path='path/to/.env')
+
logger = get_module_logger("model_utils")
@@ -47,20 +61,27 @@ class RequestAbortException(Exception):
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
- def __init__(self, message: str):
+ def __init__(self, message: str, key_identifier: str = None): # 添加 key 标识符
super().__init__(message)
self.message = message
+ self.key_identifier = key_identifier # 存储导致 403 的 key
def __str__(self):
return self.message
+# 新增:用于内部标记需要切换 Key 的异常
+class _SwitchKeyException(Exception):
+ """内部异常,用于标记需要切换Key并且跳过标准等待时间."""
+ pass
+
+
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
- 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
+ 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", # 401 也可能是 Key 无效
402: "账号余额不足",
- 403: "需要实名,或余额不足",
+ 403: "需要实名,或余额不足,或Key无权限", # 扩展 403 的含义
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
@@ -69,29 +90,45 @@ error_code_mapping = {
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
+ """安全地记录请求内容,隐藏敏感信息"""
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
+ # 检查是否为 Gemini 载荷
+ is_gemini_payload = payload and isinstance(payload, dict) and "contents" in payload
+
+ # 创建 payload 的副本进行修改,避免影响原始对象
+ safe_payload = json.loads(json.dumps(payload)) if payload else {}
+
if (
image_base64
- and payload
- and isinstance(payload, dict)
- and "messages" in payload
- and len(payload["messages"]) > 0
+ and safe_payload
+ and isinstance(safe_payload, dict)
):
- if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
- content = payload["messages"][0]["content"]
- if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
- payload["messages"][0]["content"][1]["image_url"]["url"] = (
- f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
- f"{image_base64[:10]}...{image_base64[-10:]}"
- )
- # if isinstance(content, str) and len(content) > 100:
- # payload["messages"][0]["content"] = content[:100]
- return payload
+ # OpenAI 格式处理
+ if "messages" in safe_payload and len(safe_payload["messages"]) > 0:
+ if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
+ content = safe_payload["messages"][0]["content"]
+ if isinstance(content, list) and len(content) > 1 and isinstance(content[1], dict) and "image_url" in content[1]:
+ safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
+ f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
+ f"{image_base64[:10]}...{image_base64[-10:]}"
+ )
+ # Gemini 格式处理 (假设图片在 parts 里)
+ elif is_gemini_payload and "contents" in safe_payload and len(safe_payload["contents"]) > 0:
+ if isinstance(safe_payload["contents"][0], dict) and "parts" in safe_payload["contents"][0]:
+ parts = safe_payload["contents"][0]["parts"]
+ # 查找图片部分 (通常是 inlineData)
+ for i, part in enumerate(parts):
+ if isinstance(part, dict) and "inlineData" in part:
+ # 假设 inlineData 包含 base64 和 mime_type
+ safe_payload["contents"][0]["parts"][i]["inlineData"]["data"] = f"{image_base64[:10]}...{image_base64[-10:]}"
+ break # 只处理第一个找到的图片
+
+ return safe_payload
class LLMRequest:
- # 定义需要转换的模型列表,作为类变量避免重复
+ # 定义需要转换的模型列表,作为类变量避免重复 (OpenAI 特有参数转换)
MODELS_NEEDING_TRANSFORMATION = [
"o1",
"o1-2024-12-17",
@@ -108,33 +145,165 @@ class LLMRequest:
"o4-mini-2025-04-16",
]
+ # 类变量,用于存储运行时发现的已失效 Key (避免在同一次运行中重复尝试)
+ _abandoned_keys_runtime: Set[str] = set()
+
def __init__(self, model: dict, **kwargs):
- # 将大写的配置键转换为小写并从config中获取实际值
- try:
- self.api_key = os.environ[model["key"]]
- self.base_url = os.environ[model["base_url"]]
- except AttributeError as e:
- logger.error(f"原始 model dict 信息:{model}")
- logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
- raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
+ """
+ 初始化 LLMRequest 实例。
+
+ Args:
+ model (dict): 包含模型配置的字典,应包含 'name', 'key', 'base_url' 等键。
+ **kwargs: 其他传递给模型 API 的参数 (如 temperature, max_tokens),作为默认参数。
+ """
+ self.model_key_name = model["key"]
self.model_name: str = model["name"]
self.params = kwargs
-
self.stream = model.get("stream", False)
self.pri_in = model.get("pri_in", 0)
self.pri_out = model.get("pri_out", 0)
+ self.request_type = model.get("request_type", "default")
+
+ try:
+ # --- 加载 API Key 和 Base URL ---
+ raw_api_key_config = os.environ[self.model_key_name]
+ self.base_url = os.environ[model["base_url"]]
+
+ # --- 判断是否为 Gemini 模型 ---
+ self.is_gemini = "googleapis.com" in self.base_url.lower()
+ if self.is_gemini:
+ logger.debug(f"模型 {self.model_name}: 检测到为 Gemini API (Base URL: {self.base_url})")
+ if self.stream:
+ logger.warning(f"模型 {self.model_name}: Gemini 流式输出处理与 OpenAI 不同,暂时强制禁用流式。")
+ self.stream = False
+
+ # --- 解析和过滤 API Keys ---
+ # (代码不变)
+ parsed_keys = []
+ is_list_config = False
+ try:
+ loaded_keys = json.loads(raw_api_key_config)
+ if isinstance(loaded_keys, list):
+ parsed_keys = [str(key) for key in loaded_keys if key]
+ is_list_config = True
+ elif isinstance(loaded_keys, str) and loaded_keys:
+ parsed_keys = [loaded_keys]
+ else:
+ raise ValueError(f"Parsed API key for {self.model_key_name} is not a valid list or string.")
+ except (json.JSONDecodeError, TypeError):
+ if isinstance(raw_api_key_config, list):
+ parsed_keys = [str(key) for key in raw_api_key_config if key]
+ is_list_config = True
+ elif isinstance(raw_api_key_config, str) and raw_api_key_config:
+ parsed_keys = [raw_api_key_config]
+ else:
+ raise ValueError(f"Invalid or empty API key config for {self.model_key_name}: {raw_api_key_config}")
+
+ if not parsed_keys:
+ raise ValueError(f"No valid API keys found for {self.model_key_name}.")
+
+ abandoned_key_name = f"abandon_{self.model_key_name}"
+ abandoned_keys_set = set()
+ raw_abandoned_keys = os.environ.get(abandoned_key_name)
+
+ if raw_abandoned_keys:
+ try:
+ loaded_abandoned = json.loads(raw_abandoned_keys)
+ if isinstance(loaded_abandoned, list):
+ abandoned_keys_set.update(str(key) for key in loaded_abandoned if key)
+ elif isinstance(loaded_abandoned, str) and loaded_abandoned:
+ abandoned_keys_set.add(loaded_abandoned)
+ logger.info(f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' 的废弃 Keys。")
+ except (json.JSONDecodeError, TypeError):
+ if isinstance(raw_abandoned_keys, list):
+ abandoned_keys_set.update(str(key) for key in raw_abandoned_keys if key)
+ logger.info(f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' (直接列表) 的废弃 Keys。")
+ elif isinstance(raw_abandoned_keys, str) and raw_abandoned_keys:
+ abandoned_keys_set.add(raw_abandoned_keys)
+ logger.info(f"模型 {model['name']}: 加载了 1 个来自配置 '{abandoned_key_name}' (字符串) 的废弃 Key。")
+ else:
+ logger.warning(f"无法解析环境变量 '{abandoned_key_name}' 的内容: {raw_abandoned_keys}")
+
+ all_abandoned_keys = abandoned_keys_set.union(LLMRequest._abandoned_keys_runtime)
+ active_keys = [key for key in parsed_keys if key not in all_abandoned_keys]
+
+ if not active_keys:
+ logger.error(f"模型 {model['name']}: 所有为 '{self.model_key_name}' 配置的 Keys 都已被废弃或无效。")
+ raise ValueError(f"No active API keys available for {self.model_key_name} after filtering abandoned keys.")
+
+ if is_list_config and len(active_keys) > 1:
+ self._api_key_config = active_keys
+ logger.info(f"模型 {model['name']}: 初始化完成,可用 Keys: {len(self._api_key_config)} (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。")
+ elif active_keys:
+ self._api_key_config = active_keys[0]
+ logger.info(f"模型 {model['name']}: 初始化完成,使用单个活动 Key (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。")
+ else:
+ raise ValueError(f"Unexpected state: No active keys for {self.model_key_name}.")
+
+
+ # --- 加载代理配置 ---
+ # (代码不变)
+ self.proxy_url = None
+ self.proxy_models_set = set()
+ proxy_host = os.environ.get("PROXY_HOST")
+ proxy_port = os.environ.get("PROXY_PORT")
+ proxy_models_str = os.environ.get("PROXY_MODELS", "")
+
+ if proxy_host and proxy_port:
+ try:
+ int(proxy_port)
+ self.proxy_url = f"http://{proxy_host}:{proxy_port}"
+ logger.info(f"代理已配置: {self.proxy_url}")
+
+ if proxy_models_str:
+ try:
+ cleaned_str = proxy_models_str.strip('\'"')
+ self.proxy_models_set = {model_name.strip() for model_name in cleaned_str.split(',') if model_name.strip()}
+ logger.info(f"以下模型将使用代理: {self.proxy_models_set}")
+ except Exception as e:
+ logger.error(f"解析 PROXY_MODELS ('{proxy_models_str}') 出错: {e}. 代理将不会对特定模型生效。")
+ self.proxy_models_set = set()
+ except ValueError:
+ logger.error(f"无效的代理端口号: {proxy_port}。代理将不被启用。")
+ self.proxy_url = None
+ self.proxy_models_set = set()
+ except Exception as e:
+ logger.error(f"加载代理配置时发生错误: {e}")
+ self.proxy_url = None
+ self.proxy_models_set = set()
+ else:
+ logger.info("未配置代理服务器 (PROXY_HOST 或 PROXY_PORT 未设置)。")
+
+
+ except KeyError as e:
+ # (代码不变)
+ missing_key = str(e).strip("'")
+ if missing_key == self.model_key_name:
+ logger.error(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'")
+ raise ValueError(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'") from e
+ elif missing_key == model["base_url"]:
+ logger.error(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'")
+ raise ValueError(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'") from e
+ else:
+ logger.error(f"配置错误:找不到环境变量 - {str(e)}")
+ raise ValueError(f"配置错误:找不到环境变量 - {str(e)}") from e
+ except AttributeError as e:
+ # (代码不变)
+ logger.error(f"原始 model dict 信息:{model}")
+ logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
+ raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
+ except ValueError as e:
+ # (代码不变)
+ logger.error(f"API Key 或配置初始化错误 for {self.model_key_name}: {str(e)}")
+ raise e
- # 获取数据库实例
self._init_database()
- # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
- self.request_type = kwargs.pop("request_type", "default")
-
@staticmethod
def _init_database():
"""初始化数据库集合"""
+ # (代码不变)
try:
- # 创建llm_usage集合的索引
db.llm_usage.create_index([("timestamp", 1)])
db.llm_usage.create_index([("model_name", 1)])
db.llm_usage.create_index([("user_id", 1)])
@@ -151,25 +320,21 @@ class LLMRequest:
request_type: str = None,
endpoint: str = "/chat/completions",
):
- """记录模型使用情况到数据库
- Args:
- prompt_tokens: 输入token数
- completion_tokens: 输出token数
- total_tokens: 总token数
- user_id: 用户ID,默认为system
- request_type: 请求类型(chat/embedding/image/topic/schedule)
- endpoint: API端点
- """
- # 如果 request_type 为 None,则使用实例变量中的值
- if request_type is None:
- request_type = self.request_type
+ """记录模型使用情况到数据库"""
+ # (代码不变)
+ actual_endpoint = endpoint
+ if self.is_gemini:
+ if endpoint == "/embeddings":
+ actual_endpoint = ":embedContent"
+ else:
+ actual_endpoint = ":generateContent"
try:
usage_data = {
"model_name": self.model_name,
"user_id": user_id,
- "request_type": request_type,
- "endpoint": endpoint,
+ "request_type": request_type or self.request_type,
+ "endpoint": actual_endpoint,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
@@ -180,7 +345,7 @@ class LLMRequest:
db.llm_usage.insert_one(usage_data)
logger.trace(
f"Token使用情况 - 模型: {self.model_name}, "
- f"用户: {user_id}, 类型: {request_type}, "
+ f"用户: {user_id}, 类型: {request_type or self.request_type}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
@@ -188,17 +353,8 @@ class LLMRequest:
logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
- """计算API调用成本
- 使用模型的pri_in和pri_out价格计算输入和输出的成本
-
- Args:
- prompt_tokens: 输入token数量
- completion_tokens: 输出token数量
-
- Returns:
- float: 总成本(元)
- """
- # 使用模型的pri_in和pri_out计算成本
+ """计算API调用成本"""
+ # (代码不变)
input_cost = (prompt_tokens / 1000000) * self.pri_in
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)
@@ -211,19 +367,10 @@ class LLMRequest:
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
+ **kwargs: Any
) -> Dict[str, Any]:
- """配置请求参数
- Args:
- endpoint: API端点路径 (如 "chat/completions")
- prompt: prompt文本
- image_base64: 图片的base64编码
- image_format: 图片格式
- payload: 请求体数据
- retry_policy: 自定义重试策略
- request_type: 请求类型
- """
-
- # 合并重试策略
+ """配置请求参数,合并实例参数和调用时参数"""
+ # (代码不变)
default_retry = {
"max_retries": 3,
"base_wait": 10,
@@ -232,25 +379,34 @@ class LLMRequest:
}
policy = {**default_retry, **(retry_policy or {})}
- api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
+ actual_endpoint = endpoint
+ if self.is_gemini:
+ action = endpoint.lstrip('/')
+ api_url = f"{self.base_url.rstrip('/')}/{self.model_name}{action}"
+ stream_mode = False
+ else:
+ api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
+ stream_mode = self.stream
- stream_mode = self.stream
+ call_params = {k: v for k, v in kwargs.items() if k != 'request_type'}
+ merged_params = {**self.params, **call_params}
- # 构建请求体
- if image_base64:
- payload = await self._build_payload(prompt, image_base64, image_format)
- elif payload is None:
- payload = await self._build_payload(prompt)
+ if payload is None:
+ payload = await self._build_payload(prompt, image_base64, image_format, merged_params)
+ else:
+ logger.debug("使用外部提供的 payload,忽略单次调用参数合并。")
+
+
+ if not self.is_gemini and stream_mode:
+ payload["stream"] = merged_params.get("stream", stream_mode)
- if stream_mode:
- payload["stream"] = stream_mode
return {
"policy": policy,
"payload": payload,
"api_url": api_url,
- "stream_mode": stream_mode,
- "image_base64": image_base64, # 保留必要的exception处理所需的原始数据
+ "stream_mode": payload.get("stream", False),
+ "image_base64": image_base64,
"image_format": image_format,
"prompt": prompt,
}
@@ -266,87 +422,255 @@ class LLMRequest:
response_handler: callable = None,
user_id: str = "system",
request_type: str = None,
+ **kwargs: Any
):
- """统一请求执行入口
- Args:
- endpoint: API端点路径 (如 "chat/completions")
- prompt: prompt文本
- image_base64: 图片的base64编码
- image_format: 图片格式
- payload: 请求体数据
- retry_policy: 自定义重试策略
- response_handler: 自定义响应处理器
- user_id: 用户ID
- request_type: 请求类型
- """
- # 获取请求配置
+ """统一请求执行入口, 支持列表 key 切换、代理和单次调用参数覆盖"""
+ # (代码不变)
+ final_request_type = request_type or kwargs.get('request_type') or self.request_type
+ api_kwargs = {k: v for k, v in kwargs.items() if k != 'request_type'}
+
request_content = await self._prepare_request(
- endpoint, prompt, image_base64, image_format, payload, retry_policy
+ endpoint, prompt, image_base64, image_format, payload, retry_policy, **api_kwargs
)
- if request_type is None:
- request_type = self.request_type
- for retry in range(request_content["policy"]["max_retries"]):
- try:
- # 使用上下文管理器处理会话
- headers = await self._build_headers()
- # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
- if request_content["stream_mode"]:
- headers["Accept"] = "text/event-stream"
- async with aiohttp.ClientSession() as session:
- async with session.post(
- request_content["api_url"], headers=headers, json=request_content["payload"]
- ) as response:
- handled_result = await self._handle_response(
- response, request_content, retry, response_handler, user_id, request_type, endpoint
- )
- return handled_result
- except Exception as e:
- handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
- retry += count_delta # 降级不计入重试次数
- if handled_payload:
- # 如果降级成功,重新构建请求体
- request_content["payload"] = handled_payload
- continue
-
- logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
- raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
-
- async def _handle_response(
- self,
- response: ClientResponse,
- request_content: Dict[str, Any],
- retry_count: int,
- response_handler: callable,
- user_id,
- request_type,
- endpoint,
- ) -> Union[Dict[str, Any], None]:
policy = request_content["policy"]
+ api_url = request_content["api_url"]
+ actual_payload = request_content["payload"]
stream_mode = request_content["stream_mode"]
- if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
- await self._handle_error_response(response, retry_count, policy)
- return None
- response.raise_for_status()
- result = {}
- if stream_mode:
- # 将流式输出转化为非流式输出
- result = await self._handle_stream_output(response)
+ use_proxy = False
+ current_proxy_url = None
+ if self.proxy_url and self.model_name in self.proxy_models_set:
+ use_proxy = True
+ current_proxy_url = self.proxy_url
+ logger.debug(f"模型 {self.model_name}: 将通过代理 {current_proxy_url} 发送请求。")
+ elif self.proxy_url:
+ logger.debug(f"模型 {self.model_name}: 配置了代理,但此模型不在 PROXY_MODELS 列表中,将不使用代理。")
else:
- result = await response.json()
- return (
- response_handler(result)
- if response_handler
- else self._default_response_handler(result, user_id, request_type, endpoint)
- )
+ logger.debug(f"模型 {self.model_name}: 未配置或不为此模型使用代理。")
+
+ current_key = None
+ keys_failed_429 = set()
+ keys_abandoned_runtime = set()
+ key_switch_limit_429 = 3
+ key_switch_limit_403 = 3
+
+ available_keys_pool = []
+ is_key_list = isinstance(self._api_key_config, list)
+
+ if is_key_list:
+ available_keys_pool = list(self._api_key_config)
+ if not available_keys_pool:
+ logger.error(f"模型 {self.model_name}: 初始化后无可用活动 Keys。")
+ raise ValueError(f"模型 {self.model_name}: 无可用活动 Keys。")
+ random.shuffle(available_keys_pool)
+ key_switch_limit_429 = min(key_switch_limit_429, len(available_keys_pool))
+ key_switch_limit_403 = min(key_switch_limit_403, len(available_keys_pool))
+ logger.info(f"模型 {self.model_name}: Key 列表模式,启用 429/403 自动切换(429上限: {key_switch_limit_429}, 403上限: {key_switch_limit_403})。")
+ elif isinstance(self._api_key_config, str):
+ available_keys_pool = [self._api_key_config]
+ key_switch_limit_429 = 1
+ key_switch_limit_403 = 1
+ else:
+ logger.error(f"模型 {self.model_name}: 无效的 API Key 配置类型在执行时遇到: {type(self._api_key_config)}")
+ raise TypeError(f"模型 {self.model_name}: 无效的 API Key 配置类型")
+
+ last_exception = None
+
+ for attempt in range(policy["max_retries"]):
+ if available_keys_pool:
+ current_key = available_keys_pool.pop(0)
+ elif current_key:
+ logger.debug(f"模型 {self.model_name}: 无新 Key 可用或为单 Key 模式,将使用 Key ...{current_key[-4:]} 进行重试 (第 {attempt + 1} 次尝试)")
+ else:
+ if not self._api_key_config or all(k in LLMRequest._abandoned_keys_runtime for k in self._api_key_config if isinstance(self._api_key_config, list)) or (isinstance(self._api_key_config, str) and self._api_key_config in LLMRequest._abandoned_keys_runtime):
+ final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
+ logger.critical(final_error_msg)
+ raise PermissionDeniedException(final_error_msg)
+ else:
+ raise RuntimeError(f"模型 {self.model_name}: 无法选择 API key (第 {attempt + 1} 次尝试)")
+
+ logger.debug(f"模型 {self.model_name}: 尝试使用 Key: ...{current_key[-4:]} (总第 {attempt + 1} 次尝试)")
+
+ try:
+ headers = await self._build_headers(current_key)
+ if not self.is_gemini and stream_mode:
+ headers["Accept"] = "text/event-stream"
+
+ async with aiohttp.ClientSession() as session:
+ post_kwargs = {
+ "headers": headers,
+ "json": actual_payload,
+ "timeout": 60
+ }
+ if use_proxy:
+ post_kwargs["proxy"] = current_proxy_url
+
+ async with session.post(api_url, **post_kwargs) as response:
+
+ if response.status == 429 and is_key_list:
+ logger.warning(f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 429 错误。")
+ if current_key not in keys_failed_429:
+ keys_failed_429.add(current_key)
+ logger.info(f" (因 429 已失败 {len(keys_failed_429)}/{key_switch_limit_429} 个不同 Key)")
+ if available_keys_pool and len(keys_failed_429) < key_switch_limit_429:
+ logger.info(f" 尝试因 429 切换到下一个可用 Key...")
+ raise _SwitchKeyException()
+ else:
+ logger.warning(f" 无更多 Key 可因 429 切换或已达上限。")
+ else:
+ logger.warning(f" Key ...{current_key[-4:]} 再次遇到 429,按标准重试流程。")
+
+ elif response.status == 403 and is_key_list:
+ logger.error(f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 403 (权限拒绝) 错误。")
+ if current_key not in keys_abandoned_runtime:
+ keys_abandoned_runtime.add(current_key)
+ LLMRequest._abandoned_keys_runtime.add(current_key)
+ logger.critical(f" !! Key ...{current_key[-4:]} 已添加到运行时废弃列表。请考虑将其移至配置中的 'abandon_{self.model_key_name}' !!")
+ if current_key in available_keys_pool: available_keys_pool.remove(current_key)
+ if available_keys_pool and len(keys_abandoned_runtime) < key_switch_limit_403:
+ logger.info(f" 尝试因 403 切换到下一个可用 Key...")
+ raise _SwitchKeyException()
+ else:
+ logger.error(f" 无更多 Key 可因 403 切换或已达上限。将中止请求。")
+ await response.read()
+ raise PermissionDeniedException(f"Key ...{current_key[-4:]} 权限被拒,且无其他可用 Key 切换。", key_identifier=current_key)
+ else:
+ logger.error(f" Key ...{current_key[-4:]} 再次遇到 403,这不应发生。中止请求。")
+ await response.read()
+ raise PermissionDeniedException(f"Key ...{current_key[-4:]} 重复遇到 403。", key_identifier=current_key)
+
+ elif response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
+ await self._handle_error_response(response, attempt, policy, current_key)
+
+ if response.status in policy["retry_codes"] and attempt < policy["max_retries"] - 1:
+ if response.status not in [429, 403]:
+ wait_time = policy["base_wait"] * (2**attempt)
+ logger.warning(f"模型 {self.model_name}: 遇到可重试错误 {response.status}, 等待 {wait_time} 秒后重试...")
+ await asyncio.sleep(wait_time)
+ last_exception = RuntimeError(f"重试错误 {response.status}")
+ continue
+
+ if response.status in policy["abort_codes"] or (response.status in policy["retry_codes"] and attempt >= policy["max_retries"] - 1):
+ if attempt >= policy["max_retries"] - 1 and response.status in policy["retry_codes"]:
+ logger.error(f"模型 {self.model_name}: 达到最大重试次数,最后一次尝试仍为可重试错误 {response.status}。")
+ await self._handle_error_response(response, attempt, policy, current_key)
+ await response.read()
+ final_error_msg = f"请求中止或达到最大重试次数,最终状态码: {response.status}"
+ logger.error(final_error_msg)
+ raise RequestAbortException(final_error_msg, response)
+
+ response.raise_for_status()
+ result = {}
+ if not self.is_gemini and stream_mode:
+ result = await self._handle_stream_output(response)
+ else:
+ result = await response.json()
+
+ return (
+ response_handler(result)
+ if response_handler
+ else self._default_response_handler(result, user_id, final_request_type, endpoint)
+ )
+
+ except _SwitchKeyException:
+ # (代码不变)
+ last_exception = _SwitchKeyException()
+ logger.debug("捕获到 _SwitchKeyException,立即进行下一次尝试。")
+ continue
+ except PermissionDeniedException as e:
+ # (代码不变)
+ logger.error(f"模型 {self.model_name}: 因权限拒绝 (403) 中止请求: {e}")
+ if is_key_list and not available_keys_pool and e.key_identifier:
+ logger.critical(f" 中止原因是 Key ...{e.key_identifier[-4:]} 触发 403 后已无其他 Key 可用。")
+ raise e
+ except aiohttp.ClientProxyConnectionError as e:
+ # (代码不变)
+ logger.error(f"代理连接错误: {e} (代理地址: {current_proxy_url})")
+ last_exception = e
+ if attempt >= policy["max_retries"] - 1:
+ raise RuntimeError(f"代理连接失败达到最大重试次数: {e}") from e
+ wait_time = policy["base_wait"] * (2**attempt)
+ logger.warning(f"模型 {self.model_name}: 代理连接错误,等待 {wait_time} 秒后重试...")
+ await asyncio.sleep(wait_time)
+ continue
+ except aiohttp.ClientConnectorError as e:
+ # (代码不变)
+ logger.error(f"网络连接错误: {e} (URL: {api_url}, 代理: {current_proxy_url})")
+ last_exception = e
+ if attempt >= policy["max_retries"] - 1:
+ raise RuntimeError(f"网络连接失败达到最大重试次数: {e}") from e
+ wait_time = policy["base_wait"] * (2**attempt)
+ logger.warning(f"模型 {self.model_name}: 网络连接错误,等待 {wait_time} 秒后重试...")
+ await asyncio.sleep(wait_time)
+ continue
+ except (PayLoadTooLargeError, RequestAbortException) as e:
+ # (代码不变)
+ logger.error(f"模型 {self.model_name}: 请求处理中遇到关键错误,将中止: {e}")
+ raise e
+ except Exception as e:
+ # (代码不变)
+ last_exception = e
+ logger.warning(f"模型 {self.model_name}: 第 {attempt + 1} 次尝试中发生非 HTTP 错误: {str(e.__class__.__name__)} - {str(e)}")
+
+ if attempt >= policy["max_retries"] - 1:
+ logger.error(f"模型 {self.model_name}: 达到最大重试次数 ({policy['max_retries']}),因非 HTTP 错误失败。")
+ else:
+ try:
+ temp_request_content = {
+ "policy": policy,
+ "payload": actual_payload,
+ "api_url": api_url,
+ "stream_mode": stream_mode,
+ "image_base64": image_base64,
+ "image_format": image_format,
+ "prompt": prompt,
+ }
+ handled_payload, count_delta = await self._handle_exception(
+ e, attempt, temp_request_content, merged_params=api_kwargs
+ )
+ if handled_payload:
+ actual_payload = handled_payload
+ logger.info(f"模型 {self.model_name}: 异常处理更新了 payload,将使用当前 Key 重试。")
+
+ wait_time = policy["base_wait"] * (2**attempt)
+ logger.warning(f"模型 {self.model_name}: 等待 {wait_time} 秒后重试...")
+ await asyncio.sleep(wait_time)
+ continue
+
+ except (RequestAbortException, PermissionDeniedException) as abort_exception:
+ logger.error(f"模型 {self.model_name}: 异常处理判断需要中止请求: {abort_exception}")
+ raise abort_exception
+ except RuntimeError as rt_error:
+ logger.error(f"模型 {self.model_name}: 异常处理遇到运行时错误: {rt_error}")
+ raise rt_error
+
+ # --- 循环结束 ---
+ # (代码不变)
+ logger.error(f"模型 {self.model_name}: 所有重试尝试 ({policy['max_retries']} 次) 均失败。")
+ if last_exception:
+ if isinstance(last_exception, PermissionDeniedException):
+ logger.error(f"最后遇到的错误是权限拒绝: {str(last_exception)}")
+ raise last_exception
+ logger.error(f"最后遇到的错误: {str(last_exception.__class__.__name__)} - {str(last_exception)}")
+ raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API 请求失败。最后错误: {str(last_exception)}") from last_exception
+ else:
+ if not available_keys_pool and keys_abandoned_runtime:
+ final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
+ logger.critical(final_error_msg)
+ raise PermissionDeniedException(final_error_msg)
+ else:
+ raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API 请求失败,原因未知。")
+
async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
+ """处理 OpenAI 兼容的流式输出"""
+ # (代码不变)
flag_delta_content_finished = False
accumulated_content = ""
- usage = None # 初始化usage变量,避免未定义错误
+ usage = None
reasoning_content = ""
content = ""
- tool_calls = None # 初始化工具调用变量
+ tool_calls = None
async for line_bytes in response.content:
try:
@@ -362,7 +686,7 @@ class LLMRequest:
if flag_delta_content_finished:
chunk_usage = chunk.get("usage", None)
if chunk_usage:
- usage = chunk_usage # 获取token用量
+ usage = chunk_usage
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
@@ -370,15 +694,27 @@ class LLMRequest:
delta_content = ""
accumulated_content += delta_content
- # 提取工具调用信息
if "tool_calls" in delta:
if tool_calls is None:
- tool_calls = delta["tool_calls"]
+ tool_calls = []
+ for tc in delta["tool_calls"]:
+ new_tc = dict(tc)
+ if 'function' in new_tc and 'arguments' not in new_tc['function']:
+ new_tc['function']['arguments'] = ""
+ tool_calls.append(new_tc)
else:
- # 合并工具调用信息
- tool_calls.extend(delta["tool_calls"])
+ for i, tc_delta in enumerate(delta["tool_calls"]):
+ if i < len(tool_calls) and 'function' in tc_delta and 'arguments' in tc_delta['function']:
+ if 'arguments' in tool_calls[i]['function']:
+ tool_calls[i]['function']['arguments'] += tc_delta['function']['arguments']
+ else:
+ tool_calls[i]['function']['arguments'] = tc_delta['function']['arguments']
+ elif i >= len(tool_calls):
+ new_tc = dict(tc_delta)
+ if 'function' in new_tc and 'arguments' not in new_tc['function']:
+ new_tc['function']['arguments'] = ""
+ tool_calls.append(new_tc)
- # 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason")
if delta.get("reasoning_content", None):
reasoning_content += delta["reasoning_content"]
@@ -387,37 +723,37 @@ class LLMRequest:
if chunk_usage:
usage = chunk_usage
break
- # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
flag_delta_content_finished = True
+ except json.JSONDecodeError as e:
+ logger.error(f"模型 {self.model_name} 解析流式 JSON 错误: {e} - data: '{data_str}'")
except Exception as e:
- logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
+ logger.exception(f"模型 {self.model_name} 解析流式输出块错误: {str(e)}")
+ except UnicodeDecodeError as e:
+ logger.warning(f"模型 {self.model_name} 流式输出解码错误: {e} - bytes: {line_bytes[:50]}...")
except Exception as e:
if isinstance(e, GeneratorExit):
log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
else:
log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
logger.warning(log_content)
- # 确保资源被正确清理
try:
await response.release()
except Exception as cleanup_error:
logger.error(f"清理资源时发生错误: {cleanup_error}")
- # 返回已经累积的内容
content = accumulated_content
- if not content:
+ break
+ if not content and accumulated_content:
content = accumulated_content
think_match = re.search(r"(.*?)", content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r".*?", "", content, flags=re.DOTALL).strip()
- # 构建消息对象
message = {
"content": content,
"reasoning_content": reasoning_content,
}
- # 如果有工具调用,添加到消息中
if tool_calls:
message["tool_calls"] = tool_calls
@@ -428,285 +764,358 @@ class LLMRequest:
return result
async def _handle_error_response(
- self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]
- ) -> Union[Dict[str, any]]:
- if response.status in policy["retry_codes"]:
- wait_time = policy["base_wait"] * (2**retry_count)
- logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
- if response.status == 413:
- logger.warning("请求体过大,尝试压缩...")
+ self, response: ClientResponse, retry_count: int, policy: Dict[str, Any], current_key: str = None
+ ) -> None:
+ """处理 HTTP 错误响应 (区分 403 和其他错误)"""
+ # (代码不变)
+ status = response.status
+ try:
+ error_text = await response.text()
+ except Exception as e:
+ error_text = f"(无法读取响应体: {e})"
+
+ if status == 403:
+ logger.error(
+ f"模型 {self.model_name}: 遇到 403 (权限拒绝) 错误。Key: ...{current_key[-4:] if current_key else 'N/A'}. "
+ f"响应: {error_text[:200]}"
+ )
+ raise PermissionDeniedException(f"模型禁止访问 ({status})", key_identifier=current_key)
+
+ elif status in policy["retry_codes"] and status != 429:
+ if status == 413:
+ logger.warning(f"模型 {self.model_name}: 错误码 413 (Payload Too Large)。Key: ...{current_key[-4:] if current_key else 'N/A'}. 尝试压缩...")
raise PayLoadTooLargeError("请求体过大")
- elif response.status in [500, 503]:
+ elif status in [500, 503]:
logger.error(
- f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
+ f"模型 {self.model_name}: 服务器内部错误或过载 ({status})。Key: ...{current_key[-4:] if current_key else 'N/A'}. "
+ f"响应: {error_text[:200]}"
)
- raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
+ return
else:
- logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
- raise RuntimeError("请求限制(429)")
- elif response.status in policy["abort_codes"]:
- if response.status != 403:
- raise RequestAbortException("请求出现错误,中断处理", response)
- else:
- raise PermissionDeniedException("模型禁止访问")
+ logger.warning(f"模型 {self.model_name}: 遇到可重试错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}")
+ return
+
+ elif status in policy["abort_codes"]:
+ logger.error(
+ f"模型 {self.model_name}: 遇到需要中止的错误码: {status} - {error_code_mapping.get(status, '未知错误')}. "
+ f"Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}"
+ )
+ raise RequestAbortException(f"请求出现错误 {status},中止处理", response)
+ else:
+ logger.error(f"模型 {self.model_name}: 遇到未明确处理的错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}")
+ try:
+ response.raise_for_status()
+ raise RequestAbortException(f"未处理的错误状态码 {status}", response)
+ except aiohttp.ClientResponseError as e:
+ raise RequestAbortException(f"未处理的错误状态码 {status}: {e.message}", response) from e
+
async def _handle_exception(
- self, exception, retry_count: int, request_content: Dict[str, Any]
+ self, exception, retry_count: int, request_content: Dict[str, Any], merged_params: Dict[str, Any] = None
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
+ """处理非 HTTP 错误,支持使用合并后的参数重建 payload"""
+ # (代码不变)
policy = request_content["policy"]
payload = request_content["payload"]
wait_time = policy["base_wait"] * (2**retry_count)
keep_request = False
if retry_count < policy["max_retries"] - 1:
keep_request = True
- if isinstance(exception, RequestAbortException):
- response = exception.response
- logger.error(
- f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
- )
- # 尝试获取并记录服务器返回的详细错误信息
- try:
- error_json = await response.json()
- if error_json and isinstance(error_json, list) and len(error_json) > 0:
- # 处理多个错误的情况
- for error_item in error_json:
- if "error" in error_item and isinstance(error_item["error"], dict):
- error_obj: dict = error_item["error"]
- error_code = error_obj.get("code")
- error_message = error_obj.get("message")
- error_status = error_obj.get("status")
- logger.error(
- f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
- )
- elif isinstance(error_json, dict) and "error" in error_json:
- # 处理单个错误对象的情况
- error_obj = error_json.get("error", {})
- error_code = error_obj.get("code")
- error_message = error_obj.get("message")
- error_status = error_obj.get("status")
- logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}")
+
+ params_for_rebuild = merged_params if merged_params is not None else payload
+
+ if isinstance(exception, PayLoadTooLargeError):
+ if keep_request:
+ logger.warning("请求体过大 (PayLoadTooLargeError),尝试压缩图片...")
+ image_base64 = request_content.get("image_base64")
+ if image_base64:
+ compressed_image_base64 = compress_base64_image_by_scale(image_base64)
+ if compressed_image_base64 != image_base64:
+ new_payload = await self._build_payload(
+ request_content["prompt"], compressed_image_base64, request_content["image_format"], params_for_rebuild
+ )
+ logger.info("图片压缩成功,将使用压缩后的图片重试。")
+ return new_payload, 0
+ else:
+ logger.warning("图片压缩未改变大小或失败。")
else:
- # 记录原始错误响应内容
- logger.error(f"服务器错误响应: {error_json}")
- except Exception as e:
- logger.warning(f"无法解析服务器错误响应: {str(e)}")
- raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
-
- elif isinstance(exception, PermissionDeniedException):
- # 只针对硅基流动的V3和R1进行降级处理
- if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
- old_model_name = self.model_name
- self.model_name = self.model_name[4:] # 移除"Pro/"前缀
- logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
-
- # 对全局配置进行更新
- if global_config.llm_normal.get("name") == old_model_name:
- global_config.llm_normal["name"] = self.model_name
- logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
- if global_config.llm_reasoning.get("name") == old_model_name:
- global_config.llm_reasoning["name"] = self.model_name
- logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
-
- if payload and "model" in payload:
- payload["model"] = self.model_name
-
- await asyncio.sleep(wait_time)
- return payload, -1
- raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}")
-
- elif isinstance(exception, PayLoadTooLargeError):
- if keep_request:
- image_base64 = request_content["image_base64"]
- compressed_image_base64 = compress_base64_image_by_scale(image_base64)
- new_payload = await self._build_payload(
- request_content["prompt"], compressed_image_base64, request_content["image_format"]
- )
- return new_payload, 0
- else:
+ logger.warning("请求体过大但请求中不包含图片,无法压缩。")
return None, 0
+ else:
+ logger.error("达到最大重试次数,请求体仍然过大。")
+ raise RuntimeError("请求体过大,压缩或重试后仍然失败。") from exception
- elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError):
+ elif isinstance(exception, (aiohttp.ClientError, asyncio.TimeoutError)):
if keep_request:
- logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}")
- await asyncio.sleep(wait_time)
+ logger.error(f"模型 {self.model_name} 网络错误: {str(exception)}")
return None, 0
else:
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
- raise RuntimeError(f"网络请求失败: {str(exception)}")
+ raise RuntimeError(f"网络请求失败: {str(exception)}") from exception
elif isinstance(exception, aiohttp.ClientResponseError):
- # 处理aiohttp抛出的,除了policy中的status的响应错误
if keep_request:
logger.error(
- f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}"
+ f"模型 {self.model_name} HTTP响应错误 (未被策略覆盖): 状态码: {exception.status}, 错误: {exception.message}"
)
try:
- error_text = await exception.response.text()
- error_json = json.loads(error_text)
- if isinstance(error_json, list) and len(error_json) > 0:
- # 处理多个错误的情况
- for error_item in error_json:
- if "error" in error_item and isinstance(error_item["error"], dict):
- error_obj = error_item["error"]
- logger.error(
- f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
- f"状态={error_obj.get('status')}, "
- f"消息={error_obj.get('message')}"
- )
- elif isinstance(error_json, dict) and "error" in error_json:
- error_obj = error_json.get("error", {})
- logger.error(
- f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
- f"状态={error_obj.get('status')}, "
- f"消息={error_obj.get('message')}"
- )
- else:
- logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
- except (json.JSONDecodeError, TypeError) as json_err:
- logger.warning(
- f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
- )
+ error_text = await exception.response.text() if hasattr(exception, 'response') else str(exception)
+ logger.error(f"服务器错误响应详情: {error_text[:500]}")
except Exception as parse_err:
- logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
-
- await asyncio.sleep(wait_time)
+ logger.warning(f"无法解析服务器错误响应内容: {str(parse_err)}")
return None, 0
else:
logger.critical(
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
)
- # 安全地检查和记录请求详情
+ current_key_placeholder = request_content.get("current_key", "******")
handled_payload = await _safely_record(request_content, payload)
- logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
+ logger.critical(f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}")
raise RuntimeError(
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
- )
+ ) from exception
else:
if keep_request:
- logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}")
- await asyncio.sleep(wait_time)
+ logger.error(f"模型 {self.model_name} 遇到未知错误: {str(exception.__class__.__name__)} - {str(exception)}")
return None, 0
else:
- logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
- # 安全地检查和记录请求详情
+ logger.critical(f"模型 {self.model_name} 请求因未知错误失败: {str(exception.__class__.__name__)} - {str(exception)}")
+ current_key_placeholder = request_content.get("current_key", "******")
handled_payload = await _safely_record(request_content, payload)
- logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
- raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
+ logger.critical(f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}")
+ raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") from exception
- async def _transform_parameters(self, params: dict) -> dict:
- """
- 根据模型名称转换参数:
- - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
- 并将 'max_tokens' 重命名为 'max_completion_tokens'
- """
- # 复制一份参数,避免直接修改原始数据
- new_params = dict(params)
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
- # 删除 'temperature' 参数(如果存在)
+ async def _transform_parameters(self, merged_params: dict) -> dict:
+ """根据模型名称转换合并后的参数,并移除内部参数"""
+ new_params = dict(merged_params)
+
+ # --- 移除内部使用的参数 ---
+ new_params.pop("request_type", None) # 移除 request_type
+ # 如果还有其他内部参数,也在这里移除
+ # new_params.pop("internal_param_name", None)
+
+ # --- 模型特定参数转换 ---
+ if not self.is_gemini and self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
+ # OpenAI 特有转换 (示例)
new_params.pop("temperature", None)
- # 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
if "max_tokens" in new_params:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
+ elif self.is_gemini:
+ # Gemini 参数转换
+ gen_config = new_params.get("generationConfig", {})
+ if "temperature" in new_params:
+ gen_config["temperature"] = new_params.pop("temperature")
+ if "max_tokens" in new_params:
+ gen_config["maxOutputTokens"] = new_params.pop("max_tokens")
+ if "top_p" in new_params:
+ gen_config["topP"] = new_params.pop("top_p")
+ if "top_k" in new_params:
+ gen_config["topK"] = new_params.pop("top_k")
+ # ... 其他 Gemini 特定参数 ...
+
+ if gen_config:
+ new_params["generationConfig"] = gen_config
+
+ # 移除 OpenAI 特有的顶层参数
+ new_params.pop("frequency_penalty", None)
+ new_params.pop("presence_penalty", None)
+ new_params.pop("max_completion_tokens", None)
+
return new_params
- async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
- """构建请求体"""
- # 复制一份参数,避免直接修改 self.params
- params_copy = await self._transform_parameters(self.params)
- if image_base64:
- messages = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": prompt},
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
- },
- ],
- }
- ]
+ async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None, merged_params: dict = None) -> dict:
+ """构建请求体 (区分 Gemini 和 OpenAI),使用合并和转换后的参数"""
+ # (代码不变)
+ if merged_params is None:
+ merged_params = self.params
+
+ params_copy = await self._transform_parameters(merged_params)
+
+ if self.is_gemini:
+ parts = []
+ if prompt:
+ parts.append({"text": prompt})
+ if image_base64:
+ mime_type = f"image/{image_format.lower() if image_format else 'jpeg'}"
+ parts.append({
+ "inlineData": {
+ "mimeType": mime_type,
+ "data": image_base64
+ }
+ })
+ payload = {
+ "contents": [{"parts": parts}],
+ **params_copy
+ }
+ payload.pop("model", None)
+
else:
- messages = [{"role": "user", "content": prompt}]
- payload = {
- "model": self.model_name,
- "messages": messages,
- **params_copy,
- }
- if "max_tokens" not in payload and "max_completion_tokens" not in payload:
- payload["max_tokens"] = global_config.model_max_output_length
- # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
- payload["max_completion_tokens"] = payload.pop("max_tokens")
+ if image_base64:
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": prompt},
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64}"},
+ },
+ ],
+ }
+ ]
+ else:
+ messages = [{"role": "user", "content": prompt}]
+
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ **params_copy,
+ }
+ if "max_tokens" not in payload and "max_completion_tokens" not in payload:
+ if "max_tokens" not in params_copy and "max_completion_tokens" not in params_copy:
+ payload["max_tokens"] = global_config.model_max_output_length
+ if "max_completion_tokens" in payload:
+ payload["max_tokens"] = payload.pop("max_completion_tokens")
+
return payload
def _default_response_handler(
self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
) -> Tuple:
- """默认响应解析"""
- if "choices" in result and result["choices"]:
- message = result["choices"][0]["message"]
- content = message.get("content", "")
- content, reasoning = self._extract_reasoning(content)
- reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
- if not reasoning_content:
- reasoning_content = message.get("reasoning_content", "")
- if not reasoning_content:
- reasoning_content = reasoning
+ """默认响应解析 (区分 Gemini 和 OpenAI)"""
+ # (代码不变)
+ content = "没有返回结果"
+ reasoning_content = ""
+ tool_calls = None
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
- # 提取工具调用信息
- tool_calls = message.get("tool_calls", None)
+ if self.is_gemini:
+ try:
+ if "candidates" in result and result["candidates"]:
+ candidate = result["candidates"][0]
+ if "content" in candidate and "parts" in candidate["content"] and candidate["content"]["parts"]:
+ text_parts = [part.get("text", "") for part in candidate["content"]["parts"] if "text" in part]
+ raw_content = "".join(text_parts).strip()
+ content, reasoning = self._extract_reasoning(raw_content)
+ reasoning_content = reasoning
+ else:
+ content = "Gemini响应中缺少 content 或 parts"
+ logger.warning(f"模型 {self.model_name}: Gemini 响应格式不完整 (缺少 content/parts): {result}")
- # 记录token使用情况
- usage = result.get("usage", {})
- if usage:
- prompt_tokens = usage.get("prompt_tokens", 0)
- completion_tokens = usage.get("completion_tokens", 0)
- total_tokens = usage.get("total_tokens", 0)
- self._record_usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- user_id=user_id,
- request_type=request_type if request_type is not None else self.request_type,
- endpoint=endpoint,
- )
+ finish_reason = candidate.get("finishReason")
+ if finish_reason == "SAFETY":
+ logger.warning(f"模型 {self.model_name}: Gemini 响应因安全设置被阻止。")
+ content = "响应内容因安全原因被过滤。"
+ elif finish_reason == "RECITATION":
+ logger.warning(f"模型 {self.model_name}: Gemini 响应因引用限制被阻止。")
+ content = "响应内容因引用限制被过滤。"
+ elif finish_reason == "OTHER":
+ logger.warning(f"模型 {self.model_name}: Gemini 响应因未知原因停止。")
- # 只有当tool_calls存在且不为空时才返回
- if tool_calls:
- logger.debug(f"检测到工具调用: {tool_calls}")
- return content, reasoning_content, tool_calls
+ usage = result.get("usageMetadata", {})
+ if usage:
+ prompt_tokens = usage.get("promptTokenCount", 0)
+ completion_tokens = usage.get("candidatesTokenCount", 0)
+ total_tokens = usage.get("totalTokenCount", 0)
+ if completion_tokens == 0 and total_tokens > 0:
+ completion_tokens = total_tokens - prompt_tokens
+ else:
+ logger.warning(f"模型 {self.model_name} (Gemini) 的响应中缺少 'usageMetadata' 信息。")
+
+ except Exception as e:
+ logger.error(f"解析 Gemini 响应出错: {e} - 响应: {result}")
+ content = "解析 Gemini 响应时出错"
+
+ else:
+ if "choices" in result and result["choices"]:
+ message = result["choices"][0].get("message", {})
+ raw_content = message.get("content", "")
+ content, reasoning = self._extract_reasoning(raw_content if raw_content else "")
+
+ explicit_reasoning = message.get("model_extra", {}).get("reasoning_content", "")
+ if not explicit_reasoning:
+ explicit_reasoning = message.get("reasoning_content", "")
+ reasoning_content = explicit_reasoning if explicit_reasoning else reasoning
+
+ tool_calls = message.get("tool_calls", None)
+
+ usage = result.get("usage", {})
+ if usage:
+ prompt_tokens = usage.get("prompt_tokens", 0)
+ completion_tokens = usage.get("completion_tokens", 0)
+ total_tokens = usage.get("total_tokens", 0)
+ else:
+ logger.warning(f"模型 {self.model_name} (OpenAI) 的响应中缺少 'usage' 信息。")
else:
- return content, reasoning_content
+ logger.warning(f"模型 {self.model_name} (OpenAI) 的响应格式不符合预期: {result}")
+
+ if prompt_tokens > 0 or completion_tokens > 0 or total_tokens > 0:
+ self._record_usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ user_id=user_id,
+ request_type=request_type,
+ endpoint=endpoint,
+ )
+ else:
+ logger.warning(f"模型 {self.model_name}: 未能从响应中提取有效的 token 使用信息。")
+
+ if tool_calls:
+ logger.debug(f"检测到工具调用: {tool_calls}")
+ return content, reasoning_content, tool_calls
+ else:
+ return content, reasoning_content
- return "没有返回结果", ""
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取"""
- match = re.search(r"(?:)?(.*?)", content, re.DOTALL)
- content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip()
+ # (代码不变)
+ if not content:
+ return "", ""
+ match = re.search(r"(.*?)", content, re.DOTALL)
+ cleaned_content = re.sub(r".*?", "", content, flags=re.DOTALL, count=1).strip()
if match:
reasoning = match.group(1).strip()
else:
reasoning = ""
- return content, reasoning
+ return cleaned_content, reasoning
- async def _build_headers(self, no_key: bool = False) -> dict:
- """构建请求头"""
+ async def _build_headers(self, api_key: str, no_key: bool = False) -> dict:
+ """构建请求头 (区分 Gemini 和 OpenAI)"""
+ # (代码不变)
if no_key:
- return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
+ if self.is_gemini:
+ return {"x-goog-api-key": "**********", "Content-Type": "application/json"}
+ else:
+ return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
else:
- return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
- # 防止小朋友们截图自己的key
+ if not api_key:
+ logger.error(f"尝试使用无效 (空) 的 API key 为模型 {self.model_name} 构建请求头。")
+ raise ValueError(f"无效的 API key 提供给 _build_headers。")
- async def generate_response(self, prompt: str) -> Tuple:
- """根据输入的提示生成模型的异步响应"""
+ if self.is_gemini:
+ return {"x-goog-api-key": api_key, "Content-Type": "application/json"}
+ else:
+ return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
- response = await self._execute_request(endpoint="/chat/completions", prompt=prompt)
- # 根据返回值的长度决定怎么处理
+
+ async def generate_response(self, prompt: str, user_id: str = "system", **kwargs) -> Tuple:
+ """根据输入的提示生成模型的异步响应,支持覆盖参数"""
+ # (代码不变)
+ endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
+ response = await self._execute_request(
+ endpoint=endpoint,
+ prompt=prompt,
+ user_id=user_id,
+ request_type="chat",
+ **kwargs
+ )
if len(response) == 3:
content, reasoning_content, tool_calls = response
return content, reasoning_content, self.model_name, tool_calls
@@ -714,176 +1123,231 @@ class LLMRequest:
content, reasoning_content = response
return content, reasoning_content, self.model_name
- async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple:
- """根据输入的提示和图片生成模型的异步响应"""
-
+ async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str, user_id: str = "system", **kwargs) -> Tuple:
+ """根据输入的提示和图片生成模型的异步响应,支持覆盖参数"""
+ # (代码不变)
+ endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
response = await self._execute_request(
- endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format
+ endpoint=endpoint,
+ prompt=prompt,
+ image_base64=image_base64,
+ image_format=image_format,
+ user_id=user_id,
+ request_type="vision",
+ **kwargs
)
- # 根据返回值的长度决定怎么处理
- if len(response) == 3:
- content, reasoning_content, tool_calls = response
- return content, reasoning_content, tool_calls
- else:
- content, reasoning_content = response
- return content, reasoning_content
-
- async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
- """异步方式根据输入的提示生成模型的响应"""
- # 构建请求体,不硬编码max_tokens
- data = {
- "model": self.model_name,
- "messages": [{"role": "user", "content": prompt}],
- **self.params,
- **kwargs,
- }
-
- response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
- # 原样返回响应,不做处理
-
return response
- async def generate_response_tool_async(self, prompt: str, tools: list, **kwargs) -> tuple[str, str, list]:
- """异步方式根据输入的提示生成模型的响应"""
- # 构建请求体,不硬编码max_tokens
+ async def generate_response_async(self, prompt: str, user_id: str = "system", request_type: str = "chat", **kwargs) -> Union[str, Tuple]:
+ """异步方式根据输入的提示生成模型的响应 (通用),支持覆盖参数"""
+ # (代码不变)
+ endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
+ response = await self._execute_request(
+ endpoint=endpoint,
+ prompt=prompt,
+ payload=None,
+ retry_policy=None,
+ response_handler=None,
+ user_id=user_id,
+ request_type=request_type,
+ **kwargs
+ )
+ return response
+
+ async def generate_response_tool_async(self, prompt: str, tools: list, user_id: str = "system", **kwargs) -> tuple[str, str, list | None]:
+ """异步方式根据输入的提示和工具生成模型的响应,支持覆盖参数 (Gemini 暂不完全适配)"""
+ # (代码不变)
+ if self.is_gemini:
+ logger.warning(f"模型 {self.model_name}: Gemini 的函数调用实现与 OpenAI 不同,当前实现可能不兼容。")
+ return "Gemini 函数调用暂未完全适配", "", None
+
+ endpoint = "/chat/completions"
+ merged_params = {**self.params, **kwargs}
+ transformed_params = await self._transform_parameters(merged_params)
+
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
- **self.params,
- **kwargs,
+ **transformed_params,
"tools": tools,
+ "tool_choice": transformed_params.get("tool_choice", "auto"),
}
+ if "max_completion_tokens" in data:
+ data["max_tokens"] = data.pop("max_completion_tokens")
+ if "max_tokens" not in data:
+ data["max_tokens"] = global_config.model_max_output_length
+
+ response = await self._execute_request(
+ endpoint=endpoint,
+ payload=data,
+ prompt=prompt,
+ user_id=user_id,
+ request_type="tool_call",
+ **kwargs
+ )
- response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
logger.debug(f"向模型 {self.model_name} 发送工具调用请求,包含 {len(tools)} 个工具,返回结果: {response}")
- # 检查响应是否包含工具调用
- if len(response) == 3:
+
+ if isinstance(response, tuple) and len(response) == 3:
content, reasoning_content, tool_calls = response
- logger.debug(f"收到工具调用响应,包含 {len(tool_calls) if tool_calls else 0} 个工具调用")
+ if tool_calls:
+ logger.debug(f"收到工具调用响应,包含 {len(tool_calls)} 个工具调用")
+ else:
+ logger.debug("收到响应结构但无实际工具调用,视为普通响应")
return content, reasoning_content, tool_calls
- else:
+ elif isinstance(response, tuple) and len(response) == 2:
content, reasoning_content = response
logger.debug("收到普通响应,无工具调用")
return content, reasoning_content, None
+ else:
+ logger.error(f"收到来自 _execute_request 的意外响应格式: {response}")
+ return "处理响应时出错", "", None
- async def get_embedding(self, text: str) -> Union[list, None]:
- """异步方法:获取文本的embedding向量
-
- Args:
- text: 需要获取embedding的文本
-
- Returns:
- list: embedding向量,如果失败则返回None
- """
+ async def get_embedding(self, text: str, user_id: str = "system", **kwargs) -> Union[list, None]:
+ """异步方法:获取文本的embedding向量,支持覆盖参数 (Gemini Embedding 需注意模型名称)"""
+ # (代码不变)
if len(text) < 1:
logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
return None
+ # 移除内部参数,避免发送给 API
+ api_kwargs = {k: v for k, v in kwargs.items() if k != 'request_type'}
+
+
+ if self.is_gemini:
+ endpoint = ":embedContent"
+ payload = {
+ "model": f"models/{self.model_name}",
+ "content": {
+ "parts": [{"text": text}]
+ },
+ **api_kwargs # 合并过滤后的 kwargs
+ }
+ payload.pop("encoding_format", None)
+ payload.pop("input", None)
+
+ else:
+ endpoint = "/embeddings"
+ payload = {
+ "model": self.model_name,
+ "input": text,
+ "encoding_format": "float",
+ **api_kwargs # 合并过滤后的 kwargs
+ }
+ payload.pop("content", None)
+ payload.pop("taskType", None)
+
+
def embedding_handler(result):
- """处理响应"""
- if "data" in result and len(result["data"]) > 0:
- # 提取 token 使用信息
+ # (代码不变)
+ embedding_value = None
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+
+ if self.is_gemini:
+ if "embedding" in result and "value" in result["embedding"]:
+ embedding_value = result["embedding"]["value"]
+ logger.warning(f"模型 {self.model_name} (Gemini Embedding): 响应中未找到明确的 token 使用信息。")
+ else:
+ if "data" in result and len(result["data"]) > 0:
+ embedding_value = result["data"][0].get("embedding", None)
usage = result.get("usage", {})
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
- completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
- # 记录 token 使用情况
- self._record_usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- user_id="system", # 可以根据需要修改 user_id
- # request_type="embedding", # 请求类型为 embedding
- request_type=self.request_type, # 请求类型为 text
- endpoint="/embeddings", # API 端点
- )
- return result["data"][0].get("embedding", None)
- return result["data"][0].get("embedding", None)
- return None
+ else:
+ logger.warning(f"模型 {self.model_name} (OpenAI Embedding) 的响应中缺少 'usage' 信息。")
+
+ self._record_usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ user_id=user_id,
+ request_type="embedding",
+ endpoint=endpoint,
+ )
+ return embedding_value
embedding = await self._execute_request(
- endpoint="/embeddings",
+ endpoint=endpoint,
+ payload=payload,
prompt=text,
- payload={"model": self.model_name, "input": text, "encoding_format": "float"},
retry_policy={"max_retries": 2, "base_wait": 6},
response_handler=embedding_handler,
+ user_id=user_id,
+ request_type="embedding",
+ **api_kwargs # 传递过滤后的 kwargs
)
return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
- """压缩base64格式的图片到指定大小
- Args:
- base64_data: base64编码的图片数据
- target_size: 目标文件大小(字节),默认0.8MB
- Returns:
- str: 压缩后的base64图片数据
- """
+ """压缩base64格式的图片到指定大小"""
+ # (代码不变)
try:
- # 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
-
- # 如果已经小于目标大小,直接返回原图
- if len(image_data) <= 2 * 1024 * 1024:
+ if len(image_data) <= target_size * 1.05:
+ logger.info(f"图片大小 {len(image_data) / 1024:.1f}KB 已足够小,无需压缩。")
return base64_data
-
- # 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
-
- # 获取原始尺寸
+ img_format = img.format
original_width, original_height = img.size
-
- # 计算缩放比例
- scale = min(1.0, (target_size / len(image_data)) ** 0.5)
-
- # 计算新的尺寸
- new_width = int(original_width * scale)
- new_height = int(original_height * scale)
-
- # 创建内存缓冲区
+ scale = max(0.2, min(1.0, (target_size / len(image_data)) ** 0.5))
+ new_width = max(1, int(original_width * scale))
+ new_height = max(1, int(original_height * scale))
output_buffer = io.BytesIO()
+ save_format = img_format # Default to original format
- # 如果是GIF,处理所有帧
- if getattr(img, "is_animated", False):
+ if getattr(img, "is_animated", False) and img.n_frames > 1:
frames = []
+ durations = []
+ loop = img.info.get('loop', 0)
+ disposal = img.info.get('disposal', 2)
+ logger.info(f"检测到 GIF 动图 ({img.n_frames} 帧),尝试按比例压缩...")
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
- new_frame = img.copy()
- new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折
- frames.append(new_frame)
-
- # 保存到缓冲区
- frames[0].save(
- output_buffer,
- format="GIF",
- save_all=True,
- append_images=frames[1:],
- optimize=True,
- duration=img.info.get("duration", 100),
- loop=img.info.get("loop", 0),
- )
- else:
- # 处理静态图片
- resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
-
- # 保存到缓冲区,保持原始格式
- if img.format == "PNG" and img.mode in ("RGBA", "LA"):
- resized_img.save(output_buffer, format="PNG", optimize=True)
+ current_duration = img.info.get('duration', 100)
+ durations.append(current_duration)
+ new_frame = img.convert("RGBA").copy()
+ resized_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
+ frames.append(resized_frame)
+ if frames:
+ frames[0].save(
+ output_buffer, format="GIF", save_all=True, append_images=frames[1:],
+ optimize=False, duration=durations, loop=loop, disposal=disposal,
+ transparency=img.info.get('transparency', None), background=img.info.get('background', None)
+ )
+ save_format = "GIF"
else:
- resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True)
+ logger.warning("未能处理 GIF 帧。")
+ return base64_data
+ else:
+ if img.mode in ("RGBA", "LA") or 'transparency' in img.info:
+ resized_img = img.convert("RGBA").resize((new_width, new_height), Image.Resampling.LANCZOS)
+ save_format = "PNG"
+ save_params = {"optimize": True}
+ else:
+ resized_img = img.convert("RGB").resize((new_width, new_height), Image.Resampling.LANCZOS)
+ if img_format and img_format.upper() == "JPEG":
+ save_format = "JPEG"
+ save_params = {"quality": 85, "optimize": True}
+ else:
+ save_format = "PNG"
+ save_params = {"optimize": True}
+ resized_img.save(output_buffer, format=save_format, **save_params)
- # 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
- logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
- logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB")
-
- return base64.b64encode(compressed_data).decode("utf-8")
-
+ logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height} ({img.format} -> {save_format})")
+ logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB (目标: {target_size / 1024:.1f}KB)")
+ if len(compressed_data) < len(image_data) * 0.95:
+ return base64.b64encode(compressed_data).decode("utf-8")
+ else:
+ logger.info("压缩效果不明显或反而增大,返回原始图片。")
+ return base64_data
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
-
logger.error(traceback.format_exc())
- return base64_data
+ return base64_data
\ No newline at end of file