mirror of https://github.com/Mai-with-u/MaiBot.git
修复Openai o系列模型支持
parent
1516d91e90
commit
92b780249c
|
|
@ -19,13 +19,20 @@ logger = get_module_logger("model_utils")
|
|||
class LLMRequest:
|
||||
# 定义需要转换的模型列表,作为类变量避免重复
|
||||
MODELS_NEEDING_TRANSFORMATION = [
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o1",
|
||||
"o1-2024-12-17",
|
||||
"o1-preview-2024-09-12",
|
||||
"o3-mini-2025-01-31",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1-preview",
|
||||
"o1-preview-2024-09-12",
|
||||
"o1-pro",
|
||||
"o1-pro-2025-03-19",
|
||||
"o3",
|
||||
"o3-2025-04-16",
|
||||
"o3-mini",
|
||||
"o3-mini-2025-01-31",
|
||||
"o4-mini",
|
||||
"o4-mini-2025-04-16"
|
||||
]
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
|
|
@ -38,7 +45,9 @@ class LLMRequest:
|
|||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
|
||||
self.model_name = model["name"]
|
||||
self.params = kwargs
|
||||
# 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
|
||||
self.request_type = kwargs.pop("request_type", "default")
|
||||
self.params = self._init_transform_params(kwargs)
|
||||
|
||||
self.stream = model.get("stream", False)
|
||||
self.pri_in = model.get("pri_in", 0)
|
||||
|
|
@ -47,8 +56,27 @@ class LLMRequest:
|
|||
# 获取数据库实例
|
||||
self._init_database()
|
||||
|
||||
# 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
|
||||
self.request_type = kwargs.pop("request_type", "default")
|
||||
def _init_transform_params(self, raw: dict) -> dict:
|
||||
"""
|
||||
把 raw kwargs:
|
||||
- 如果 model 需要转换,就把 max_tokens→max_completion_tokens,并丢掉 temperature
|
||||
- 其它一律保留
|
||||
"""
|
||||
need = {m.lower() for m in self.MODELS_NEEDING_TRANSFORMATION}
|
||||
is_o = any(self.model_name.lower().startswith(m) for m in need)
|
||||
|
||||
newp = {}
|
||||
for k, v in raw.items():
|
||||
# 对于 o 系列要丢掉 temperature
|
||||
if is_o and k == "temperature":
|
||||
continue
|
||||
# 对于 o 系列把 max_tokens 改名
|
||||
if is_o and k == "max_tokens":
|
||||
newp["max_completion_tokens"] = v
|
||||
# 其余都原样保留
|
||||
else:
|
||||
newp[k] = v
|
||||
return newp
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
|
|
@ -510,55 +538,37 @@ class LLMRequest:
|
|||
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
|
||||
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
|
||||
|
||||
async def _transform_parameters(self, params: dict) -> dict:
|
||||
"""
|
||||
根据模型名称转换参数:
|
||||
- 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
|
||||
并将 'max_tokens' 重命名为 'max_completion_tokens'
|
||||
"""
|
||||
# 复制一份参数,避免直接修改原始数据
|
||||
new_params = dict(params)
|
||||
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
|
||||
# 删除 'temperature' 参数(如果存在)
|
||||
new_params.pop("temperature", None)
|
||||
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
|
||||
if "max_tokens" in new_params:
|
||||
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
||||
return new_params
|
||||
|
||||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
||||
"""构建请求体"""
|
||||
# 复制一份参数,避免直接修改 self.params
|
||||
params_copy = await self._transform_parameters(self.params)
|
||||
params_copy = self.params
|
||||
# 准备 messages
|
||||
if image_base64:
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{image_format.lower()};base64,{image_base64}"
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": global_config.max_response_length,
|
||||
**params_copy,
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
else:
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": global_config.max_response_length,
|
||||
**params_copy,
|
||||
}
|
||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
# 基础 payload(暂不加 max_tokens)
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
**params_copy,
|
||||
}
|
||||
# 如果用户既没传 max_tokens,也没传 max_completion_tokens,就补上一个默认值
|
||||
if "max_tokens" not in params_copy and "max_completion_tokens" not in params_copy:
|
||||
payload["max_tokens"] = global_config.max_response_length
|
||||
return payload
|
||||
|
||||
def _default_response_handler(
|
||||
|
|
@ -648,16 +658,20 @@ class LLMRequest:
|
|||
|
||||
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
# 构建请求体
|
||||
# 构建基础 payload,不硬编码 max_tokens
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": global_config.max_response_length,
|
||||
**self.params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
|
||||
# 如果调用者既没传 max_tokens 也没传 max_completion_tokens,就补默认值
|
||||
if "max_tokens" not in data and "max_completion_tokens" not in data:
|
||||
data["max_tokens"] = global_config.max_response_length
|
||||
# 发请求
|
||||
response = await self._execute_request(
|
||||
endpoint="/chat/completions", payload=data, prompt=prompt
|
||||
)
|
||||
# 原样返回响应,不做处理
|
||||
return response
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue