MaiBot/src/chat/models/utils_model.py

1517 lines
73 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncio
import json
import random # 添加 random 模块导入
import re
from datetime import datetime
from typing import Tuple, Union, Dict, Any, Set # 引入 Set
import aiohttp
from aiohttp.client import ClientResponse
# 相对路径导入,根据你的项目结构调整
# 例如,如果 utils_model.py 在 src/utils/ 下,而 logger 在 src/common/ 下
# from ..common.logger import get_module_logger
# from ..common.database import db
# from ..config.config import global_config
# 假设它们在期望的路径
from src.common.logger import get_module_logger
from ...common.database import db
from ...config.config import global_config
import base64
from PIL import Image
import io
import os
from rich.traceback import install
install(extra_lines=3)
# 尝试加载 .env 文件中的环境变量 (如果项目结构需要)
# load_dotenv() # 如果你的 .env 文件不在标准位置,可能需要指定路径 load_dotenv(dotenv_path='path/to/.env')
logger = get_module_logger("model_utils")
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
# (代码不变)
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return "请求体过大,请尝试压缩图片或减少输入内容。"
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
# (代码不变)
def __init__(self, message: str, response: ClientResponse):
super().__init__(message)
self.message = message
self.response = response
def __str__(self):
return self.message
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
# (代码不变)
def __init__(self, message: str, key_identifier: str = None): # 添加 key 标识符
super().__init__(message)
self.message = message
self.key_identifier = key_identifier # 存储导致 403 的 key
def __str__(self):
return self.message
# 新增:用于内部标记需要切换 Key 的异常
class _SwitchKeyException(Exception):
"""内部异常用于标记需要切换Key并且跳过标准等待时间."""
# (代码不变)
pass
# 常见Error Code Mapping
error_code_mapping = {
# (代码不变)
400: "参数不正确",
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", # 401 也可能是 Key 无效
402: "账号余额不足",
403: "需要实名,或余额不足,或Key无权限", # 扩展 403 的含义
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
"""安全地记录请求内容,隐藏敏感信息"""
# (代码不变)
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
is_gemini_payload = payload and isinstance(payload, dict) and "contents" in payload
safe_payload = json.loads(json.dumps(payload)) if payload else {}
if image_base64 and safe_payload and isinstance(safe_payload, dict):
if "messages" in safe_payload and len(safe_payload["messages"]) > 0:
if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
content = safe_payload["messages"][0]["content"]
if (
isinstance(content, list)
and len(content) > 1
and isinstance(content[1], dict)
and "image_url" in content[1]
):
safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
elif is_gemini_payload and "contents" in safe_payload and len(safe_payload["contents"]) > 0:
if isinstance(safe_payload["contents"][0], dict) and "parts" in safe_payload["contents"][0]:
parts = safe_payload["contents"][0]["parts"]
for i, part in enumerate(parts):
if isinstance(part, dict) and "inlineData" in part:
safe_payload["contents"][0]["parts"][i]["inlineData"]["data"] = (
f"{image_base64[:10]}...{image_base64[-10:]}"
)
break
return safe_payload
class LLMRequest:
# (代码不变)
MODELS_NEEDING_TRANSFORMATION = [
"o1",
"o1-2024-12-17",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"o1-pro",
"o1-pro-2025-03-19",
"o3",
"o3-2025-04-16",
"o3-mini",
"o3-mini-2025-01-31o4-mini",
"o4-mini-2025-04-16",
]
_abandoned_keys_runtime: Set[str] = set()
def __init__(self, model: dict, **kwargs):
"""初始化 LLMRequest 实例"""
# (代码不变)
self.model_key_name = model["key"]
self.model_name: str = model["name"]
self.params = kwargs
self.stream = model.get("stream", False)
self.pri_in = model.get("pri_in", 0)
self.pri_out = model.get("pri_out", 0)
self.request_type = model.get("request_type", "default")
try:
raw_api_key_config = os.environ[self.model_key_name]
self.base_url = os.environ[model["base_url"]]
self.is_gemini = "googleapis.com" in self.base_url.lower()
if self.is_gemini:
logger.debug(f"模型 {self.model_name}: 检测到为 Gemini API (Base URL: {self.base_url})")
if self.stream:
logger.warning(f"模型 {self.model_name}: Gemini 流式输出处理与 OpenAI 不同,暂时强制禁用流式。")
self.stream = False
# 解析和过滤 API Keys (代码不变)
parsed_keys = []
is_list_config = False
try:
loaded_keys = json.loads(raw_api_key_config)
if isinstance(loaded_keys, list):
parsed_keys = [str(key) for key in loaded_keys if key]
is_list_config = True
elif isinstance(loaded_keys, str) and loaded_keys:
parsed_keys = [loaded_keys]
else:
raise ValueError(f"Parsed API key for {self.model_key_name} is not a valid list or string.")
except (json.JSONDecodeError, TypeError) as e:
if isinstance(raw_api_key_config, list):
parsed_keys = [str(key) for key in raw_api_key_config if key]
is_list_config = True
elif isinstance(raw_api_key_config, str) and raw_api_key_config:
parsed_keys = [raw_api_key_config]
else:
raise ValueError(
f"Invalid or empty API key config for {self.model_key_name}: {raw_api_key_config}"
) from e
if not parsed_keys:
raise ValueError(f"No valid API keys found for {self.model_key_name}.")
abandoned_key_name = f"abandon_{self.model_key_name}"
abandoned_keys_set = set()
raw_abandoned_keys = os.environ.get(abandoned_key_name)
if raw_abandoned_keys:
try:
loaded_abandoned = json.loads(raw_abandoned_keys)
if isinstance(loaded_abandoned, list):
abandoned_keys_set.update(str(key) for key in loaded_abandoned if key)
elif isinstance(loaded_abandoned, str) and loaded_abandoned:
abandoned_keys_set.add(loaded_abandoned)
logger.info(
f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' 的废弃 Keys。"
)
except (json.JSONDecodeError, TypeError):
if isinstance(raw_abandoned_keys, list):
abandoned_keys_set.update(str(key) for key in raw_abandoned_keys if key)
logger.info(
f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' (直接列表) 的废弃 Keys。"
)
elif isinstance(raw_abandoned_keys, str) and raw_abandoned_keys:
abandoned_keys_set.add(raw_abandoned_keys)
logger.info(
f"模型 {model['name']}: 加载了 1 个来自配置 '{abandoned_key_name}' (字符串) 的废弃 Key。"
)
else:
logger.warning(f"无法解析环境变量 '{abandoned_key_name}' 的内容: {raw_abandoned_keys}")
all_abandoned_keys = abandoned_keys_set.union(LLMRequest._abandoned_keys_runtime)
active_keys = [key for key in parsed_keys if key not in all_abandoned_keys]
if not active_keys:
logger.error(f"模型 {model['name']}: 所有为 '{self.model_key_name}' 配置的 Keys 都已被废弃或无效。")
raise ValueError(
f"No active API keys available for {self.model_key_name} after filtering abandoned keys."
)
if is_list_config and len(active_keys) > 1:
self._api_key_config = active_keys
logger.info(
f"模型 {model['name']}: 初始化完成,可用 Keys: {len(self._api_key_config)} (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。"
)
elif active_keys:
self._api_key_config = active_keys[0]
logger.info(
f"模型 {model['name']}: 初始化完成,使用单个活动 Key (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。"
)
else:
raise ValueError(f"Unexpected state: No active keys for {self.model_key_name}.")
# 加载代理配置 (代码不变)
self.proxy_url = None
self.proxy_models_set = set()
proxy_host = os.environ.get("PROXY_HOST")
proxy_port = os.environ.get("PROXY_PORT")
proxy_models_str = os.environ.get("PROXY_MODELS", "")
if proxy_host and proxy_port:
try:
int(proxy_port)
self.proxy_url = f"http://{proxy_host}:{proxy_port}"
logger.debug(f"代理已配置: {self.proxy_url}")
if proxy_models_str:
try:
cleaned_str = proxy_models_str.strip("'\"")
self.proxy_models_set = {
model_name.strip() for model_name in cleaned_str.split(",") if model_name.strip()
}
logger.debug(f"以下模型将使用代理: {self.proxy_models_set}")
except Exception as e:
logger.error(
f"解析 PROXY_MODELS ('{proxy_models_str}') 出错: {e}. 代理将不会对特定模型生效。"
)
self.proxy_models_set = set()
except ValueError:
logger.error(f"无效的代理端口号: {proxy_port}。代理将不被启用。")
self.proxy_url = None
self.proxy_models_set = set()
except Exception as e:
logger.error(f"加载代理配置时发生错误: {e}")
self.proxy_url = None
self.proxy_models_set = set()
else:
logger.info("未配置代理服务器 (PROXY_HOST 或 PROXY_PORT 未设置)。")
except KeyError as e:
# (代码不变)
missing_key = str(e).strip("'")
if missing_key == self.model_key_name:
logger.error(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'")
raise ValueError(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'") from e
elif missing_key == model["base_url"]:
logger.error(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'")
raise ValueError(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'") from e
else:
logger.error(f"配置错误:找不到环境变量 - {str(e)}")
raise ValueError(f"配置错误:找不到环境变量 - {str(e)}") from e
except AttributeError as e:
# (代码不变)
logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
except ValueError as e:
# (代码不变)
logger.error(f"API Key 或配置初始化错误 for {self.model_key_name}: {str(e)}")
raise e
self._init_database()
@staticmethod
def _init_database():
"""初始化数据库集合"""
# (代码不变)
try:
db.llm_usage.create_index([("timestamp", 1)])
db.llm_usage.create_index([("model_name", 1)])
db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)])
except Exception as e:
logger.error(f"创建数据库索引失败: {str(e)}")
def _record_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
user_id: str = "system",
request_type: str = None,
endpoint: str = "/chat/completions",
):
"""记录模型使用情况到数据库
Args:
prompt_tokens: 输入token数
completion_tokens: 输出token数
total_tokens: 总token数
user_id: 用户ID默认为system
request_type: 请求类型
endpoint: API端点
"""
# 如果 request_type 为 None则使用实例变量中的值
if request_type is None:
request_type = self.request_type
actual_endpoint = endpoint
if self.is_gemini:
if endpoint == "/embeddings":
actual_endpoint = ":embedContent"
else:
actual_endpoint = ":generateContent"
try:
usage_data = {
"model_name": self.model_name,
"user_id": user_id,
"request_type": request_type or self.request_type,
"endpoint": actual_endpoint,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
"status": "success",
"timestamp": datetime.now(),
}
db.llm_usage.insert_one(usage_data)
logger.trace(
f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type or self.request_type}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本"""
# (代码不变)
input_cost = (prompt_tokens / 1000000) * self.pri_in
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)
async def _prepare_request(
self,
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""配置请求参数,合并实例参数和调用时参数"""
default_retry = {
"max_retries": global_config.api_polling_max_retries,
"base_wait": 10,
"retry_codes": [429, 413, 500, 503],
"abort_codes": [400, 401, 402, 403],
}
policy = {**default_retry, **(retry_policy or {})}
_actual_endpoint = endpoint
if self.is_gemini:
action = endpoint.lstrip("/")
api_url = f"{self.base_url.rstrip('/')}/{self.model_name}{action}"
stream_mode = False
else:
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
stream_mode = self.stream
call_params = {k: v for k, v in kwargs.items() if k != "request_type"}
merged_params = {**self.params, **call_params}
if payload is None:
payload = await self._build_payload(prompt, image_base64, image_format, merged_params)
else:
logger.debug("使用外部提供的 payload忽略单次调用参数合并。")
if not self.is_gemini and stream_mode:
payload["stream"] = merged_params.get("stream", stream_mode)
return {
"policy": policy,
"payload": payload,
"api_url": api_url,
"stream_mode": payload.get("stream", False),
"image_base64": image_base64,
"image_format": image_format,
"prompt": prompt,
}
async def _execute_request(
self,
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
user_id: str = "system",
request_type: str = None,
**kwargs: Any,
):
"""统一请求执行入口, 支持列表 key 切换、代理和单次调用参数覆盖"""
final_request_type = request_type or kwargs.get("request_type") or self.request_type
api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"}
request_content = await self._prepare_request(
endpoint, prompt, image_base64, image_format, payload, retry_policy, **api_kwargs
)
policy = request_content["policy"]
api_url = request_content["api_url"]
actual_payload = request_content["payload"]
stream_mode = request_content["stream_mode"]
use_proxy = False
current_proxy_url = None
if self.proxy_url and self.model_name in self.proxy_models_set:
use_proxy = True
current_proxy_url = self.proxy_url
logger.debug(f"模型 {self.model_name}: 将通过代理 {current_proxy_url} 发送请求。")
elif self.proxy_url:
logger.debug(f"模型 {self.model_name}: 配置了代理,但此模型不在 PROXY_MODELS 列表中,将不使用代理。")
else:
logger.debug(f"模型 {self.model_name}: 未配置或不为此模型使用代理。")
current_key = None
keys_failed_429 = set()
keys_abandoned_runtime = set()
key_switch_limit_429 = global_config.api_polling_max_retries
key_switch_limit_403 = global_config.api_polling_max_retries
available_keys_pool = []
is_key_list = isinstance(self._api_key_config, list)
if is_key_list:
available_keys_pool = list(self._api_key_config)
if not available_keys_pool:
logger.error(f"模型 {self.model_name}: 初始化后无可用活动 Keys。")
raise ValueError(f"模型 {self.model_name}: 无可用活动 Keys。")
random.shuffle(available_keys_pool)
key_switch_limit_429 = min(key_switch_limit_429, len(available_keys_pool))
key_switch_limit_403 = min(key_switch_limit_403, len(available_keys_pool))
logger.info(
f"模型 {self.model_name}: Key 列表模式,启用 429/403 自动切换429上限: {key_switch_limit_429}, 403上限: {key_switch_limit_403})。"
)
elif isinstance(self._api_key_config, str):
available_keys_pool = [self._api_key_config]
key_switch_limit_429 = 1
key_switch_limit_403 = 1
else:
logger.error(f"模型 {self.model_name}: 无效的 API Key 配置类型在执行时遇到: {type(self._api_key_config)}")
raise TypeError(f"模型 {self.model_name}: 无效的 API Key 配置类型")
last_exception = None
for attempt in range(policy["max_retries"]):
if available_keys_pool:
current_key = available_keys_pool.pop(0)
elif current_key:
logger.debug(
f"模型 {self.model_name}: 无新 Key 可用或为单 Key 模式,将使用 Key ...{current_key[-4:]} 进行重试 (第 {attempt + 1} 次尝试)"
)
else:
if (
not self._api_key_config
or all(
k in LLMRequest._abandoned_keys_runtime
for k in self._api_key_config
if isinstance(self._api_key_config, list)
)
or (
isinstance(self._api_key_config, str)
and self._api_key_config in LLMRequest._abandoned_keys_runtime
)
):
final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
logger.critical(final_error_msg)
raise PermissionDeniedException(final_error_msg)
else:
raise RuntimeError(f"模型 {self.model_name}: 无法选择 API key (第 {attempt + 1} 次尝试)")
logger.debug(f"模型 {self.model_name}: 尝试使用 Key: ...{current_key[-4:]} (总第 {attempt + 1} 次尝试)")
try:
headers = await self._build_headers(current_key)
if not self.is_gemini and stream_mode:
headers["Accept"] = "text/event-stream"
async with aiohttp.ClientSession() as session:
post_kwargs = {"headers": headers, "json": actual_payload, "timeout": 60}
if use_proxy:
post_kwargs["proxy"] = current_proxy_url
async with session.post(api_url, **post_kwargs) as response:
if response.status == 429 and is_key_list:
logger.warning(f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 429 错误。")
response_text = await response.text()
logger.debug(
f"模型 {self.model_name}: Key ...{current_key[-4:]} response:\n{json.dumps(json.loads(response_text), indent=2, ensure_ascii=False)}\napi_url:\n{api_url}\nheader:\n{headers}\npayload:\n{actual_payload}"
)
if current_key not in keys_failed_429:
keys_failed_429.add(current_key)
logger.info(
f" (因 429 已失败 {len(keys_failed_429)}/{key_switch_limit_429} 个不同 Key)"
)
if available_keys_pool and len(keys_failed_429) < key_switch_limit_429:
logger.info(" 尝试因 429 切换到下一个可用 Key...")
raise _SwitchKeyException()
else:
logger.warning(" 无更多 Key 可因 429 切换或已达上限。")
else:
logger.warning(f" Key ...{current_key[-4:]} 再次遇到 429按标准重试流程。")
elif response.status == 403 and is_key_list:
logger.error(
f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 403 (权限拒绝) 错误。"
)
if current_key not in keys_abandoned_runtime:
keys_abandoned_runtime.add(current_key)
LLMRequest._abandoned_keys_runtime.add(current_key)
logger.critical(
f" !! Key ...{current_key[-4:]} 已添加到运行时废弃列表。请考虑将其移至配置中的 'abandon_{self.model_key_name}' !!"
)
if current_key in available_keys_pool:
available_keys_pool.remove(current_key)
if available_keys_pool and len(keys_abandoned_runtime) < key_switch_limit_403:
logger.info(" 尝试因 403 切换到下一个可用 Key...")
raise _SwitchKeyException()
else:
logger.error(" 无更多 Key 可因 403 切换或已达上限。将中止请求。")
await response.read()
raise PermissionDeniedException(
f"Key ...{current_key[-4:]} 权限被拒,且无其他可用 Key 切换。",
key_identifier=current_key,
)
else:
logger.error(f" Key ...{current_key[-4:]} 再次遇到 403这不应发生。中止请求。")
await response.read()
raise PermissionDeniedException(
f"Key ...{current_key[-4:]} 重复遇到 403。", key_identifier=current_key
)
elif response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
await self._handle_error_response(response, attempt, policy, current_key)
if response.status in policy["retry_codes"] and attempt < policy["max_retries"] - 1:
if response.status not in [429, 403]:
wait_time = policy["base_wait"] * (2**attempt)
logger.warning(
f"模型 {self.model_name}: 遇到可重试错误 {response.status}, 等待 {wait_time} 秒后重试..."
)
await asyncio.sleep(wait_time)
last_exception = RuntimeError(f"重试错误 {response.status}")
continue
if response.status in policy["abort_codes"] or (
response.status in policy["retry_codes"] and attempt >= policy["max_retries"] - 1
):
if attempt >= policy["max_retries"] - 1 and response.status in policy["retry_codes"]:
logger.error(
f"模型 {self.model_name}: 达到最大重试次数,最后一次尝试仍为可重试错误 {response.status}"
)
# await self._handle_error_response(response, attempt, policy, current_key)
# await response.read()
# final_error_msg = f"请求中止或达到最大重试次数,最终状态码: {response.status}"
# logger.error(final_error_msg)
# raise RequestAbortException(final_error_msg, response)
response.raise_for_status()
result = {}
if not self.is_gemini and stream_mode:
result = await self._handle_stream_output(response)
else:
result = await response.json()
return (
response_handler(result)
if response_handler
else self._default_response_handler(result, user_id, final_request_type, endpoint)
)
except _SwitchKeyException:
last_exception = _SwitchKeyException()
logger.debug("捕获到 _SwitchKeyException立即进行下一次尝试。")
continue
except PermissionDeniedException as e:
logger.error(f"模型 {self.model_name}: 因权限拒绝 (403) 中止请求: {e}")
if is_key_list and not available_keys_pool and e.key_identifier:
logger.critical(f" 中止原因是 Key ...{e.key_identifier[-4:]} 触发 403 后已无其他 Key 可用。")
raise e
except aiohttp.ClientProxyConnectionError as e:
logger.error(f"代理连接错误: {e} (代理地址: {current_proxy_url})")
last_exception = e
if attempt >= policy["max_retries"] - 1:
raise RuntimeError(f"代理连接失败达到最大重试次数: {e}") from e
wait_time = policy["base_wait"] * (2**attempt)
logger.warning(f"模型 {self.model_name}: 代理连接错误,等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
continue
except aiohttp.ClientConnectorError as e:
logger.error(f"网络连接错误: {e} (URL: {api_url}, 代理: {current_proxy_url})")
last_exception = e
if attempt >= policy["max_retries"] - 1:
raise RuntimeError(f"网络连接失败达到最大重试次数: {e}") from e
wait_time = policy["base_wait"] * (2**attempt)
logger.warning(f"模型 {self.model_name}: 网络连接错误,等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
continue
except (PayLoadTooLargeError, RequestAbortException) as e:
# (代码不变)
logger.error(f"模型 {self.model_name}: 请求处理中遇到关键错误,将中止: {e}")
raise e
except Exception as e:
# (代码不变)
last_exception = e
logger.warning(
f"模型 {self.model_name}: 第 {attempt + 1} 次尝试中发生非 HTTP 错误: {str(e.__class__.__name__)} - {str(e)}"
)
if attempt >= policy["max_retries"] - 1:
logger.error(
f"模型 {self.model_name}: 达到最大重试次数 ({policy['max_retries']}),因非 HTTP 错误失败。"
)
else:
try:
temp_request_content = {
"policy": policy,
"payload": actual_payload,
"api_url": api_url,
"stream_mode": stream_mode,
"image_base64": image_base64,
"image_format": image_format,
"prompt": prompt,
}
handled_payload, count_delta = await self._handle_exception(
e, attempt, temp_request_content, merged_params=api_kwargs
)
if handled_payload:
actual_payload = handled_payload
logger.info(f"模型 {self.model_name}: 异常处理更新了 payload将使用当前 Key 重试。")
wait_time = policy["base_wait"] * (2**attempt)
logger.warning(f"模型 {self.model_name}: 等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
continue
except (RequestAbortException, PermissionDeniedException) as abort_exception:
logger.error(f"模型 {self.model_name}: 异常处理判断需要中止请求: {abort_exception}")
raise abort_exception
except RuntimeError as rt_error:
logger.error(f"模型 {self.model_name}: 异常处理遇到运行时错误: {rt_error}")
raise rt_error
# --- 循环结束 ---
logger.error(f"模型 {self.model_name}: 所有重试尝试 ({policy['max_retries']} 次) 均失败。")
if last_exception:
if isinstance(last_exception, PermissionDeniedException):
logger.error(f"最后遇到的错误是权限拒绝: {str(last_exception)}")
raise last_exception
logger.error(f"最后遇到的错误: {str(last_exception.__class__.__name__)} - {str(last_exception)}")
raise RuntimeError(
f"模型 {self.model_name} 达到最大重试次数API 请求失败。最后错误: {str(last_exception)}"
) from last_exception
else:
if not available_keys_pool and keys_abandoned_runtime:
final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
logger.critical(final_error_msg)
raise PermissionDeniedException(final_error_msg)
else:
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数API 请求失败,原因未知。")
async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
"""处理 OpenAI 兼容的流式输出"""
# (代码不变)
flag_delta_content_finished = False
accumulated_content = ""
usage = None
reasoning_content = ""
content = ""
tool_calls = None
async for line_bytes in response.content:
try:
line = line_bytes.decode("utf-8").strip()
if not line:
continue
if line.startswith("data:"):
data_str = line[5:].strip()
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
if flag_delta_content_finished:
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
if "tool_calls" in delta:
if tool_calls is None:
tool_calls = []
for tc in delta["tool_calls"]:
new_tc = dict(tc)
if "function" in new_tc and "arguments" not in new_tc["function"]:
new_tc["function"]["arguments"] = ""
tool_calls.append(new_tc)
else:
for i, tc_delta in enumerate(delta["tool_calls"]):
if (
i < len(tool_calls)
and "function" in tc_delta
and "arguments" in tc_delta["function"]
):
if "arguments" in tool_calls[i]["function"]:
tool_calls[i]["function"]["arguments"] += tc_delta["function"][
"arguments"
]
else:
tool_calls[i]["function"]["arguments"] = tc_delta["function"][
"arguments"
]
elif i >= len(tool_calls):
new_tc = dict(tc_delta)
if "function" in new_tc and "arguments" not in new_tc["function"]:
new_tc["function"]["arguments"] = ""
tool_calls.append(new_tc)
finish_reason = chunk["choices"][0].get("finish_reason")
if delta.get("reasoning_content", None):
reasoning_content += delta["reasoning_content"]
if finish_reason == "stop" or finish_reason == "tool_calls":
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage
break
flag_delta_content_finished = True
except json.JSONDecodeError as e:
logger.error(f"模型 {self.model_name} 解析流式 JSON 错误: {e} - data: '{data_str}'")
except Exception as e:
logger.exception(f"模型 {self.model_name} 解析流式输出块错误: {str(e)}")
except UnicodeDecodeError as e:
logger.warning(f"模型 {self.model_name} 流式输出解码错误: {e} - bytes: {line_bytes[:50]}...")
except Exception as e:
if isinstance(e, GeneratorExit):
log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
else:
log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
logger.warning(log_content)
try:
await response.release()
except Exception as cleanup_error:
logger.error(f"清理资源时发生错误: {cleanup_error}")
content = accumulated_content
break
if not content and accumulated_content:
content = accumulated_content
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
message = {
"content": content,
"reasoning_content": reasoning_content,
}
if tool_calls:
message["tool_calls"] = tool_calls
result = {
"choices": [{"message": message}],
"usage": usage,
}
return result
async def _handle_error_response(
self, response: ClientResponse, retry_count: int, policy: Dict[str, Any], current_key: str = None
) -> None:
"""处理 HTTP 错误响应 (区分 403 和其他错误)"""
# (代码不变)
status = response.status
try:
error_text = await response.text()
except Exception as e:
error_text = f"(无法读取响应体: {e})"
if status == 403:
logger.error(
f"模型 {self.model_name}: 遇到 403 (权限拒绝) 错误。Key: ...{current_key[-4:] if current_key else 'N/A'}. "
f"响应: {error_text[:200]}"
)
raise PermissionDeniedException(f"模型禁止访问 ({status})", key_identifier=current_key)
elif status in policy["retry_codes"] and status != 429:
if status == 413:
logger.warning(
f"模型 {self.model_name}: 错误码 413 (Payload Too Large)。Key: ...{current_key[-4:] if current_key else 'N/A'}. 尝试压缩..."
)
raise PayLoadTooLargeError("请求体过大")
elif status in [500, 503]:
logger.error(
f"模型 {self.model_name}: 服务器内部错误或过载 ({status})。Key: ...{current_key[-4:] if current_key else 'N/A'}. "
f"响应: {error_text[:200]}"
)
return
else:
logger.warning(
f"模型 {self.model_name}: 遇到可重试错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}"
)
return
elif status in policy["abort_codes"]:
logger.error(
f"模型 {self.model_name}: 遇到需要中止的错误码: {status} - {error_code_mapping.get(status, '未知错误')}. "
f"Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}"
)
raise RequestAbortException(f"请求出现错误 {status},中止处理", response)
else:
logger.error(
f"模型 {self.model_name}: 遇到未明确处理的错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}"
)
try:
response.raise_for_status()
raise RequestAbortException(f"未处理的错误状态码 {status}", response)
except aiohttp.ClientResponseError as e:
raise RequestAbortException(f"未处理的错误状态码 {status}: {e.message}", response) from e
async def _handle_exception(
self, exception, retry_count: int, request_content: Dict[str, Any], merged_params: Dict[str, Any] = None
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
"""处理非 HTTP 错误,支持使用合并后的参数重建 payload"""
policy = request_content["policy"]
payload = request_content["payload"]
_wait_time = policy["base_wait"] * (2**retry_count)
keep_request = False
if retry_count < policy["max_retries"] - 1:
keep_request = True
params_for_rebuild = merged_params if merged_params is not None else payload
if isinstance(exception, PayLoadTooLargeError):
if keep_request:
logger.warning("请求体过大 (PayLoadTooLargeError),尝试压缩图片...")
image_base64 = request_content.get("image_base64")
if image_base64:
compressed_image_base64 = compress_base64_image_by_scale(image_base64)
if compressed_image_base64 != image_base64:
new_payload = await self._build_payload(
request_content["prompt"],
compressed_image_base64,
request_content["image_format"],
params_for_rebuild,
)
logger.info("图片压缩成功,将使用压缩后的图片重试。")
return new_payload, 0
else:
logger.warning("图片压缩未改变大小或失败。")
else:
logger.warning("请求体过大但请求中不包含图片,无法压缩。")
return None, 0
else:
logger.error("达到最大重试次数,请求体仍然过大。")
raise RuntimeError("请求体过大,压缩或重试后仍然失败。") from exception
elif isinstance(exception, (aiohttp.ClientError, asyncio.TimeoutError)):
if keep_request:
logger.error(f"模型 {self.model_name} 网络错误: {str(exception)}")
return None, 0
else:
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
raise RuntimeError(f"网络请求失败: {str(exception)}") from exception
elif isinstance(exception, aiohttp.ClientResponseError):
if keep_request:
logger.error(
f"模型 {self.model_name} HTTP响应错误 (未被策略覆盖): 状态码: {exception.status}, 错误: {exception.message}"
)
try:
error_text = await exception.response.text() if hasattr(exception, "response") else str(exception)
logger.error(f"服务器错误响应详情: {error_text[:500]}")
except Exception as parse_err:
logger.warning(f"无法解析服务器错误响应内容: {str(parse_err)}")
return None, 0
else:
logger.critical(
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
)
current_key_placeholder = request_content.get("current_key", "******")
handled_payload = await _safely_record(request_content, payload)
logger.critical(
f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}"
)
raise RuntimeError(
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
) from exception
else:
if keep_request:
logger.error(
f"模型 {self.model_name} 遇到未知错误: {str(exception.__class__.__name__)} - {str(exception)}"
)
return None, 0
else:
logger.critical(
f"模型 {self.model_name} 请求因未知错误失败: {str(exception.__class__.__name__)} - {str(exception)}"
)
current_key_placeholder = request_content.get("current_key", "******")
handled_payload = await _safely_record(request_content, payload)
logger.critical(
f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}"
)
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") from exception
async def _transform_parameters(self, merged_params: dict) -> dict:
"""根据模型名称转换合并后的参数,并移除内部参数"""
# (代码不变)
new_params = dict(merged_params)
new_params.pop("request_type", None)
if not self.is_gemini and self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
new_params.pop("temperature", None)
if "max_tokens" in new_params:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
elif self.is_gemini:
gen_config = new_params.get("generationConfig", {})
if "temperature" in new_params:
gen_config["temperature"] = new_params.pop("temperature")
if "max_tokens" in new_params:
gen_config["maxOutputTokens"] = new_params.pop("max_tokens")
if "top_p" in new_params:
gen_config["topP"] = new_params.pop("top_p")
if "top_k" in new_params:
gen_config["topK"] = new_params.pop("top_k")
if gen_config:
new_params["generationConfig"] = gen_config
new_params.pop("frequency_penalty", None)
new_params.pop("presence_penalty", None)
new_params.pop("max_completion_tokens", None)
return new_params
async def _build_payload(
self, prompt: str, image_base64: str = None, image_format: str = None, merged_params: dict = None
) -> dict:
"""构建请求体 (区分 Gemini 和 OpenAI),使用合并和转换后的参数"""
# (代码不变)
if merged_params is None:
merged_params = self.params
params_copy = await self._transform_parameters(merged_params)
if self.is_gemini:
parts = []
if prompt:
parts.append({"text": prompt})
if image_base64:
mime_type = f"image/{image_format.lower() if image_format else 'jpeg'}"
parts.append({"inlineData": {"mimeType": mime_type, "data": image_base64}})
payload = {"contents": [{"parts": parts}], **params_copy}
payload.pop("model", None)
# --- 添加 Gemini 安全设置 ---
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
payload["safetySettings"] = safety_settings
logger.debug(f"模型 {self.model_name}: 已为 Gemini 函数调用请求添加 safetySettings (BLOCK_NONE)。")
# --- 结束添加安全设置 ---
else:
if image_base64:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64}"
},
},
],
}
]
else:
messages = [{"role": "user", "content": prompt}]
payload = {
"model": self.model_name,
"messages": messages,
**params_copy,
}
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
if "max_tokens" not in params_copy and "max_completion_tokens" not in params_copy:
payload["max_tokens"] = global_config.model_max_output_length
if "max_completion_tokens" in payload:
payload["max_tokens"] = payload.pop("max_completion_tokens")
return payload
def _default_response_handler(
self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
) -> Tuple:
"""默认响应解析 (区分 Gemini 和 OpenAI),并处理函数/工具调用"""
content = "没有返回结果"
reasoning_content = ""
tool_calls = None # OpenAI 格式
function_call = None # Gemini 格式
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
if self.is_gemini:
# --- 解析 Gemini 响应 ---
try:
if "candidates" in result and result["candidates"]:
candidate = result["candidates"][0]
# 检查是否有 content 和 parts
if "content" in candidate and "parts" in candidate["content"] and candidate["content"]["parts"]:
# 查找 functionCall 或 text 部分
final_text_parts = []
for part in candidate["content"]["parts"]:
if "functionCall" in part:
function_call = part["functionCall"] # 获取 Gemini 的 functionCall
# Gemini functionCall 通常不与 text 一起返回,这里假设只处理 functionCall
break # 找到 functionCall 就停止处理 parts
elif "text" in part:
final_text_parts.append(part.get("text", ""))
if not function_call: # 如果没有 functionCall处理 text
raw_content = "".join(final_text_parts).strip()
content, reasoning = self._extract_reasoning(raw_content)
reasoning_content = reasoning
# else: function_call 已获取content 留空或设为特定值
else:
content = "Gemini响应中缺少 content 或 parts"
logger.warning(f"模型 {self.model_name}: Gemini 响应格式不完整 (缺少 content/parts): {result}")
finish_reason = candidate.get("finishReason")
if finish_reason == "SAFETY":
logger.warning(f"模型 {self.model_name}: Gemini 响应因安全设置被阻止。")
content = "响应内容因安全原因被过滤。"
elif finish_reason == "RECITATION":
logger.warning(f"模型 {self.model_name}: Gemini 响应因引用限制被阻止。")
content = "响应内容因引用限制被过滤。"
elif finish_reason == "OTHER":
logger.warning(f"模型 {self.model_name}: Gemini 响应因未知原因停止。")
# finishReason == "TOOL_CODE" or "FUNCTION_CALL" 是正常情况
usage = result.get("usageMetadata", {})
if usage:
prompt_tokens = usage.get("promptTokenCount", 0)
completion_tokens = usage.get("candidatesTokenCount", 0)
total_tokens = usage.get("totalTokenCount", 0)
if completion_tokens == 0 and total_tokens > 0:
completion_tokens = total_tokens - prompt_tokens
else:
logger.warning(f"模型 {self.model_name} (Gemini) 的响应中缺少 'usageMetadata' 信息。")
except Exception as e:
logger.error(f"解析 Gemini 响应出错: {e} - 响应: {result}")
content = "解析 Gemini 响应时出错"
else:
# --- 解析 OpenAI 兼容响应 ---
# (代码不变)
if "choices" in result and result["choices"]:
message = result["choices"][0].get("message", {})
raw_content = message.get("content", "")
content, reasoning = self._extract_reasoning(raw_content if raw_content else "")
explicit_reasoning = message.get("model_extra", {}).get("reasoning_content", "")
if not explicit_reasoning:
explicit_reasoning = message.get("reasoning_content", "")
reasoning_content = explicit_reasoning if explicit_reasoning else reasoning
tool_calls = message.get("tool_calls", None) # 获取 OpenAI 的 tool_calls
usage = result.get("usage", {})
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
else:
logger.warning(f"模型 {self.model_name} (OpenAI) 的响应中缺少 'usage' 信息。")
else:
logger.warning(f"模型 {self.model_name} (OpenAI) 的响应格式不符合预期: {result}")
# --- 记录 Token 使用情况 ---
# (代码不变)
if prompt_tokens > 0 or completion_tokens > 0 or total_tokens > 0:
self._record_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
)
else:
logger.warning(f"模型 {self.model_name}: 未能从响应中提取有效的 token 使用信息。")
# --- 返回结果 (统一格式) ---
final_tool_calls = None
if tool_calls: # 来自 OpenAI
final_tool_calls = tool_calls
logger.debug(f"检测到 OpenAI 工具调用: {final_tool_calls}")
elif function_call: # 来自 Gemini
logger.debug(f"检测到 Gemini 函数调用: {function_call}")
# 将 Gemini functionCall 转换为 OpenAI tool_calls 格式
# 注意: Gemini 的 functionCall 没有显式的 id 和 type需要模拟
final_tool_calls = [
{
"id": f"call_{random.randint(1000, 9999)}", # 生成一个随机 ID
"type": "function",
"function": {
"name": function_call.get("name"),
# Gemini 的参数在 'args' 中OpenAI 在 'arguments' (通常是 JSON 字符串)
# 需要将 Gemini 的 dict 参数转换为 JSON 字符串
"arguments": json.dumps(function_call.get("args", {})),
},
}
]
logger.debug(f"转换为 OpenAI tool_calls 格式: {final_tool_calls}")
if final_tool_calls:
# 如果有工具/函数调用,通常 content 为空或包含思考过程,这里返回转换后的调用信息
return content, reasoning_content, final_tool_calls
else:
# 没有工具/函数调用,返回普通文本响应
return content, reasoning_content
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取"""
# (代码不变)
if not content:
return "", ""
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
cleaned_content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
if match:
reasoning = match.group(1).strip()
else:
reasoning = ""
return cleaned_content, reasoning
async def _build_headers(self, api_key: str, no_key: bool = False) -> dict:
"""构建请求头 (区分 Gemini 和 OpenAI)"""
# (代码不变)
if no_key:
if self.is_gemini:
return {"x-goog-api-key": "**********", "Content-Type": "application/json"}
else:
return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
else:
if not api_key:
logger.error(f"尝试使用无效 (空) 的 API key 为模型 {self.model_name} 构建请求头。")
raise ValueError("无效的 API key 提供给 _build_headers。")
if self.is_gemini:
return {"x-goog-api-key": api_key, "Content-Type": "application/json"}
else:
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
async def generate_response(self, prompt: str, user_id: str = "system", **kwargs) -> Tuple:
"""根据输入的提示生成模型的异步响应,支持覆盖参数"""
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
response = await self._execute_request(
endpoint=endpoint, prompt=prompt, user_id=user_id, request_type="chat", **kwargs
)
if len(response) == 3:
content, reasoning_content, tool_calls = response
return content, reasoning_content, self.model_name, tool_calls
else:
content, reasoning_content = response
return content, reasoning_content, self.model_name
async def generate_response_for_image(
self, prompt: str, image_base64: str, image_format: str, user_id: str = "system", **kwargs
) -> Tuple:
"""根据输入的提示和图片生成模型的异步响应,支持覆盖参数"""
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
response = await self._execute_request(
endpoint=endpoint,
prompt=prompt,
image_base64=image_base64,
image_format=image_format,
user_id=user_id,
request_type="vision",
**kwargs,
)
# _default_response_handler 现在总是返回至少2个值
if len(response) == 3:
return response # content, reasoning, tool_calls (tool_calls 可能为 None)
elif len(response) == 2:
content, reasoning = response
return content, reasoning # 对于 vision 请求,通常没有 tool_calls
else:
logger.error(f"来自 _default_response_handler 的意外响应格式: {response}")
return "处理响应出错", ""
async def generate_response_async(
self, prompt: str, user_id: str = "system", request_type: str = "chat", **kwargs
) -> Union[str, Tuple]:
"""异步方式根据输入的提示生成模型的响应 (通用),支持覆盖参数"""
# (代码不变)
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
response = await self._execute_request(
endpoint=endpoint,
prompt=prompt,
payload=None,
retry_policy=None,
response_handler=None,
user_id=user_id,
request_type=request_type,
**kwargs,
)
return response
# 修改:实现 Gemini Function Calling 的 Payload 构建
async def generate_response_tool_async(
self, prompt: str, tools: list, user_id: str = "system", **kwargs
) -> tuple[str, str, list | None]:
"""异步方式根据输入的提示和工具生成模型的响应,支持覆盖参数和 Gemini 函数调用"""
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
merged_params = {**self.params, **kwargs}
transformed_params = await self._transform_parameters(merged_params) # 清理 request_type 等
payload = None
if self.is_gemini:
# --- 构建 Gemini Function Calling Payload ---
logger.debug(f"为 Gemini ({self.model_name}) 构建函数调用请求。")
# 1. 转换工具定义 (OpenAI -> Gemini)
# OpenAI tool format: [{"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}}]
# Gemini tool format: [{"functionDeclarations": [{"name": ..., "description": ..., "parameters": ...}]}]
function_declarations = []
if tools:
for tool in tools:
if tool.get("type") == "function" and "function" in tool:
func_def = tool["function"]
# Gemini parameters 使用 OpenAPI Schema与 OpenAI 基本兼容
function_declarations.append(
{
"name": func_def.get("name"),
"description": func_def.get("description", ""), # Description is required for Gemini
"parameters": func_def.get(
"parameters", {"type": "object", "properties": {}}
), # Ensure parameters exist
}
)
else:
logger.warning(f"跳过不支持的工具类型或格式: {tool}")
if not function_declarations:
logger.error("没有有效的函数声明可用于 Gemini 请求。")
return "没有提供有效的函数定义", "", None
gemini_tools = [{"functionDeclarations": function_declarations}]
# 2. 构建 Gemini Payload
# parts = [{"text": prompt}] # 初始 parts
payload = {
"contents": [{"parts": [{"text": prompt}]}], # 包含用户提示
"tools": gemini_tools,
# toolConfig 默认是 AUTO可以根据需要从 kwargs 获取或硬编码
# "toolConfig": {"functionCallingConfig": {"mode": "ANY"}}, # 例如强制调用
**transformed_params, # 合并其他转换后的参数 (如 generationConfig)
}
payload.pop("model", None) # Gemini 不在顶层传 model
payload.pop("messages", None) # 移除 OpenAI 特有的 messages
payload.pop("tool_choice", None) # 移除 OpenAI 特有的 tool_choice
logger.trace(f"构建的 Gemini 函数调用 Payload: {json.dumps(payload, indent=2)}")
else:
# --- 构建 OpenAI Tool Calling Payload ---
# (逻辑不变)
logger.debug(f"为 OpenAI 兼容模型 ({self.model_name}) 构建工具调用请求。")
payload = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
**transformed_params,
"tools": tools,
"tool_choice": transformed_params.get("tool_choice", "auto"),
}
if "max_completion_tokens" in payload:
payload["max_tokens"] = payload.pop("max_completion_tokens")
if "max_tokens" not in payload:
payload["max_tokens"] = global_config.model_max_output_length
# --- 执行请求 ---
if payload is None:
logger.error("未能构建有效的 API 请求 payload。")
return "内部错误:无法构建请求", "", None
response = await self._execute_request(
endpoint=endpoint,
payload=payload,
prompt=prompt, # prompt 仍然需要,用于可能的重试
user_id=user_id,
request_type="tool_call",
**kwargs, # 传递原始 kwargs 以便在重试时重新合并
)
# _default_response_handler 现在会处理 Gemini functionCall 并统一格式
logger.debug(f"模型 {self.model_name} 工具/函数调用返回结果: {response}")
if isinstance(response, tuple) and len(response) == 3:
content, reasoning_content, final_tool_calls = response
# final_tool_calls 已经是统一的 OpenAI 格式
return content, reasoning_content, final_tool_calls
elif isinstance(response, tuple) and len(response) == 2:
content, reasoning_content = response
logger.debug("收到普通响应,无工具/函数调用")
return content, reasoning_content, None
else:
logger.error(f"收到来自 _execute_request/_default_response_handler 的意外响应格式: {response}")
return "处理响应时出错", "", None
async def get_embedding(self, text: str, user_id: str = "system", **kwargs) -> Union[list, None]:
"""异步方法获取文本的embedding向量支持覆盖参数 (Gemini Embedding 需注意模型名称)"""
# (代码不变)
if len(text) < 1:
logger.debug("该消息没有长度不再发送获取embedding向量的请求")
return None
api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"}
if self.is_gemini:
endpoint = ":embedContent"
payload = {"model": f"models/{self.model_name}", "content": {"parts": [{"text": text}]}, **api_kwargs}
payload.pop("encoding_format", None)
payload.pop("input", None)
else:
endpoint = "/embeddings"
payload = {"model": self.model_name, "input": text, "encoding_format": "float", **api_kwargs}
payload.pop("content", None)
payload.pop("taskType", None)
def embedding_handler(result):
# (代码不变)
embedding_value = None
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
if self.is_gemini:
if "embedding" in result and "value" in result["embedding"]:
embedding_value = result["embedding"]["value"]
logger.warning(f"模型 {self.model_name} (Gemini Embedding): 响应中未找到明确的 token 使用信息。")
else:
if "data" in result and len(result["data"]) > 0:
embedding_value = result["data"][0].get("embedding", None)
usage = result.get("usage", {})
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
else:
logger.warning(f"模型 {self.model_name} (OpenAI Embedding) 的响应中缺少 'usage' 信息。")
self._record_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
user_id=user_id,
request_type="embedding",
endpoint=endpoint,
)
return embedding_value
embedding = await self._execute_request(
endpoint=endpoint,
payload=payload,
prompt=text,
retry_policy={"max_retries": 2, "base_wait": 6},
response_handler=embedding_handler,
user_id=user_id,
request_type="embedding",
**api_kwargs,
)
return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小"""
# (代码不变)
try:
image_data = base64.b64decode(base64_data)
if len(image_data) <= target_size * 1.05:
logger.info(f"图片大小 {len(image_data) / 1024:.1f}KB 已足够小,无需压缩。")
return base64_data
img = Image.open(io.BytesIO(image_data))
img_format = img.format
original_width, original_height = img.size
scale = max(0.2, min(1.0, (target_size / len(image_data)) ** 0.5))
new_width = max(1, int(original_width * scale))
new_height = max(1, int(original_height * scale))
output_buffer = io.BytesIO()
save_format = img_format # Default to original format
if getattr(img, "is_animated", False) and img.n_frames > 1:
frames = []
durations = []
loop = img.info.get("loop", 0)
disposal = img.info.get("disposal", 2)
logger.info(f"检测到 GIF 动图 ({img.n_frames} 帧),尝试按比例压缩...")
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
current_duration = img.info.get("duration", 100)
durations.append(current_duration)
new_frame = img.convert("RGBA").copy()
resized_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
frames.append(resized_frame)
if frames:
frames[0].save(
output_buffer,
format="GIF",
save_all=True,
append_images=frames[1:],
optimize=False,
duration=durations,
loop=loop,
disposal=disposal,
transparency=img.info.get("transparency", None),
background=img.info.get("background", None),
)
save_format = "GIF"
else:
logger.warning("未能处理 GIF 帧。")
return base64_data
else:
if img.mode in ("RGBA", "LA") or "transparency" in img.info:
resized_img = img.convert("RGBA").resize((new_width, new_height), Image.Resampling.LANCZOS)
save_format = "PNG"
save_params = {"optimize": True}
else:
resized_img = img.convert("RGB").resize((new_width, new_height), Image.Resampling.LANCZOS)
if img_format and img_format.upper() == "JPEG":
save_format = "JPEG"
save_params = {"quality": 85, "optimize": True}
else:
save_format = "PNG"
save_params = {"optimize": True}
resized_img.save(output_buffer, format=save_format, **save_params)
compressed_data = output_buffer.getvalue()
logger.success(
f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height} ({img.format} -> {save_format})"
)
logger.info(
f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB (目标: {target_size / 1024:.1f}KB)"
)
if len(compressed_data) < len(image_data) * 0.95:
return base64.b64encode(compressed_data).decode("utf-8")
else:
logger.info("压缩效果不明显或反而增大,返回原始图片。")
return base64_data
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data