修复Openai o系列模型支持

pull/791/head
sky2002 2025-04-18 17:58:54 +08:00
parent 1516d91e90
commit 92b780249c
1 changed files with 69 additions and 55 deletions

View File

@ -19,13 +19,20 @@ logger = get_module_logger("model_utils")
class LLMRequest:
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
"o3-mini",
"o1-mini",
"o1-preview",
"o1",
"o1-2024-12-17",
"o1-preview-2024-09-12",
"o3-mini-2025-01-31",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"o1-pro",
"o1-pro-2025-03-19",
"o3",
"o3-2025-04-16",
"o3-mini",
"o3-mini-2025-01-31",
"o4-mini",
"o4-mini-2025-04-16"
]
def __init__(self, model, **kwargs):
@ -38,7 +45,9 @@ class LLMRequest:
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
self.model_name = model["name"]
self.params = kwargs
# 从 kwargs 中提取 request_type如果没有提供则默认为 "default"
self.request_type = kwargs.pop("request_type", "default")
self.params = self._init_transform_params(kwargs)
self.stream = model.get("stream", False)
self.pri_in = model.get("pri_in", 0)
@ -47,8 +56,27 @@ class LLMRequest:
# 获取数据库实例
self._init_database()
# 从 kwargs 中提取 request_type如果没有提供则默认为 "default"
self.request_type = kwargs.pop("request_type", "default")
def _init_transform_params(self, raw: dict) -> dict:
"""
raw kwargs
- 如果 model 需要转换就把 max_tokensmax_completion_tokens并丢掉 temperature
- 其它一律保留
"""
need = {m.lower() for m in self.MODELS_NEEDING_TRANSFORMATION}
is_o = any(self.model_name.lower().startswith(m) for m in need)
newp = {}
for k, v in raw.items():
# 对于 o 系列要丢掉 temperature
if is_o and k == "temperature":
continue
# 对于 o 系列把 max_tokens 改名
if is_o and k == "max_tokens":
newp["max_completion_tokens"] = v
# 其余都原样保留
else:
newp[k] = v
return newp
@staticmethod
def _init_database():
@ -510,55 +538,37 @@ class LLMRequest:
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数API请求仍然失败")
async def _transform_parameters(self, params: dict) -> dict:
"""
根据模型名称转换参数
- 对于需要转换的OpenAI CoT系列模型例如 "o3-mini"删除 'temperature' 参数
并将 'max_tokens' 重命名为 'max_completion_tokens'
"""
# 复制一份参数,避免直接修改原始数据
new_params = dict(params)
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
# 删除 'temperature' 参数(如果存在)
new_params.pop("temperature", None)
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
if "max_tokens" in new_params:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
"""构建请求体"""
# 复制一份参数,避免直接修改 self.params
params_copy = await self._transform_parameters(self.params)
params_copy = self.params
# 准备 messages
if image_base64:
payload = {
"model": self.model_name,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/{image_format.lower()};base64,{image_base64}"
},
],
}
],
"max_tokens": global_config.max_response_length,
**params_copy,
}
},
],
}
]
else:
payload = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
**params_copy,
}
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
messages = [{"role": "user", "content": prompt}]
# 基础 payload暂不加 max_tokens
payload = {
"model": self.model_name,
"messages": messages,
**params_copy,
}
# 如果用户既没传 max_tokens也没传 max_completion_tokens就补上一个默认值
if "max_tokens" not in params_copy and "max_completion_tokens" not in params_copy:
payload["max_tokens"] = global_config.max_response_length
return payload
def _default_response_handler(
@ -648,16 +658,20 @@ class LLMRequest:
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
"""异步方式根据输入的提示生成模型的响应"""
# 构建请求体
# 构建基础 payload不硬编码 max_tokens
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
**self.params,
**kwargs,
}
response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
# 如果调用者既没传 max_tokens 也没传 max_completion_tokens就补默认值
if "max_tokens" not in data and "max_completion_tokens" not in data:
data["max_tokens"] = global_config.max_response_length
# 发请求
response = await self._execute_request(
endpoint="/chat/completions", payload=data, prompt=prompt
)
# 原样返回响应,不做处理
return response