修复兼容

pull/937/head
Bakadax 2025-05-04 02:06:15 +08:00
parent a253022818
commit 7778b464b1
1 changed files with 169 additions and 102 deletions

View File

@ -37,7 +37,7 @@ logger = get_module_logger("model_utils")
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
# (代码不变)
def __init__(self, message: str):
super().__init__(message)
self.message = message
@ -48,7 +48,7 @@ class PayLoadTooLargeError(Exception):
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
# (代码不变)
def __init__(self, message: str, response: ClientResponse):
super().__init__(message)
self.message = message
@ -60,7 +60,7 @@ class RequestAbortException(Exception):
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
# (代码不变)
def __init__(self, message: str, key_identifier: str = None): # 添加 key 标识符
super().__init__(message)
self.message = message
@ -73,11 +73,13 @@ class PermissionDeniedException(Exception):
# 新增:用于内部标记需要切换 Key 的异常
class _SwitchKeyException(Exception):
"""内部异常用于标记需要切换Key并且跳过标准等待时间."""
# (代码不变)
pass
# 常见Error Code Mapping
error_code_mapping = {
# (代码不变)
400: "参数不正确",
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", # 401 也可能是 Key 无效
402: "账号余额不足",
@ -91,12 +93,10 @@ error_code_mapping = {
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
"""安全地记录请求内容,隐藏敏感信息"""
# (代码不变)
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
# 检查是否为 Gemini 载荷
is_gemini_payload = payload and isinstance(payload, dict) and "contents" in payload
# 创建 payload 的副本进行修改,避免影响原始对象
safe_payload = json.loads(json.dumps(payload)) if payload else {}
if (
@ -104,7 +104,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
and safe_payload
and isinstance(safe_payload, dict)
):
# OpenAI 格式处理
if "messages" in safe_payload and len(safe_payload["messages"]) > 0:
if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
content = safe_payload["messages"][0]["content"]
@ -113,49 +112,29 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
# Gemini 格式处理 (假设图片在 parts 里)
elif is_gemini_payload and "contents" in safe_payload and len(safe_payload["contents"]) > 0:
if isinstance(safe_payload["contents"][0], dict) and "parts" in safe_payload["contents"][0]:
parts = safe_payload["contents"][0]["parts"]
# 查找图片部分 (通常是 inlineData)
for i, part in enumerate(parts):
if isinstance(part, dict) and "inlineData" in part:
# 假设 inlineData 包含 base64 和 mime_type
safe_payload["contents"][0]["parts"][i]["inlineData"]["data"] = f"{image_base64[:10]}...{image_base64[-10:]}"
break # 只处理第一个找到的图片
break
return safe_payload
class LLMRequest:
# 定义需要转换的模型列表,作为类变量避免重复 (OpenAI 特有参数转换)
# (代码不变)
MODELS_NEEDING_TRANSFORMATION = [
"o1",
"o1-2024-12-17",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"o1-pro",
"o1-pro-2025-03-19",
"o3",
"o3-2025-04-16",
"o3-mini",
"o3-mini-2025-01-31o4-mini",
"o4-mini-2025-04-16",
"o1", "o1-2024-12-17", "o1-mini", "o1-mini-2024-09-12", "o1-preview",
"o1-preview-2024-09-12", "o1-pro", "o1-pro-2025-03-19", "o3", "o3-2025-04-16",
"o3-mini", "o3-mini-2025-01-31o4-mini", "o4-mini-2025-04-16",
]
# 类变量,用于存储运行时发现的已失效 Key (避免在同一次运行中重复尝试)
_abandoned_keys_runtime: Set[str] = set()
def __init__(self, model: dict, **kwargs):
"""
初始化 LLMRequest 实例
Args:
model (dict): 包含模型配置的字典应包含 'name', 'key', 'base_url' 等键
**kwargs: 其他传递给模型 API 的参数 ( temperature, max_tokens)作为默认参数
"""
"""初始化 LLMRequest 实例"""
# (代码不变)
self.model_key_name = model["key"]
self.model_name: str = model["name"]
self.params = kwargs
@ -165,11 +144,8 @@ class LLMRequest:
self.request_type = model.get("request_type", "default")
try:
# --- 加载 API Key 和 Base URL ---
raw_api_key_config = os.environ[self.model_key_name]
self.base_url = os.environ[model["base_url"]]
# --- 判断是否为 Gemini 模型 ---
self.is_gemini = "googleapis.com" in self.base_url.lower()
if self.is_gemini:
logger.debug(f"模型 {self.model_name}: 检测到为 Gemini API (Base URL: {self.base_url})")
@ -177,8 +153,7 @@ class LLMRequest:
logger.warning(f"模型 {self.model_name}: Gemini 流式输出处理与 OpenAI 不同,暂时强制禁用流式。")
self.stream = False
# --- 解析和过滤 API Keys ---
# (代码不变)
# 解析和过滤 API Keys (代码不变)
parsed_keys = []
is_list_config = False
try:
@ -241,8 +216,7 @@ class LLMRequest:
raise ValueError(f"Unexpected state: No active keys for {self.model_key_name}.")
# --- 加载代理配置 ---
# (代码不变)
# 加载代理配置 (代码不变)
self.proxy_url = None
self.proxy_models_set = set()
proxy_host = os.environ.get("PROXY_HOST")
@ -253,13 +227,13 @@ class LLMRequest:
try:
int(proxy_port)
self.proxy_url = f"http://{proxy_host}:{proxy_port}"
logger.info(f"代理已配置: {self.proxy_url}")
logger.debug(f"代理已配置: {self.proxy_url}")
if proxy_models_str:
try:
cleaned_str = proxy_models_str.strip('\'"')
self.proxy_models_set = {model_name.strip() for model_name in cleaned_str.split(',') if model_name.strip()}
logger.info(f"以下模型将使用代理: {self.proxy_models_set}")
logger.debug(f"以下模型将使用代理: {self.proxy_models_set}")
except Exception as e:
logger.error(f"解析 PROXY_MODELS ('{proxy_models_str}') 出错: {e}. 代理将不会对特定模型生效。")
self.proxy_models_set = set()
@ -594,7 +568,7 @@ class LLMRequest:
await asyncio.sleep(wait_time)
continue
except aiohttp.ClientConnectorError as e:
# (代码不变)
# (代码不变)
logger.error(f"网络连接错误: {e} (URL: {api_url}, 代理: {current_proxy_url})")
last_exception = e
if attempt >= policy["max_retries"] - 1:
@ -889,21 +863,15 @@ class LLMRequest:
async def _transform_parameters(self, merged_params: dict) -> dict:
"""根据模型名称转换合并后的参数,并移除内部参数"""
# (代码不变)
new_params = dict(merged_params)
new_params.pop("request_type", None)
# --- 移除内部使用的参数 ---
new_params.pop("request_type", None) # 移除 request_type
# 如果还有其他内部参数,也在这里移除
# new_params.pop("internal_param_name", None)
# --- 模型特定参数转换 ---
if not self.is_gemini and self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
# OpenAI 特有转换 (示例)
new_params.pop("temperature", None)
if "max_tokens" in new_params:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
elif self.is_gemini:
# Gemini 参数转换
gen_config = new_params.get("generationConfig", {})
if "temperature" in new_params:
gen_config["temperature"] = new_params.pop("temperature")
@ -913,12 +881,10 @@ class LLMRequest:
gen_config["topP"] = new_params.pop("top_p")
if "top_k" in new_params:
gen_config["topK"] = new_params.pop("top_k")
# ... 其他 Gemini 特定参数 ...
if gen_config:
new_params["generationConfig"] = gen_config
# 移除 OpenAI 特有的顶层参数
new_params.pop("frequency_penalty", None)
new_params.pop("presence_penalty", None)
new_params.pop("max_completion_tokens", None)
@ -984,24 +950,38 @@ class LLMRequest:
def _default_response_handler(
self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
) -> Tuple:
"""默认响应解析 (区分 Gemini 和 OpenAI)"""
# (代码不变)
"""默认响应解析 (区分 Gemini 和 OpenAI),并处理函数/工具调用"""
content = "没有返回结果"
reasoning_content = ""
tool_calls = None
tool_calls = None # OpenAI 格式
function_call = None # Gemini 格式
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
if self.is_gemini:
# --- 解析 Gemini 响应 ---
try:
if "candidates" in result and result["candidates"]:
candidate = result["candidates"][0]
# 检查是否有 content 和 parts
if "content" in candidate and "parts" in candidate["content"] and candidate["content"]["parts"]:
text_parts = [part.get("text", "") for part in candidate["content"]["parts"] if "text" in part]
raw_content = "".join(text_parts).strip()
content, reasoning = self._extract_reasoning(raw_content)
reasoning_content = reasoning
# 查找 functionCall 或 text 部分
final_text_parts = []
for part in candidate["content"]["parts"]:
if "functionCall" in part:
function_call = part["functionCall"] # 获取 Gemini 的 functionCall
# Gemini functionCall 通常不与 text 一起返回,这里假设只处理 functionCall
break # 找到 functionCall 就停止处理 parts
elif "text" in part:
final_text_parts.append(part.get("text", ""))
if not function_call: # 如果没有 functionCall处理 text
raw_content = "".join(final_text_parts).strip()
content, reasoning = self._extract_reasoning(raw_content)
reasoning_content = reasoning
# else: function_call 已获取content 留空或设为特定值
else:
content = "Gemini响应中缺少 content 或 parts"
logger.warning(f"模型 {self.model_name}: Gemini 响应格式不完整 (缺少 content/parts): {result}")
@ -1015,6 +995,7 @@ class LLMRequest:
content = "响应内容因引用限制被过滤。"
elif finish_reason == "OTHER":
logger.warning(f"模型 {self.model_name}: Gemini 响应因未知原因停止。")
# finishReason == "TOOL_CODE" or "FUNCTION_CALL" 是正常情况
usage = result.get("usageMetadata", {})
if usage:
@ -1031,6 +1012,8 @@ class LLMRequest:
content = "解析 Gemini 响应时出错"
else:
# --- 解析 OpenAI 兼容响应 ---
# (代码不变)
if "choices" in result and result["choices"]:
message = result["choices"][0].get("message", {})
raw_content = message.get("content", "")
@ -1041,7 +1024,7 @@ class LLMRequest:
explicit_reasoning = message.get("reasoning_content", "")
reasoning_content = explicit_reasoning if explicit_reasoning else reasoning
tool_calls = message.get("tool_calls", None)
tool_calls = message.get("tool_calls", None) # 获取 OpenAI 的 tool_calls
usage = result.get("usage", {})
if usage:
@ -1053,6 +1036,8 @@ class LLMRequest:
else:
logger.warning(f"模型 {self.model_name} (OpenAI) 的响应格式不符合预期: {result}")
# --- 记录 Token 使用情况 ---
# (代码不变)
if prompt_tokens > 0 or completion_tokens > 0 or total_tokens > 0:
self._record_usage(
prompt_tokens=prompt_tokens,
@ -1065,10 +1050,35 @@ class LLMRequest:
else:
logger.warning(f"模型 {self.model_name}: 未能从响应中提取有效的 token 使用信息。")
if tool_calls:
logger.debug(f"检测到工具调用: {tool_calls}")
return content, reasoning_content, tool_calls
# --- 返回结果 (统一格式) ---
final_tool_calls = None
if tool_calls: # 来自 OpenAI
final_tool_calls = tool_calls
logger.debug(f"检测到 OpenAI 工具调用: {final_tool_calls}")
elif function_call: # 来自 Gemini
logger.debug(f"检测到 Gemini 函数调用: {function_call}")
# 将 Gemini functionCall 转换为 OpenAI tool_calls 格式
# 注意: Gemini 的 functionCall 没有显式的 id 和 type需要模拟
final_tool_calls = [
{
"id": f"call_{random.randint(1000, 9999)}", # 生成一个随机 ID
"type": "function",
"function": {
"name": function_call.get("name"),
# Gemini 的参数在 'args' 中OpenAI 在 'arguments' (通常是 JSON 字符串)
# 需要将 Gemini 的 dict 参数转换为 JSON 字符串
"arguments": json.dumps(function_call.get("args", {}))
}
}
]
logger.debug(f"转换为 OpenAI tool_calls 格式: {final_tool_calls}")
if final_tool_calls:
# 如果有工具/函数调用,通常 content 为空或包含思考过程,这里返回转换后的调用信息
return content, reasoning_content, final_tool_calls
else:
# 没有工具/函数调用,返回普通文本响应
return content, reasoning_content
@ -1136,7 +1146,16 @@ class LLMRequest:
request_type="vision",
**kwargs
)
return response
# _default_response_handler 现在总是返回至少2个值
if len(response) == 3:
return response # content, reasoning, tool_calls (tool_calls 可能为 None)
elif len(response) == 2:
content, reasoning = response
return content, reasoning # 对于 vision 请求,通常没有 tool_calls
else:
logger.error(f"来自 _default_response_handler 的意外响应格式: {response}")
return "处理响应出错", ""
async def generate_response_async(self, prompt: str, user_id: str = "system", request_type: str = "chat", **kwargs) -> Union[str, Tuple]:
"""异步方式根据输入的提示生成模型的响应 (通用),支持覆盖参数"""
@ -1154,53 +1173,102 @@ class LLMRequest:
)
return response
# 修改:实现 Gemini Function Calling 的 Payload 构建
async def generate_response_tool_async(self, prompt: str, tools: list, user_id: str = "system", **kwargs) -> tuple[str, str, list | None]:
"""异步方式根据输入的提示和工具生成模型的响应,支持覆盖参数 (Gemini 暂不完全适配)"""
# (代码不变)
if self.is_gemini:
logger.warning(f"模型 {self.model_name}: Gemini 的函数调用实现与 OpenAI 不同,当前实现可能不兼容。")
return "Gemini 函数调用暂未完全适配", "", None
"""异步方式根据输入的提示和工具生成模型的响应,支持覆盖参数和 Gemini 函数调用"""
endpoint = "/chat/completions"
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
merged_params = {**self.params, **kwargs}
transformed_params = await self._transform_parameters(merged_params)
transformed_params = await self._transform_parameters(merged_params) # 清理 request_type 等
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
**transformed_params,
"tools": tools,
"tool_choice": transformed_params.get("tool_choice", "auto"),
}
if "max_completion_tokens" in data:
data["max_tokens"] = data.pop("max_completion_tokens")
if "max_tokens" not in data:
data["max_tokens"] = global_config.model_max_output_length
payload = None
if self.is_gemini:
# --- 构建 Gemini Function Calling Payload ---
logger.debug(f"为 Gemini ({self.model_name}) 构建函数调用请求。")
# 1. 转换工具定义 (OpenAI -> Gemini)
# OpenAI tool format: [{"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}}]
# Gemini tool format: [{"functionDeclarations": [{"name": ..., "description": ..., "parameters": ...}]}]
function_declarations = []
if tools:
for tool in tools:
if tool.get("type") == "function" and "function" in tool:
func_def = tool["function"]
# Gemini parameters 使用 OpenAPI Schema与 OpenAI 基本兼容
function_declarations.append({
"name": func_def.get("name"),
"description": func_def.get("description", ""), # Description is required for Gemini
"parameters": func_def.get("parameters", {"type": "object", "properties": {}}) # Ensure parameters exist
})
else:
logger.warning(f"跳过不支持的工具类型或格式: {tool}")
if not function_declarations:
logger.error("没有有效的函数声明可用于 Gemini 请求。")
return "没有提供有效的函数定义", "", None
gemini_tools = [{"functionDeclarations": function_declarations}]
# 2. 构建 Gemini Payload
# parts = [{"text": prompt}] # 初始 parts
payload = {
"contents": [{"parts": [{"text": prompt}]}], # 包含用户提示
"tools": gemini_tools,
# toolConfig 默认是 AUTO可以根据需要从 kwargs 获取或硬编码
# "toolConfig": {"functionCallingConfig": {"mode": "ANY"}}, # 例如强制调用
**transformed_params # 合并其他转换后的参数 (如 generationConfig)
}
payload.pop("model", None) # Gemini 不在顶层传 model
payload.pop("messages", None) # 移除 OpenAI 特有的 messages
payload.pop("tool_choice", None) # 移除 OpenAI 特有的 tool_choice
logger.trace(f"构建的 Gemini 函数调用 Payload: {json.dumps(payload, indent=2)}")
else:
# --- 构建 OpenAI Tool Calling Payload ---
# (逻辑不变)
logger.debug(f"为 OpenAI 兼容模型 ({self.model_name}) 构建工具调用请求。")
payload = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
**transformed_params,
"tools": tools,
"tool_choice": transformed_params.get("tool_choice", "auto"),
}
if "max_completion_tokens" in payload:
payload["max_tokens"] = payload.pop("max_completion_tokens")
if "max_tokens" not in payload:
payload["max_tokens"] = global_config.model_max_output_length
# --- 执行请求 ---
if payload is None:
logger.error("未能构建有效的 API 请求 payload。")
return "内部错误:无法构建请求", "", None
response = await self._execute_request(
endpoint=endpoint,
payload=data,
prompt=prompt,
payload=payload,
prompt=prompt, # prompt 仍然需要,用于可能的重试
user_id=user_id,
request_type="tool_call",
**kwargs
**kwargs # 传递原始 kwargs 以便在重试时重新合并
)
logger.debug(f"向模型 {self.model_name} 发送工具调用请求,包含 {len(tools)} 个工具,返回结果: {response}")
# _default_response_handler 现在会处理 Gemini functionCall 并统一格式
logger.debug(f"模型 {self.model_name} 工具/函数调用返回结果: {response}")
if isinstance(response, tuple) and len(response) == 3:
content, reasoning_content, tool_calls = response
if tool_calls:
logger.debug(f"收到工具调用响应,包含 {len(tool_calls)} 个工具调用")
else:
logger.debug("收到响应结构但无实际工具调用,视为普通响应")
return content, reasoning_content, tool_calls
content, reasoning_content, final_tool_calls = response
# final_tool_calls 已经是统一的 OpenAI 格式
return content, reasoning_content, final_tool_calls
elif isinstance(response, tuple) and len(response) == 2:
content, reasoning_content = response
logger.debug("收到普通响应,无工具调用")
logger.debug("收到普通响应,无工具/函数调用")
return content, reasoning_content, None
else:
logger.error(f"收到来自 _execute_request 的意外响应格式: {response}")
logger.error(f"收到来自 _execute_request/_default_response_handler 的意外响应格式: {response}")
return "处理响应时出错", "", None
@ -1211,10 +1279,8 @@ class LLMRequest:
logger.debug("该消息没有长度不再发送获取embedding向量的请求")
return None
# 移除内部参数,避免发送给 API
api_kwargs = {k: v for k, v in kwargs.items() if k != 'request_type'}
if self.is_gemini:
endpoint = ":embedContent"
payload = {
@ -1222,7 +1288,7 @@ class LLMRequest:
"content": {
"parts": [{"text": text}]
},
**api_kwargs # 合并过滤后的 kwargs
**api_kwargs
}
payload.pop("encoding_format", None)
payload.pop("input", None)
@ -1233,7 +1299,7 @@ class LLMRequest:
"model": self.model_name,
"input": text,
"encoding_format": "float",
**api_kwargs # 合并过滤后的 kwargs
**api_kwargs
}
payload.pop("content", None)
payload.pop("taskType", None)
@ -1278,7 +1344,7 @@ class LLMRequest:
response_handler=embedding_handler,
user_id=user_id,
request_type="embedding",
**api_kwargs # 传递过滤后的 kwargs
**api_kwargs
)
return embedding
@ -1350,4 +1416,5 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data
return base64_data