mirror of https://github.com/Mai-with-u/MaiBot.git
1517 lines
73 KiB
Python
1517 lines
73 KiB
Python
import asyncio
|
||
import json
|
||
import random # 添加 random 模块导入
|
||
import re
|
||
from datetime import datetime
|
||
from typing import Tuple, Union, Dict, Any, Set # 引入 Set
|
||
|
||
import aiohttp
|
||
from aiohttp.client import ClientResponse
|
||
|
||
# 相对路径导入,根据你的项目结构调整
|
||
# 例如,如果 utils_model.py 在 src/utils/ 下,而 logger 在 src/common/ 下
|
||
# from ..common.logger import get_module_logger
|
||
# from ..common.database import db
|
||
# from ..config.config import global_config
|
||
# 假设它们在期望的路径
|
||
from src.common.logger import get_module_logger
|
||
from ...common.database import db
|
||
from ...config.config import global_config
|
||
|
||
|
||
import base64
|
||
from PIL import Image
|
||
import io
|
||
import os
|
||
|
||
from rich.traceback import install
|
||
|
||
install(extra_lines=3)
|
||
|
||
# 尝试加载 .env 文件中的环境变量 (如果项目结构需要)
|
||
# load_dotenv() # 如果你的 .env 文件不在标准位置,可能需要指定路径 load_dotenv(dotenv_path='path/to/.env')
|
||
|
||
logger = get_module_logger("model_utils")
|
||
|
||
|
||
class PayLoadTooLargeError(Exception):
|
||
"""自定义异常类,用于处理请求体过大错误"""
|
||
|
||
# (代码不变)
|
||
def __init__(self, message: str):
|
||
super().__init__(message)
|
||
self.message = message
|
||
|
||
def __str__(self):
|
||
return "请求体过大,请尝试压缩图片或减少输入内容。"
|
||
|
||
|
||
class RequestAbortException(Exception):
|
||
"""自定义异常类,用于处理请求中断异常"""
|
||
|
||
# (代码不变)
|
||
def __init__(self, message: str, response: ClientResponse):
|
||
super().__init__(message)
|
||
self.message = message
|
||
self.response = response
|
||
|
||
def __str__(self):
|
||
return self.message
|
||
|
||
|
||
class PermissionDeniedException(Exception):
|
||
"""自定义异常类,用于处理访问拒绝的异常"""
|
||
|
||
# (代码不变)
|
||
def __init__(self, message: str, key_identifier: str = None): # 添加 key 标识符
|
||
super().__init__(message)
|
||
self.message = message
|
||
self.key_identifier = key_identifier # 存储导致 403 的 key
|
||
|
||
def __str__(self):
|
||
return self.message
|
||
|
||
|
||
# 新增:用于内部标记需要切换 Key 的异常
|
||
class _SwitchKeyException(Exception):
|
||
"""内部异常,用于标记需要切换Key并且跳过标准等待时间."""
|
||
|
||
# (代码不变)
|
||
pass
|
||
|
||
|
||
# 常见Error Code Mapping
|
||
error_code_mapping = {
|
||
# (代码不变)
|
||
400: "参数不正确",
|
||
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", # 401 也可能是 Key 无效
|
||
402: "账号余额不足",
|
||
403: "需要实名,或余额不足,或Key无权限", # 扩展 403 的含义
|
||
404: "Not Found",
|
||
429: "请求过于频繁,请稍后再试",
|
||
500: "服务器内部故障",
|
||
503: "服务器负载过高",
|
||
}
|
||
|
||
|
||
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
|
||
"""安全地记录请求内容,隐藏敏感信息"""
|
||
# (代码不变)
|
||
image_base64: str = request_content.get("image_base64")
|
||
image_format: str = request_content.get("image_format")
|
||
is_gemini_payload = payload and isinstance(payload, dict) and "contents" in payload
|
||
safe_payload = json.loads(json.dumps(payload)) if payload else {}
|
||
|
||
if image_base64 and safe_payload and isinstance(safe_payload, dict):
|
||
if "messages" in safe_payload and len(safe_payload["messages"]) > 0:
|
||
if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
|
||
content = safe_payload["messages"][0]["content"]
|
||
if (
|
||
isinstance(content, list)
|
||
and len(content) > 1
|
||
and isinstance(content[1], dict)
|
||
and "image_url" in content[1]
|
||
):
|
||
safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||
)
|
||
elif is_gemini_payload and "contents" in safe_payload and len(safe_payload["contents"]) > 0:
|
||
if isinstance(safe_payload["contents"][0], dict) and "parts" in safe_payload["contents"][0]:
|
||
parts = safe_payload["contents"][0]["parts"]
|
||
for i, part in enumerate(parts):
|
||
if isinstance(part, dict) and "inlineData" in part:
|
||
safe_payload["contents"][0]["parts"][i]["inlineData"]["data"] = (
|
||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||
)
|
||
break
|
||
|
||
return safe_payload
|
||
|
||
|
||
class LLMRequest:
|
||
# (代码不变)
|
||
MODELS_NEEDING_TRANSFORMATION = [
|
||
"o1",
|
||
"o1-2024-12-17",
|
||
"o1-mini",
|
||
"o1-mini-2024-09-12",
|
||
"o1-preview",
|
||
"o1-preview-2024-09-12",
|
||
"o1-pro",
|
||
"o1-pro-2025-03-19",
|
||
"o3",
|
||
"o3-2025-04-16",
|
||
"o3-mini",
|
||
"o3-mini-2025-01-31o4-mini",
|
||
"o4-mini-2025-04-16",
|
||
]
|
||
_abandoned_keys_runtime: Set[str] = set()
|
||
|
||
def __init__(self, model: dict, **kwargs):
|
||
"""初始化 LLMRequest 实例"""
|
||
# (代码不变)
|
||
self.model_key_name = model["key"]
|
||
self.model_name: str = model["name"]
|
||
self.params = kwargs
|
||
self.stream = model.get("stream", False)
|
||
self.pri_in = model.get("pri_in", 0)
|
||
self.pri_out = model.get("pri_out", 0)
|
||
self.request_type = model.get("request_type", "default")
|
||
|
||
try:
|
||
raw_api_key_config = os.environ[self.model_key_name]
|
||
self.base_url = os.environ[model["base_url"]]
|
||
self.is_gemini = "googleapis.com" in self.base_url.lower()
|
||
if self.is_gemini:
|
||
logger.debug(f"模型 {self.model_name}: 检测到为 Gemini API (Base URL: {self.base_url})")
|
||
if self.stream:
|
||
logger.warning(f"模型 {self.model_name}: Gemini 流式输出处理与 OpenAI 不同,暂时强制禁用流式。")
|
||
self.stream = False
|
||
|
||
# 解析和过滤 API Keys (代码不变)
|
||
parsed_keys = []
|
||
is_list_config = False
|
||
try:
|
||
loaded_keys = json.loads(raw_api_key_config)
|
||
if isinstance(loaded_keys, list):
|
||
parsed_keys = [str(key) for key in loaded_keys if key]
|
||
is_list_config = True
|
||
elif isinstance(loaded_keys, str) and loaded_keys:
|
||
parsed_keys = [loaded_keys]
|
||
else:
|
||
raise ValueError(f"Parsed API key for {self.model_key_name} is not a valid list or string.")
|
||
except (json.JSONDecodeError, TypeError) as e:
|
||
if isinstance(raw_api_key_config, list):
|
||
parsed_keys = [str(key) for key in raw_api_key_config if key]
|
||
is_list_config = True
|
||
elif isinstance(raw_api_key_config, str) and raw_api_key_config:
|
||
parsed_keys = [raw_api_key_config]
|
||
else:
|
||
raise ValueError(
|
||
f"Invalid or empty API key config for {self.model_key_name}: {raw_api_key_config}"
|
||
) from e
|
||
|
||
if not parsed_keys:
|
||
raise ValueError(f"No valid API keys found for {self.model_key_name}.")
|
||
|
||
abandoned_key_name = f"abandon_{self.model_key_name}"
|
||
abandoned_keys_set = set()
|
||
raw_abandoned_keys = os.environ.get(abandoned_key_name)
|
||
|
||
if raw_abandoned_keys:
|
||
try:
|
||
loaded_abandoned = json.loads(raw_abandoned_keys)
|
||
if isinstance(loaded_abandoned, list):
|
||
abandoned_keys_set.update(str(key) for key in loaded_abandoned if key)
|
||
elif isinstance(loaded_abandoned, str) and loaded_abandoned:
|
||
abandoned_keys_set.add(loaded_abandoned)
|
||
logger.info(
|
||
f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' 的废弃 Keys。"
|
||
)
|
||
except (json.JSONDecodeError, TypeError):
|
||
if isinstance(raw_abandoned_keys, list):
|
||
abandoned_keys_set.update(str(key) for key in raw_abandoned_keys if key)
|
||
logger.info(
|
||
f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' (直接列表) 的废弃 Keys。"
|
||
)
|
||
elif isinstance(raw_abandoned_keys, str) and raw_abandoned_keys:
|
||
abandoned_keys_set.add(raw_abandoned_keys)
|
||
logger.info(
|
||
f"模型 {model['name']}: 加载了 1 个来自配置 '{abandoned_key_name}' (字符串) 的废弃 Key。"
|
||
)
|
||
else:
|
||
logger.warning(f"无法解析环境变量 '{abandoned_key_name}' 的内容: {raw_abandoned_keys}")
|
||
|
||
all_abandoned_keys = abandoned_keys_set.union(LLMRequest._abandoned_keys_runtime)
|
||
active_keys = [key for key in parsed_keys if key not in all_abandoned_keys]
|
||
|
||
if not active_keys:
|
||
logger.error(f"模型 {model['name']}: 所有为 '{self.model_key_name}' 配置的 Keys 都已被废弃或无效。")
|
||
raise ValueError(
|
||
f"No active API keys available for {self.model_key_name} after filtering abandoned keys."
|
||
)
|
||
|
||
if is_list_config and len(active_keys) > 1:
|
||
self._api_key_config = active_keys
|
||
logger.info(
|
||
f"模型 {model['name']}: 初始化完成,可用 Keys: {len(self._api_key_config)} (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。"
|
||
)
|
||
elif active_keys:
|
||
self._api_key_config = active_keys[0]
|
||
logger.info(
|
||
f"模型 {model['name']}: 初始化完成,使用单个活动 Key (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。"
|
||
)
|
||
else:
|
||
raise ValueError(f"Unexpected state: No active keys for {self.model_key_name}.")
|
||
|
||
# 加载代理配置 (代码不变)
|
||
self.proxy_url = None
|
||
self.proxy_models_set = set()
|
||
proxy_host = os.environ.get("PROXY_HOST")
|
||
proxy_port = os.environ.get("PROXY_PORT")
|
||
proxy_models_str = os.environ.get("PROXY_MODELS", "")
|
||
|
||
if proxy_host and proxy_port:
|
||
try:
|
||
int(proxy_port)
|
||
self.proxy_url = f"http://{proxy_host}:{proxy_port}"
|
||
logger.debug(f"代理已配置: {self.proxy_url}")
|
||
|
||
if proxy_models_str:
|
||
try:
|
||
cleaned_str = proxy_models_str.strip("'\"")
|
||
self.proxy_models_set = {
|
||
model_name.strip() for model_name in cleaned_str.split(",") if model_name.strip()
|
||
}
|
||
logger.debug(f"以下模型将使用代理: {self.proxy_models_set}")
|
||
except Exception as e:
|
||
logger.error(
|
||
f"解析 PROXY_MODELS ('{proxy_models_str}') 出错: {e}. 代理将不会对特定模型生效。"
|
||
)
|
||
self.proxy_models_set = set()
|
||
except ValueError:
|
||
logger.error(f"无效的代理端口号: {proxy_port}。代理将不被启用。")
|
||
self.proxy_url = None
|
||
self.proxy_models_set = set()
|
||
except Exception as e:
|
||
logger.error(f"加载代理配置时发生错误: {e}")
|
||
self.proxy_url = None
|
||
self.proxy_models_set = set()
|
||
else:
|
||
logger.info("未配置代理服务器 (PROXY_HOST 或 PROXY_PORT 未设置)。")
|
||
|
||
except KeyError as e:
|
||
# (代码不变)
|
||
missing_key = str(e).strip("'")
|
||
if missing_key == self.model_key_name:
|
||
logger.error(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'")
|
||
raise ValueError(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'") from e
|
||
elif missing_key == model["base_url"]:
|
||
logger.error(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'")
|
||
raise ValueError(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'") from e
|
||
else:
|
||
logger.error(f"配置错误:找不到环境变量 - {str(e)}")
|
||
raise ValueError(f"配置错误:找不到环境变量 - {str(e)}") from e
|
||
except AttributeError as e:
|
||
# (代码不变)
|
||
logger.error(f"原始 model dict 信息:{model}")
|
||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
|
||
except ValueError as e:
|
||
# (代码不变)
|
||
logger.error(f"API Key 或配置初始化错误 for {self.model_key_name}: {str(e)}")
|
||
raise e
|
||
|
||
self._init_database()
|
||
|
||
@staticmethod
|
||
def _init_database():
|
||
"""初始化数据库集合"""
|
||
# (代码不变)
|
||
try:
|
||
db.llm_usage.create_index([("timestamp", 1)])
|
||
db.llm_usage.create_index([("model_name", 1)])
|
||
db.llm_usage.create_index([("user_id", 1)])
|
||
db.llm_usage.create_index([("request_type", 1)])
|
||
except Exception as e:
|
||
logger.error(f"创建数据库索引失败: {str(e)}")
|
||
|
||
def _record_usage(
|
||
self,
|
||
prompt_tokens: int,
|
||
completion_tokens: int,
|
||
total_tokens: int,
|
||
user_id: str = "system",
|
||
request_type: str = None,
|
||
endpoint: str = "/chat/completions",
|
||
):
|
||
"""记录模型使用情况到数据库
|
||
Args:
|
||
prompt_tokens: 输入token数
|
||
completion_tokens: 输出token数
|
||
total_tokens: 总token数
|
||
user_id: 用户ID,默认为system
|
||
request_type: 请求类型
|
||
endpoint: API端点
|
||
"""
|
||
# 如果 request_type 为 None,则使用实例变量中的值
|
||
if request_type is None:
|
||
request_type = self.request_type
|
||
|
||
actual_endpoint = endpoint
|
||
if self.is_gemini:
|
||
if endpoint == "/embeddings":
|
||
actual_endpoint = ":embedContent"
|
||
else:
|
||
actual_endpoint = ":generateContent"
|
||
|
||
try:
|
||
usage_data = {
|
||
"model_name": self.model_name,
|
||
"user_id": user_id,
|
||
"request_type": request_type or self.request_type,
|
||
"endpoint": actual_endpoint,
|
||
"prompt_tokens": prompt_tokens,
|
||
"completion_tokens": completion_tokens,
|
||
"total_tokens": total_tokens,
|
||
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
|
||
"status": "success",
|
||
"timestamp": datetime.now(),
|
||
}
|
||
db.llm_usage.insert_one(usage_data)
|
||
logger.trace(
|
||
f"Token使用情况 - 模型: {self.model_name}, "
|
||
f"用户: {user_id}, 类型: {request_type or self.request_type}, "
|
||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||
f"总计: {total_tokens}"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||
|
||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||
"""计算API调用成本"""
|
||
# (代码不变)
|
||
input_cost = (prompt_tokens / 1000000) * self.pri_in
|
||
output_cost = (completion_tokens / 1000000) * self.pri_out
|
||
return round(input_cost + output_cost, 6)
|
||
|
||
async def _prepare_request(
|
||
self,
|
||
endpoint: str,
|
||
prompt: str = None,
|
||
image_base64: str = None,
|
||
image_format: str = None,
|
||
payload: dict = None,
|
||
retry_policy: dict = None,
|
||
**kwargs: Any,
|
||
) -> Dict[str, Any]:
|
||
"""配置请求参数,合并实例参数和调用时参数"""
|
||
default_retry = {
|
||
"max_retries": global_config.api_polling_max_retries,
|
||
"base_wait": 10,
|
||
"retry_codes": [429, 413, 500, 503],
|
||
"abort_codes": [400, 401, 402, 403],
|
||
}
|
||
policy = {**default_retry, **(retry_policy or {})}
|
||
|
||
_actual_endpoint = endpoint
|
||
if self.is_gemini:
|
||
action = endpoint.lstrip("/")
|
||
api_url = f"{self.base_url.rstrip('/')}/{self.model_name}{action}"
|
||
stream_mode = False
|
||
else:
|
||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||
stream_mode = self.stream
|
||
|
||
call_params = {k: v for k, v in kwargs.items() if k != "request_type"}
|
||
merged_params = {**self.params, **call_params}
|
||
|
||
if payload is None:
|
||
payload = await self._build_payload(prompt, image_base64, image_format, merged_params)
|
||
else:
|
||
logger.debug("使用外部提供的 payload,忽略单次调用参数合并。")
|
||
|
||
if not self.is_gemini and stream_mode:
|
||
payload["stream"] = merged_params.get("stream", stream_mode)
|
||
|
||
return {
|
||
"policy": policy,
|
||
"payload": payload,
|
||
"api_url": api_url,
|
||
"stream_mode": payload.get("stream", False),
|
||
"image_base64": image_base64,
|
||
"image_format": image_format,
|
||
"prompt": prompt,
|
||
}
|
||
|
||
async def _execute_request(
|
||
self,
|
||
endpoint: str,
|
||
prompt: str = None,
|
||
image_base64: str = None,
|
||
image_format: str = None,
|
||
payload: dict = None,
|
||
retry_policy: dict = None,
|
||
response_handler: callable = None,
|
||
user_id: str = "system",
|
||
request_type: str = None,
|
||
**kwargs: Any,
|
||
):
|
||
"""统一请求执行入口, 支持列表 key 切换、代理和单次调用参数覆盖"""
|
||
final_request_type = request_type or kwargs.get("request_type") or self.request_type
|
||
api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"}
|
||
|
||
request_content = await self._prepare_request(
|
||
endpoint, prompt, image_base64, image_format, payload, retry_policy, **api_kwargs
|
||
)
|
||
policy = request_content["policy"]
|
||
api_url = request_content["api_url"]
|
||
actual_payload = request_content["payload"]
|
||
stream_mode = request_content["stream_mode"]
|
||
|
||
use_proxy = False
|
||
current_proxy_url = None
|
||
if self.proxy_url and self.model_name in self.proxy_models_set:
|
||
use_proxy = True
|
||
current_proxy_url = self.proxy_url
|
||
logger.debug(f"模型 {self.model_name}: 将通过代理 {current_proxy_url} 发送请求。")
|
||
elif self.proxy_url:
|
||
logger.debug(f"模型 {self.model_name}: 配置了代理,但此模型不在 PROXY_MODELS 列表中,将不使用代理。")
|
||
else:
|
||
logger.debug(f"模型 {self.model_name}: 未配置或不为此模型使用代理。")
|
||
|
||
current_key = None
|
||
keys_failed_429 = set()
|
||
keys_abandoned_runtime = set()
|
||
key_switch_limit_429 = global_config.api_polling_max_retries
|
||
key_switch_limit_403 = global_config.api_polling_max_retries
|
||
|
||
available_keys_pool = []
|
||
is_key_list = isinstance(self._api_key_config, list)
|
||
|
||
if is_key_list:
|
||
available_keys_pool = list(self._api_key_config)
|
||
if not available_keys_pool:
|
||
logger.error(f"模型 {self.model_name}: 初始化后无可用活动 Keys。")
|
||
raise ValueError(f"模型 {self.model_name}: 无可用活动 Keys。")
|
||
random.shuffle(available_keys_pool)
|
||
key_switch_limit_429 = min(key_switch_limit_429, len(available_keys_pool))
|
||
key_switch_limit_403 = min(key_switch_limit_403, len(available_keys_pool))
|
||
logger.info(
|
||
f"模型 {self.model_name}: Key 列表模式,启用 429/403 自动切换(429上限: {key_switch_limit_429}, 403上限: {key_switch_limit_403})。"
|
||
)
|
||
elif isinstance(self._api_key_config, str):
|
||
available_keys_pool = [self._api_key_config]
|
||
key_switch_limit_429 = 1
|
||
key_switch_limit_403 = 1
|
||
else:
|
||
logger.error(f"模型 {self.model_name}: 无效的 API Key 配置类型在执行时遇到: {type(self._api_key_config)}")
|
||
raise TypeError(f"模型 {self.model_name}: 无效的 API Key 配置类型")
|
||
|
||
last_exception = None
|
||
|
||
for attempt in range(policy["max_retries"]):
|
||
if available_keys_pool:
|
||
current_key = available_keys_pool.pop(0)
|
||
elif current_key:
|
||
logger.debug(
|
||
f"模型 {self.model_name}: 无新 Key 可用或为单 Key 模式,将使用 Key ...{current_key[-4:]} 进行重试 (第 {attempt + 1} 次尝试)"
|
||
)
|
||
else:
|
||
if (
|
||
not self._api_key_config
|
||
or all(
|
||
k in LLMRequest._abandoned_keys_runtime
|
||
for k in self._api_key_config
|
||
if isinstance(self._api_key_config, list)
|
||
)
|
||
or (
|
||
isinstance(self._api_key_config, str)
|
||
and self._api_key_config in LLMRequest._abandoned_keys_runtime
|
||
)
|
||
):
|
||
final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
|
||
logger.critical(final_error_msg)
|
||
raise PermissionDeniedException(final_error_msg)
|
||
else:
|
||
raise RuntimeError(f"模型 {self.model_name}: 无法选择 API key (第 {attempt + 1} 次尝试)")
|
||
|
||
logger.debug(f"模型 {self.model_name}: 尝试使用 Key: ...{current_key[-4:]} (总第 {attempt + 1} 次尝试)")
|
||
|
||
try:
|
||
headers = await self._build_headers(current_key)
|
||
if not self.is_gemini and stream_mode:
|
||
headers["Accept"] = "text/event-stream"
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
post_kwargs = {"headers": headers, "json": actual_payload, "timeout": 60}
|
||
if use_proxy:
|
||
post_kwargs["proxy"] = current_proxy_url
|
||
|
||
async with session.post(api_url, **post_kwargs) as response:
|
||
if response.status == 429 and is_key_list:
|
||
logger.warning(f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 429 错误。")
|
||
response_text = await response.text()
|
||
logger.debug(
|
||
f"模型 {self.model_name}: Key ...{current_key[-4:]} response:\n{json.dumps(json.loads(response_text), indent=2, ensure_ascii=False)}\napi_url:\n{api_url}\nheader:\n{headers}\npayload:\n{actual_payload}"
|
||
)
|
||
if current_key not in keys_failed_429:
|
||
keys_failed_429.add(current_key)
|
||
logger.info(
|
||
f" (因 429 已失败 {len(keys_failed_429)}/{key_switch_limit_429} 个不同 Key)"
|
||
)
|
||
if available_keys_pool and len(keys_failed_429) < key_switch_limit_429:
|
||
logger.info(" 尝试因 429 切换到下一个可用 Key...")
|
||
raise _SwitchKeyException()
|
||
else:
|
||
logger.warning(" 无更多 Key 可因 429 切换或已达上限。")
|
||
else:
|
||
logger.warning(f" Key ...{current_key[-4:]} 再次遇到 429,按标准重试流程。")
|
||
|
||
elif response.status == 403 and is_key_list:
|
||
logger.error(
|
||
f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 403 (权限拒绝) 错误。"
|
||
)
|
||
if current_key not in keys_abandoned_runtime:
|
||
keys_abandoned_runtime.add(current_key)
|
||
LLMRequest._abandoned_keys_runtime.add(current_key)
|
||
logger.critical(
|
||
f" !! Key ...{current_key[-4:]} 已添加到运行时废弃列表。请考虑将其移至配置中的 'abandon_{self.model_key_name}' !!"
|
||
)
|
||
if current_key in available_keys_pool:
|
||
available_keys_pool.remove(current_key)
|
||
if available_keys_pool and len(keys_abandoned_runtime) < key_switch_limit_403:
|
||
logger.info(" 尝试因 403 切换到下一个可用 Key...")
|
||
raise _SwitchKeyException()
|
||
else:
|
||
logger.error(" 无更多 Key 可因 403 切换或已达上限。将中止请求。")
|
||
await response.read()
|
||
raise PermissionDeniedException(
|
||
f"Key ...{current_key[-4:]} 权限被拒,且无其他可用 Key 切换。",
|
||
key_identifier=current_key,
|
||
)
|
||
else:
|
||
logger.error(f" Key ...{current_key[-4:]} 再次遇到 403,这不应发生。中止请求。")
|
||
await response.read()
|
||
raise PermissionDeniedException(
|
||
f"Key ...{current_key[-4:]} 重复遇到 403。", key_identifier=current_key
|
||
)
|
||
|
||
elif response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
|
||
await self._handle_error_response(response, attempt, policy, current_key)
|
||
|
||
if response.status in policy["retry_codes"] and attempt < policy["max_retries"] - 1:
|
||
if response.status not in [429, 403]:
|
||
wait_time = policy["base_wait"] * (2**attempt)
|
||
logger.warning(
|
||
f"模型 {self.model_name}: 遇到可重试错误 {response.status}, 等待 {wait_time} 秒后重试..."
|
||
)
|
||
await asyncio.sleep(wait_time)
|
||
last_exception = RuntimeError(f"重试错误 {response.status}")
|
||
continue
|
||
|
||
if response.status in policy["abort_codes"] or (
|
||
response.status in policy["retry_codes"] and attempt >= policy["max_retries"] - 1
|
||
):
|
||
if attempt >= policy["max_retries"] - 1 and response.status in policy["retry_codes"]:
|
||
logger.error(
|
||
f"模型 {self.model_name}: 达到最大重试次数,最后一次尝试仍为可重试错误 {response.status}。"
|
||
)
|
||
# await self._handle_error_response(response, attempt, policy, current_key)
|
||
# await response.read()
|
||
# final_error_msg = f"请求中止或达到最大重试次数,最终状态码: {response.status}"
|
||
# logger.error(final_error_msg)
|
||
# raise RequestAbortException(final_error_msg, response)
|
||
|
||
response.raise_for_status()
|
||
result = {}
|
||
if not self.is_gemini and stream_mode:
|
||
result = await self._handle_stream_output(response)
|
||
else:
|
||
result = await response.json()
|
||
|
||
return (
|
||
response_handler(result)
|
||
if response_handler
|
||
else self._default_response_handler(result, user_id, final_request_type, endpoint)
|
||
)
|
||
|
||
except _SwitchKeyException:
|
||
last_exception = _SwitchKeyException()
|
||
logger.debug("捕获到 _SwitchKeyException,立即进行下一次尝试。")
|
||
continue
|
||
except PermissionDeniedException as e:
|
||
logger.error(f"模型 {self.model_name}: 因权限拒绝 (403) 中止请求: {e}")
|
||
if is_key_list and not available_keys_pool and e.key_identifier:
|
||
logger.critical(f" 中止原因是 Key ...{e.key_identifier[-4:]} 触发 403 后已无其他 Key 可用。")
|
||
raise e
|
||
except aiohttp.ClientProxyConnectionError as e:
|
||
logger.error(f"代理连接错误: {e} (代理地址: {current_proxy_url})")
|
||
last_exception = e
|
||
if attempt >= policy["max_retries"] - 1:
|
||
raise RuntimeError(f"代理连接失败达到最大重试次数: {e}") from e
|
||
wait_time = policy["base_wait"] * (2**attempt)
|
||
logger.warning(f"模型 {self.model_name}: 代理连接错误,等待 {wait_time} 秒后重试...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
except aiohttp.ClientConnectorError as e:
|
||
logger.error(f"网络连接错误: {e} (URL: {api_url}, 代理: {current_proxy_url})")
|
||
last_exception = e
|
||
if attempt >= policy["max_retries"] - 1:
|
||
raise RuntimeError(f"网络连接失败达到最大重试次数: {e}") from e
|
||
wait_time = policy["base_wait"] * (2**attempt)
|
||
logger.warning(f"模型 {self.model_name}: 网络连接错误,等待 {wait_time} 秒后重试...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
except (PayLoadTooLargeError, RequestAbortException) as e:
|
||
# (代码不变)
|
||
logger.error(f"模型 {self.model_name}: 请求处理中遇到关键错误,将中止: {e}")
|
||
raise e
|
||
except Exception as e:
|
||
# (代码不变)
|
||
last_exception = e
|
||
logger.warning(
|
||
f"模型 {self.model_name}: 第 {attempt + 1} 次尝试中发生非 HTTP 错误: {str(e.__class__.__name__)} - {str(e)}"
|
||
)
|
||
|
||
if attempt >= policy["max_retries"] - 1:
|
||
logger.error(
|
||
f"模型 {self.model_name}: 达到最大重试次数 ({policy['max_retries']}),因非 HTTP 错误失败。"
|
||
)
|
||
else:
|
||
try:
|
||
temp_request_content = {
|
||
"policy": policy,
|
||
"payload": actual_payload,
|
||
"api_url": api_url,
|
||
"stream_mode": stream_mode,
|
||
"image_base64": image_base64,
|
||
"image_format": image_format,
|
||
"prompt": prompt,
|
||
}
|
||
handled_payload, count_delta = await self._handle_exception(
|
||
e, attempt, temp_request_content, merged_params=api_kwargs
|
||
)
|
||
if handled_payload:
|
||
actual_payload = handled_payload
|
||
logger.info(f"模型 {self.model_name}: 异常处理更新了 payload,将使用当前 Key 重试。")
|
||
|
||
wait_time = policy["base_wait"] * (2**attempt)
|
||
logger.warning(f"模型 {self.model_name}: 等待 {wait_time} 秒后重试...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
|
||
except (RequestAbortException, PermissionDeniedException) as abort_exception:
|
||
logger.error(f"模型 {self.model_name}: 异常处理判断需要中止请求: {abort_exception}")
|
||
raise abort_exception
|
||
except RuntimeError as rt_error:
|
||
logger.error(f"模型 {self.model_name}: 异常处理遇到运行时错误: {rt_error}")
|
||
raise rt_error
|
||
|
||
# --- 循环结束 ---
|
||
logger.error(f"模型 {self.model_name}: 所有重试尝试 ({policy['max_retries']} 次) 均失败。")
|
||
if last_exception:
|
||
if isinstance(last_exception, PermissionDeniedException):
|
||
logger.error(f"最后遇到的错误是权限拒绝: {str(last_exception)}")
|
||
raise last_exception
|
||
logger.error(f"最后遇到的错误: {str(last_exception.__class__.__name__)} - {str(last_exception)}")
|
||
raise RuntimeError(
|
||
f"模型 {self.model_name} 达到最大重试次数,API 请求失败。最后错误: {str(last_exception)}"
|
||
) from last_exception
|
||
else:
|
||
if not available_keys_pool and keys_abandoned_runtime:
|
||
final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
|
||
logger.critical(final_error_msg)
|
||
raise PermissionDeniedException(final_error_msg)
|
||
else:
|
||
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API 请求失败,原因未知。")
|
||
|
||
async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
|
||
"""处理 OpenAI 兼容的流式输出"""
|
||
# (代码不变)
|
||
flag_delta_content_finished = False
|
||
accumulated_content = ""
|
||
usage = None
|
||
reasoning_content = ""
|
||
content = ""
|
||
tool_calls = None
|
||
|
||
async for line_bytes in response.content:
|
||
try:
|
||
line = line_bytes.decode("utf-8").strip()
|
||
if not line:
|
||
continue
|
||
if line.startswith("data:"):
|
||
data_str = line[5:].strip()
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
chunk = json.loads(data_str)
|
||
if flag_delta_content_finished:
|
||
chunk_usage = chunk.get("usage", None)
|
||
if chunk_usage:
|
||
usage = chunk_usage
|
||
else:
|
||
delta = chunk["choices"][0]["delta"]
|
||
delta_content = delta.get("content")
|
||
if delta_content is None:
|
||
delta_content = ""
|
||
accumulated_content += delta_content
|
||
|
||
if "tool_calls" in delta:
|
||
if tool_calls is None:
|
||
tool_calls = []
|
||
for tc in delta["tool_calls"]:
|
||
new_tc = dict(tc)
|
||
if "function" in new_tc and "arguments" not in new_tc["function"]:
|
||
new_tc["function"]["arguments"] = ""
|
||
tool_calls.append(new_tc)
|
||
else:
|
||
for i, tc_delta in enumerate(delta["tool_calls"]):
|
||
if (
|
||
i < len(tool_calls)
|
||
and "function" in tc_delta
|
||
and "arguments" in tc_delta["function"]
|
||
):
|
||
if "arguments" in tool_calls[i]["function"]:
|
||
tool_calls[i]["function"]["arguments"] += tc_delta["function"][
|
||
"arguments"
|
||
]
|
||
else:
|
||
tool_calls[i]["function"]["arguments"] = tc_delta["function"][
|
||
"arguments"
|
||
]
|
||
elif i >= len(tool_calls):
|
||
new_tc = dict(tc_delta)
|
||
if "function" in new_tc and "arguments" not in new_tc["function"]:
|
||
new_tc["function"]["arguments"] = ""
|
||
tool_calls.append(new_tc)
|
||
|
||
finish_reason = chunk["choices"][0].get("finish_reason")
|
||
if delta.get("reasoning_content", None):
|
||
reasoning_content += delta["reasoning_content"]
|
||
if finish_reason == "stop" or finish_reason == "tool_calls":
|
||
chunk_usage = chunk.get("usage", None)
|
||
if chunk_usage:
|
||
usage = chunk_usage
|
||
break
|
||
flag_delta_content_finished = True
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"模型 {self.model_name} 解析流式 JSON 错误: {e} - data: '{data_str}'")
|
||
except Exception as e:
|
||
logger.exception(f"模型 {self.model_name} 解析流式输出块错误: {str(e)}")
|
||
except UnicodeDecodeError as e:
|
||
logger.warning(f"模型 {self.model_name} 流式输出解码错误: {e} - bytes: {line_bytes[:50]}...")
|
||
except Exception as e:
|
||
if isinstance(e, GeneratorExit):
|
||
log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
|
||
else:
|
||
log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
|
||
logger.warning(log_content)
|
||
try:
|
||
await response.release()
|
||
except Exception as cleanup_error:
|
||
logger.error(f"清理资源时发生错误: {cleanup_error}")
|
||
content = accumulated_content
|
||
break
|
||
if not content and accumulated_content:
|
||
content = accumulated_content
|
||
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||
if think_match:
|
||
reasoning_content = think_match.group(1).strip()
|
||
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
||
|
||
message = {
|
||
"content": content,
|
||
"reasoning_content": reasoning_content,
|
||
}
|
||
|
||
if tool_calls:
|
||
message["tool_calls"] = tool_calls
|
||
|
||
result = {
|
||
"choices": [{"message": message}],
|
||
"usage": usage,
|
||
}
|
||
return result
|
||
|
||
async def _handle_error_response(
|
||
self, response: ClientResponse, retry_count: int, policy: Dict[str, Any], current_key: str = None
|
||
) -> None:
|
||
"""处理 HTTP 错误响应 (区分 403 和其他错误)"""
|
||
# (代码不变)
|
||
status = response.status
|
||
try:
|
||
error_text = await response.text()
|
||
except Exception as e:
|
||
error_text = f"(无法读取响应体: {e})"
|
||
|
||
if status == 403:
|
||
logger.error(
|
||
f"模型 {self.model_name}: 遇到 403 (权限拒绝) 错误。Key: ...{current_key[-4:] if current_key else 'N/A'}. "
|
||
f"响应: {error_text[:200]}"
|
||
)
|
||
raise PermissionDeniedException(f"模型禁止访问 ({status})", key_identifier=current_key)
|
||
|
||
elif status in policy["retry_codes"] and status != 429:
|
||
if status == 413:
|
||
logger.warning(
|
||
f"模型 {self.model_name}: 错误码 413 (Payload Too Large)。Key: ...{current_key[-4:] if current_key else 'N/A'}. 尝试压缩..."
|
||
)
|
||
raise PayLoadTooLargeError("请求体过大")
|
||
elif status in [500, 503]:
|
||
logger.error(
|
||
f"模型 {self.model_name}: 服务器内部错误或过载 ({status})。Key: ...{current_key[-4:] if current_key else 'N/A'}. "
|
||
f"响应: {error_text[:200]}"
|
||
)
|
||
return
|
||
else:
|
||
logger.warning(
|
||
f"模型 {self.model_name}: 遇到可重试错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}"
|
||
)
|
||
return
|
||
|
||
elif status in policy["abort_codes"]:
|
||
logger.error(
|
||
f"模型 {self.model_name}: 遇到需要中止的错误码: {status} - {error_code_mapping.get(status, '未知错误')}. "
|
||
f"Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}"
|
||
)
|
||
raise RequestAbortException(f"请求出现错误 {status},中止处理", response)
|
||
else:
|
||
logger.error(
|
||
f"模型 {self.model_name}: 遇到未明确处理的错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}"
|
||
)
|
||
try:
|
||
response.raise_for_status()
|
||
raise RequestAbortException(f"未处理的错误状态码 {status}", response)
|
||
except aiohttp.ClientResponseError as e:
|
||
raise RequestAbortException(f"未处理的错误状态码 {status}: {e.message}", response) from e
|
||
|
||
async def _handle_exception(
|
||
self, exception, retry_count: int, request_content: Dict[str, Any], merged_params: Dict[str, Any] = None
|
||
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
|
||
"""处理非 HTTP 错误,支持使用合并后的参数重建 payload"""
|
||
policy = request_content["policy"]
|
||
payload = request_content["payload"]
|
||
_wait_time = policy["base_wait"] * (2**retry_count)
|
||
keep_request = False
|
||
if retry_count < policy["max_retries"] - 1:
|
||
keep_request = True
|
||
|
||
params_for_rebuild = merged_params if merged_params is not None else payload
|
||
|
||
if isinstance(exception, PayLoadTooLargeError):
|
||
if keep_request:
|
||
logger.warning("请求体过大 (PayLoadTooLargeError),尝试压缩图片...")
|
||
image_base64 = request_content.get("image_base64")
|
||
if image_base64:
|
||
compressed_image_base64 = compress_base64_image_by_scale(image_base64)
|
||
if compressed_image_base64 != image_base64:
|
||
new_payload = await self._build_payload(
|
||
request_content["prompt"],
|
||
compressed_image_base64,
|
||
request_content["image_format"],
|
||
params_for_rebuild,
|
||
)
|
||
logger.info("图片压缩成功,将使用压缩后的图片重试。")
|
||
return new_payload, 0
|
||
else:
|
||
logger.warning("图片压缩未改变大小或失败。")
|
||
else:
|
||
logger.warning("请求体过大但请求中不包含图片,无法压缩。")
|
||
return None, 0
|
||
else:
|
||
logger.error("达到最大重试次数,请求体仍然过大。")
|
||
raise RuntimeError("请求体过大,压缩或重试后仍然失败。") from exception
|
||
|
||
elif isinstance(exception, (aiohttp.ClientError, asyncio.TimeoutError)):
|
||
if keep_request:
|
||
logger.error(f"模型 {self.model_name} 网络错误: {str(exception)}")
|
||
return None, 0
|
||
else:
|
||
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
|
||
raise RuntimeError(f"网络请求失败: {str(exception)}") from exception
|
||
|
||
elif isinstance(exception, aiohttp.ClientResponseError):
|
||
if keep_request:
|
||
logger.error(
|
||
f"模型 {self.model_name} HTTP响应错误 (未被策略覆盖): 状态码: {exception.status}, 错误: {exception.message}"
|
||
)
|
||
try:
|
||
error_text = await exception.response.text() if hasattr(exception, "response") else str(exception)
|
||
logger.error(f"服务器错误响应详情: {error_text[:500]}")
|
||
except Exception as parse_err:
|
||
logger.warning(f"无法解析服务器错误响应内容: {str(parse_err)}")
|
||
return None, 0
|
||
else:
|
||
logger.critical(
|
||
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
|
||
)
|
||
current_key_placeholder = request_content.get("current_key", "******")
|
||
handled_payload = await _safely_record(request_content, payload)
|
||
logger.critical(
|
||
f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}"
|
||
)
|
||
raise RuntimeError(
|
||
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
||
) from exception
|
||
|
||
else:
|
||
if keep_request:
|
||
logger.error(
|
||
f"模型 {self.model_name} 遇到未知错误: {str(exception.__class__.__name__)} - {str(exception)}"
|
||
)
|
||
return None, 0
|
||
else:
|
||
logger.critical(
|
||
f"模型 {self.model_name} 请求因未知错误失败: {str(exception.__class__.__name__)} - {str(exception)}"
|
||
)
|
||
current_key_placeholder = request_content.get("current_key", "******")
|
||
handled_payload = await _safely_record(request_content, payload)
|
||
logger.critical(
|
||
f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}"
|
||
)
|
||
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") from exception
|
||
|
||
async def _transform_parameters(self, merged_params: dict) -> dict:
|
||
"""根据模型名称转换合并后的参数,并移除内部参数"""
|
||
# (代码不变)
|
||
new_params = dict(merged_params)
|
||
new_params.pop("request_type", None)
|
||
|
||
if not self.is_gemini and self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
|
||
new_params.pop("temperature", None)
|
||
if "max_tokens" in new_params:
|
||
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
||
elif self.is_gemini:
|
||
gen_config = new_params.get("generationConfig", {})
|
||
if "temperature" in new_params:
|
||
gen_config["temperature"] = new_params.pop("temperature")
|
||
if "max_tokens" in new_params:
|
||
gen_config["maxOutputTokens"] = new_params.pop("max_tokens")
|
||
if "top_p" in new_params:
|
||
gen_config["topP"] = new_params.pop("top_p")
|
||
if "top_k" in new_params:
|
||
gen_config["topK"] = new_params.pop("top_k")
|
||
|
||
if gen_config:
|
||
new_params["generationConfig"] = gen_config
|
||
|
||
new_params.pop("frequency_penalty", None)
|
||
new_params.pop("presence_penalty", None)
|
||
new_params.pop("max_completion_tokens", None)
|
||
|
||
return new_params
|
||
|
||
async def _build_payload(
|
||
self, prompt: str, image_base64: str = None, image_format: str = None, merged_params: dict = None
|
||
) -> dict:
|
||
"""构建请求体 (区分 Gemini 和 OpenAI),使用合并和转换后的参数"""
|
||
# (代码不变)
|
||
if merged_params is None:
|
||
merged_params = self.params
|
||
|
||
params_copy = await self._transform_parameters(merged_params)
|
||
|
||
if self.is_gemini:
|
||
parts = []
|
||
if prompt:
|
||
parts.append({"text": prompt})
|
||
if image_base64:
|
||
mime_type = f"image/{image_format.lower() if image_format else 'jpeg'}"
|
||
parts.append({"inlineData": {"mimeType": mime_type, "data": image_base64}})
|
||
payload = {"contents": [{"parts": parts}], **params_copy}
|
||
payload.pop("model", None)
|
||
# --- 添加 Gemini 安全设置 ---
|
||
safety_settings = [
|
||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||
]
|
||
payload["safetySettings"] = safety_settings
|
||
logger.debug(f"模型 {self.model_name}: 已为 Gemini 函数调用请求添加 safetySettings (BLOCK_NONE)。")
|
||
# --- 结束添加安全设置 ---
|
||
|
||
else:
|
||
if image_base64:
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64}"
|
||
},
|
||
},
|
||
],
|
||
}
|
||
]
|
||
else:
|
||
messages = [{"role": "user", "content": prompt}]
|
||
|
||
payload = {
|
||
"model": self.model_name,
|
||
"messages": messages,
|
||
**params_copy,
|
||
}
|
||
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
||
if "max_tokens" not in params_copy and "max_completion_tokens" not in params_copy:
|
||
payload["max_tokens"] = global_config.model_max_output_length
|
||
if "max_completion_tokens" in payload:
|
||
payload["max_tokens"] = payload.pop("max_completion_tokens")
|
||
|
||
return payload
|
||
|
||
def _default_response_handler(
|
||
self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
|
||
) -> Tuple:
|
||
"""默认响应解析 (区分 Gemini 和 OpenAI),并处理函数/工具调用"""
|
||
content = "没有返回结果"
|
||
reasoning_content = ""
|
||
tool_calls = None # OpenAI 格式
|
||
function_call = None # Gemini 格式
|
||
prompt_tokens = 0
|
||
completion_tokens = 0
|
||
total_tokens = 0
|
||
|
||
if self.is_gemini:
|
||
# --- 解析 Gemini 响应 ---
|
||
try:
|
||
if "candidates" in result and result["candidates"]:
|
||
candidate = result["candidates"][0]
|
||
# 检查是否有 content 和 parts
|
||
if "content" in candidate and "parts" in candidate["content"] and candidate["content"]["parts"]:
|
||
# 查找 functionCall 或 text 部分
|
||
final_text_parts = []
|
||
for part in candidate["content"]["parts"]:
|
||
if "functionCall" in part:
|
||
function_call = part["functionCall"] # 获取 Gemini 的 functionCall
|
||
# Gemini functionCall 通常不与 text 一起返回,这里假设只处理 functionCall
|
||
break # 找到 functionCall 就停止处理 parts
|
||
elif "text" in part:
|
||
final_text_parts.append(part.get("text", ""))
|
||
|
||
if not function_call: # 如果没有 functionCall,处理 text
|
||
raw_content = "".join(final_text_parts).strip()
|
||
content, reasoning = self._extract_reasoning(raw_content)
|
||
reasoning_content = reasoning
|
||
# else: function_call 已获取,content 留空或设为特定值
|
||
|
||
else:
|
||
content = "Gemini响应中缺少 content 或 parts"
|
||
logger.warning(f"模型 {self.model_name}: Gemini 响应格式不完整 (缺少 content/parts): {result}")
|
||
|
||
finish_reason = candidate.get("finishReason")
|
||
if finish_reason == "SAFETY":
|
||
logger.warning(f"模型 {self.model_name}: Gemini 响应因安全设置被阻止。")
|
||
content = "响应内容因安全原因被过滤。"
|
||
elif finish_reason == "RECITATION":
|
||
logger.warning(f"模型 {self.model_name}: Gemini 响应因引用限制被阻止。")
|
||
content = "响应内容因引用限制被过滤。"
|
||
elif finish_reason == "OTHER":
|
||
logger.warning(f"模型 {self.model_name}: Gemini 响应因未知原因停止。")
|
||
# finishReason == "TOOL_CODE" or "FUNCTION_CALL" 是正常情况
|
||
|
||
usage = result.get("usageMetadata", {})
|
||
if usage:
|
||
prompt_tokens = usage.get("promptTokenCount", 0)
|
||
completion_tokens = usage.get("candidatesTokenCount", 0)
|
||
total_tokens = usage.get("totalTokenCount", 0)
|
||
if completion_tokens == 0 and total_tokens > 0:
|
||
completion_tokens = total_tokens - prompt_tokens
|
||
else:
|
||
logger.warning(f"模型 {self.model_name} (Gemini) 的响应中缺少 'usageMetadata' 信息。")
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析 Gemini 响应出错: {e} - 响应: {result}")
|
||
content = "解析 Gemini 响应时出错"
|
||
|
||
else:
|
||
# --- 解析 OpenAI 兼容响应 ---
|
||
# (代码不变)
|
||
if "choices" in result and result["choices"]:
|
||
message = result["choices"][0].get("message", {})
|
||
raw_content = message.get("content", "")
|
||
content, reasoning = self._extract_reasoning(raw_content if raw_content else "")
|
||
|
||
explicit_reasoning = message.get("model_extra", {}).get("reasoning_content", "")
|
||
if not explicit_reasoning:
|
||
explicit_reasoning = message.get("reasoning_content", "")
|
||
reasoning_content = explicit_reasoning if explicit_reasoning else reasoning
|
||
|
||
tool_calls = message.get("tool_calls", None) # 获取 OpenAI 的 tool_calls
|
||
|
||
usage = result.get("usage", {})
|
||
if usage:
|
||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||
completion_tokens = usage.get("completion_tokens", 0)
|
||
total_tokens = usage.get("total_tokens", 0)
|
||
else:
|
||
logger.warning(f"模型 {self.model_name} (OpenAI) 的响应中缺少 'usage' 信息。")
|
||
else:
|
||
logger.warning(f"模型 {self.model_name} (OpenAI) 的响应格式不符合预期: {result}")
|
||
|
||
# --- 记录 Token 使用情况 ---
|
||
# (代码不变)
|
||
if prompt_tokens > 0 or completion_tokens > 0 or total_tokens > 0:
|
||
self._record_usage(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
total_tokens=total_tokens,
|
||
user_id=user_id,
|
||
request_type=request_type,
|
||
endpoint=endpoint,
|
||
)
|
||
else:
|
||
logger.warning(f"模型 {self.model_name}: 未能从响应中提取有效的 token 使用信息。")
|
||
|
||
# --- 返回结果 (统一格式) ---
|
||
final_tool_calls = None
|
||
if tool_calls: # 来自 OpenAI
|
||
final_tool_calls = tool_calls
|
||
logger.debug(f"检测到 OpenAI 工具调用: {final_tool_calls}")
|
||
elif function_call: # 来自 Gemini
|
||
logger.debug(f"检测到 Gemini 函数调用: {function_call}")
|
||
# 将 Gemini functionCall 转换为 OpenAI tool_calls 格式
|
||
# 注意: Gemini 的 functionCall 没有显式的 id 和 type,需要模拟
|
||
final_tool_calls = [
|
||
{
|
||
"id": f"call_{random.randint(1000, 9999)}", # 生成一个随机 ID
|
||
"type": "function",
|
||
"function": {
|
||
"name": function_call.get("name"),
|
||
# Gemini 的参数在 'args' 中,OpenAI 在 'arguments' (通常是 JSON 字符串)
|
||
# 需要将 Gemini 的 dict 参数转换为 JSON 字符串
|
||
"arguments": json.dumps(function_call.get("args", {})),
|
||
},
|
||
}
|
||
]
|
||
logger.debug(f"转换为 OpenAI tool_calls 格式: {final_tool_calls}")
|
||
|
||
if final_tool_calls:
|
||
# 如果有工具/函数调用,通常 content 为空或包含思考过程,这里返回转换后的调用信息
|
||
return content, reasoning_content, final_tool_calls
|
||
else:
|
||
# 没有工具/函数调用,返回普通文本响应
|
||
return content, reasoning_content
|
||
|
||
@staticmethod
|
||
def _extract_reasoning(content: str) -> Tuple[str, str]:
|
||
"""CoT思维链提取"""
|
||
# (代码不变)
|
||
if not content:
|
||
return "", ""
|
||
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||
cleaned_content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||
if match:
|
||
reasoning = match.group(1).strip()
|
||
else:
|
||
reasoning = ""
|
||
return cleaned_content, reasoning
|
||
|
||
async def _build_headers(self, api_key: str, no_key: bool = False) -> dict:
|
||
"""构建请求头 (区分 Gemini 和 OpenAI)"""
|
||
# (代码不变)
|
||
if no_key:
|
||
if self.is_gemini:
|
||
return {"x-goog-api-key": "**********", "Content-Type": "application/json"}
|
||
else:
|
||
return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
|
||
else:
|
||
if not api_key:
|
||
logger.error(f"尝试使用无效 (空) 的 API key 为模型 {self.model_name} 构建请求头。")
|
||
raise ValueError("无效的 API key 提供给 _build_headers。")
|
||
|
||
if self.is_gemini:
|
||
return {"x-goog-api-key": api_key, "Content-Type": "application/json"}
|
||
else:
|
||
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||
|
||
async def generate_response(self, prompt: str, user_id: str = "system", **kwargs) -> Tuple:
|
||
"""根据输入的提示生成模型的异步响应,支持覆盖参数"""
|
||
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
|
||
response = await self._execute_request(
|
||
endpoint=endpoint, prompt=prompt, user_id=user_id, request_type="chat", **kwargs
|
||
)
|
||
if len(response) == 3:
|
||
content, reasoning_content, tool_calls = response
|
||
return content, reasoning_content, self.model_name, tool_calls
|
||
else:
|
||
content, reasoning_content = response
|
||
return content, reasoning_content, self.model_name
|
||
|
||
async def generate_response_for_image(
|
||
self, prompt: str, image_base64: str, image_format: str, user_id: str = "system", **kwargs
|
||
) -> Tuple:
|
||
"""根据输入的提示和图片生成模型的异步响应,支持覆盖参数"""
|
||
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
|
||
response = await self._execute_request(
|
||
endpoint=endpoint,
|
||
prompt=prompt,
|
||
image_base64=image_base64,
|
||
image_format=image_format,
|
||
user_id=user_id,
|
||
request_type="vision",
|
||
**kwargs,
|
||
)
|
||
# _default_response_handler 现在总是返回至少2个值
|
||
if len(response) == 3:
|
||
return response # content, reasoning, tool_calls (tool_calls 可能为 None)
|
||
elif len(response) == 2:
|
||
content, reasoning = response
|
||
return content, reasoning # 对于 vision 请求,通常没有 tool_calls
|
||
else:
|
||
logger.error(f"来自 _default_response_handler 的意外响应格式: {response}")
|
||
return "处理响应出错", ""
|
||
|
||
async def generate_response_async(
|
||
self, prompt: str, user_id: str = "system", request_type: str = "chat", **kwargs
|
||
) -> Union[str, Tuple]:
|
||
"""异步方式根据输入的提示生成模型的响应 (通用),支持覆盖参数"""
|
||
# (代码不变)
|
||
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
|
||
response = await self._execute_request(
|
||
endpoint=endpoint,
|
||
prompt=prompt,
|
||
payload=None,
|
||
retry_policy=None,
|
||
response_handler=None,
|
||
user_id=user_id,
|
||
request_type=request_type,
|
||
**kwargs,
|
||
)
|
||
return response
|
||
|
||
# 修改:实现 Gemini Function Calling 的 Payload 构建
|
||
async def generate_response_tool_async(
|
||
self, prompt: str, tools: list, user_id: str = "system", **kwargs
|
||
) -> tuple[str, str, list | None]:
|
||
"""异步方式根据输入的提示和工具生成模型的响应,支持覆盖参数和 Gemini 函数调用"""
|
||
|
||
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
|
||
merged_params = {**self.params, **kwargs}
|
||
transformed_params = await self._transform_parameters(merged_params) # 清理 request_type 等
|
||
|
||
payload = None
|
||
|
||
if self.is_gemini:
|
||
# --- 构建 Gemini Function Calling Payload ---
|
||
logger.debug(f"为 Gemini ({self.model_name}) 构建函数调用请求。")
|
||
# 1. 转换工具定义 (OpenAI -> Gemini)
|
||
# OpenAI tool format: [{"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}}]
|
||
# Gemini tool format: [{"functionDeclarations": [{"name": ..., "description": ..., "parameters": ...}]}]
|
||
function_declarations = []
|
||
if tools:
|
||
for tool in tools:
|
||
if tool.get("type") == "function" and "function" in tool:
|
||
func_def = tool["function"]
|
||
# Gemini parameters 使用 OpenAPI Schema,与 OpenAI 基本兼容
|
||
function_declarations.append(
|
||
{
|
||
"name": func_def.get("name"),
|
||
"description": func_def.get("description", ""), # Description is required for Gemini
|
||
"parameters": func_def.get(
|
||
"parameters", {"type": "object", "properties": {}}
|
||
), # Ensure parameters exist
|
||
}
|
||
)
|
||
else:
|
||
logger.warning(f"跳过不支持的工具类型或格式: {tool}")
|
||
|
||
if not function_declarations:
|
||
logger.error("没有有效的函数声明可用于 Gemini 请求。")
|
||
return "没有提供有效的函数定义", "", None
|
||
|
||
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||
|
||
# 2. 构建 Gemini Payload
|
||
# parts = [{"text": prompt}] # 初始 parts
|
||
payload = {
|
||
"contents": [{"parts": [{"text": prompt}]}], # 包含用户提示
|
||
"tools": gemini_tools,
|
||
# toolConfig 默认是 AUTO,可以根据需要从 kwargs 获取或硬编码
|
||
# "toolConfig": {"functionCallingConfig": {"mode": "ANY"}}, # 例如强制调用
|
||
**transformed_params, # 合并其他转换后的参数 (如 generationConfig)
|
||
}
|
||
payload.pop("model", None) # Gemini 不在顶层传 model
|
||
payload.pop("messages", None) # 移除 OpenAI 特有的 messages
|
||
payload.pop("tool_choice", None) # 移除 OpenAI 特有的 tool_choice
|
||
|
||
logger.trace(f"构建的 Gemini 函数调用 Payload: {json.dumps(payload, indent=2)}")
|
||
|
||
else:
|
||
# --- 构建 OpenAI Tool Calling Payload ---
|
||
# (逻辑不变)
|
||
logger.debug(f"为 OpenAI 兼容模型 ({self.model_name}) 构建工具调用请求。")
|
||
payload = {
|
||
"model": self.model_name,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
**transformed_params,
|
||
"tools": tools,
|
||
"tool_choice": transformed_params.get("tool_choice", "auto"),
|
||
}
|
||
if "max_completion_tokens" in payload:
|
||
payload["max_tokens"] = payload.pop("max_completion_tokens")
|
||
if "max_tokens" not in payload:
|
||
payload["max_tokens"] = global_config.model_max_output_length
|
||
|
||
# --- 执行请求 ---
|
||
if payload is None:
|
||
logger.error("未能构建有效的 API 请求 payload。")
|
||
return "内部错误:无法构建请求", "", None
|
||
|
||
response = await self._execute_request(
|
||
endpoint=endpoint,
|
||
payload=payload,
|
||
prompt=prompt, # prompt 仍然需要,用于可能的重试
|
||
user_id=user_id,
|
||
request_type="tool_call",
|
||
**kwargs, # 传递原始 kwargs 以便在重试时重新合并
|
||
)
|
||
|
||
# _default_response_handler 现在会处理 Gemini functionCall 并统一格式
|
||
logger.debug(f"模型 {self.model_name} 工具/函数调用返回结果: {response}")
|
||
|
||
if isinstance(response, tuple) and len(response) == 3:
|
||
content, reasoning_content, final_tool_calls = response
|
||
# final_tool_calls 已经是统一的 OpenAI 格式
|
||
return content, reasoning_content, final_tool_calls
|
||
elif isinstance(response, tuple) and len(response) == 2:
|
||
content, reasoning_content = response
|
||
logger.debug("收到普通响应,无工具/函数调用")
|
||
return content, reasoning_content, None
|
||
else:
|
||
logger.error(f"收到来自 _execute_request/_default_response_handler 的意外响应格式: {response}")
|
||
return "处理响应时出错", "", None
|
||
|
||
async def get_embedding(self, text: str, user_id: str = "system", **kwargs) -> Union[list, None]:
|
||
"""异步方法:获取文本的embedding向量,支持覆盖参数 (Gemini Embedding 需注意模型名称)"""
|
||
# (代码不变)
|
||
if len(text) < 1:
|
||
logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
|
||
return None
|
||
|
||
api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"}
|
||
|
||
if self.is_gemini:
|
||
endpoint = ":embedContent"
|
||
payload = {"model": f"models/{self.model_name}", "content": {"parts": [{"text": text}]}, **api_kwargs}
|
||
payload.pop("encoding_format", None)
|
||
payload.pop("input", None)
|
||
|
||
else:
|
||
endpoint = "/embeddings"
|
||
payload = {"model": self.model_name, "input": text, "encoding_format": "float", **api_kwargs}
|
||
payload.pop("content", None)
|
||
payload.pop("taskType", None)
|
||
|
||
def embedding_handler(result):
|
||
# (代码不变)
|
||
embedding_value = None
|
||
prompt_tokens = 0
|
||
completion_tokens = 0
|
||
total_tokens = 0
|
||
|
||
if self.is_gemini:
|
||
if "embedding" in result and "value" in result["embedding"]:
|
||
embedding_value = result["embedding"]["value"]
|
||
logger.warning(f"模型 {self.model_name} (Gemini Embedding): 响应中未找到明确的 token 使用信息。")
|
||
else:
|
||
if "data" in result and len(result["data"]) > 0:
|
||
embedding_value = result["data"][0].get("embedding", None)
|
||
usage = result.get("usage", {})
|
||
if usage:
|
||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||
total_tokens = usage.get("total_tokens", 0)
|
||
else:
|
||
logger.warning(f"模型 {self.model_name} (OpenAI Embedding) 的响应中缺少 'usage' 信息。")
|
||
|
||
self._record_usage(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
total_tokens=total_tokens,
|
||
user_id=user_id,
|
||
request_type="embedding",
|
||
endpoint=endpoint,
|
||
)
|
||
return embedding_value
|
||
|
||
embedding = await self._execute_request(
|
||
endpoint=endpoint,
|
||
payload=payload,
|
||
prompt=text,
|
||
retry_policy={"max_retries": 2, "base_wait": 6},
|
||
response_handler=embedding_handler,
|
||
user_id=user_id,
|
||
request_type="embedding",
|
||
**api_kwargs,
|
||
)
|
||
return embedding
|
||
|
||
|
||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||
"""压缩base64格式的图片到指定大小"""
|
||
# (代码不变)
|
||
try:
|
||
image_data = base64.b64decode(base64_data)
|
||
if len(image_data) <= target_size * 1.05:
|
||
logger.info(f"图片大小 {len(image_data) / 1024:.1f}KB 已足够小,无需压缩。")
|
||
return base64_data
|
||
img = Image.open(io.BytesIO(image_data))
|
||
img_format = img.format
|
||
original_width, original_height = img.size
|
||
scale = max(0.2, min(1.0, (target_size / len(image_data)) ** 0.5))
|
||
new_width = max(1, int(original_width * scale))
|
||
new_height = max(1, int(original_height * scale))
|
||
output_buffer = io.BytesIO()
|
||
save_format = img_format # Default to original format
|
||
|
||
if getattr(img, "is_animated", False) and img.n_frames > 1:
|
||
frames = []
|
||
durations = []
|
||
loop = img.info.get("loop", 0)
|
||
disposal = img.info.get("disposal", 2)
|
||
logger.info(f"检测到 GIF 动图 ({img.n_frames} 帧),尝试按比例压缩...")
|
||
for frame_idx in range(img.n_frames):
|
||
img.seek(frame_idx)
|
||
current_duration = img.info.get("duration", 100)
|
||
durations.append(current_duration)
|
||
new_frame = img.convert("RGBA").copy()
|
||
resized_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||
frames.append(resized_frame)
|
||
if frames:
|
||
frames[0].save(
|
||
output_buffer,
|
||
format="GIF",
|
||
save_all=True,
|
||
append_images=frames[1:],
|
||
optimize=False,
|
||
duration=durations,
|
||
loop=loop,
|
||
disposal=disposal,
|
||
transparency=img.info.get("transparency", None),
|
||
background=img.info.get("background", None),
|
||
)
|
||
save_format = "GIF"
|
||
else:
|
||
logger.warning("未能处理 GIF 帧。")
|
||
return base64_data
|
||
else:
|
||
if img.mode in ("RGBA", "LA") or "transparency" in img.info:
|
||
resized_img = img.convert("RGBA").resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||
save_format = "PNG"
|
||
save_params = {"optimize": True}
|
||
else:
|
||
resized_img = img.convert("RGB").resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||
if img_format and img_format.upper() == "JPEG":
|
||
save_format = "JPEG"
|
||
save_params = {"quality": 85, "optimize": True}
|
||
else:
|
||
save_format = "PNG"
|
||
save_params = {"optimize": True}
|
||
resized_img.save(output_buffer, format=save_format, **save_params)
|
||
|
||
compressed_data = output_buffer.getvalue()
|
||
logger.success(
|
||
f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height} ({img.format} -> {save_format})"
|
||
)
|
||
logger.info(
|
||
f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB (目标: {target_size / 1024:.1f}KB)"
|
||
)
|
||
if len(compressed_data) < len(image_data) * 0.95:
|
||
return base64.b64encode(compressed_data).decode("utf-8")
|
||
else:
|
||
logger.info("压缩效果不明显或反而增大,返回原始图片。")
|
||
return base64_data
|
||
except Exception as e:
|
||
logger.error(f"压缩图片失败: {str(e)}")
|
||
import traceback
|
||
|
||
logger.error(traceback.format_exc())
|
||
return base64_data
|