diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 7930a035..b4e2c127 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -19,13 +19,20 @@ logger = get_module_logger("model_utils") class LLMRequest: # 定义需要转换的模型列表,作为类变量避免重复 MODELS_NEEDING_TRANSFORMATION = [ - "o3-mini", - "o1-mini", - "o1-preview", + "o1", "o1-2024-12-17", - "o1-preview-2024-09-12", - "o3-mini-2025-01-31", + "o1-mini", "o1-mini-2024-09-12", + "o1-preview", + "o1-preview-2024-09-12", + "o1-pro", + "o1-pro-2025-03-19", + "o3", + "o3-2025-04-16", + "o3-mini", + "o3-mini-2025-01-31", + "o4-mini", + "o4-mini-2025-04-16" ] def __init__(self, model, **kwargs): @@ -38,7 +45,9 @@ class LLMRequest: logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e self.model_name = model["name"] - self.params = kwargs + # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default" + self.request_type = kwargs.pop("request_type", "default") + self.params = self._init_transform_params(kwargs) self.stream = model.get("stream", False) self.pri_in = model.get("pri_in", 0) @@ -47,8 +56,27 @@ class LLMRequest: # 获取数据库实例 self._init_database() - # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default" - self.request_type = kwargs.pop("request_type", "default") + def _init_transform_params(self, raw: dict) -> dict: + """ + 把 raw kwargs: + - 如果 model 需要转换,就把 max_tokens→max_completion_tokens,并丢掉 temperature + - 其它一律保留 + """ + need = {m.lower() for m in self.MODELS_NEEDING_TRANSFORMATION} + is_o = any(self.model_name.lower().startswith(m) for m in need) + + newp = {} + for k, v in raw.items(): + # 对于 o 系列要丢掉 temperature + if is_o and k == "temperature": + continue + # 对于 o 系列把 max_tokens 改名 + if is_o and k == "max_tokens": + newp["max_completion_tokens"] = v + # 其余都原样保留 + else: + newp[k] = v + return newp @staticmethod def _init_database(): @@ -510,55 +538,37 @@ class LLMRequest: logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败") raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败") - async def _transform_parameters(self, params: dict) -> dict: - """ - 根据模型名称转换参数: - - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数, - 并将 'max_tokens' 重命名为 'max_completion_tokens' - """ - # 复制一份参数,避免直接修改原始数据 - new_params = dict(params) - - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION: - # 删除 'temperature' 参数(如果存在) - new_params.pop("temperature", None) - # 如果存在 'max_tokens',则重命名为 'max_completion_tokens' - if "max_tokens" in new_params: - new_params["max_completion_tokens"] = new_params.pop("max_tokens") - return new_params - async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict: """构建请求体""" # 复制一份参数,避免直接修改 self.params - params_copy = await self._transform_parameters(self.params) + params_copy = self.params + # 准备 messages if image_base64: - payload = { - "model": self.model_name, - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}, + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/{image_format.lower()};base64,{image_base64}" }, - ], - } - ], - "max_tokens": global_config.max_response_length, - **params_copy, - } + }, + ], + } + ] else: - payload = { - "model": self.model_name, - "messages": [{"role": "user", "content": prompt}], - "max_tokens": global_config.max_response_length, - **params_copy, - } - # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: - payload["max_completion_tokens"] = payload.pop("max_tokens") + messages = [{"role": "user", "content": prompt}] + # 基础 payload(暂不加 max_tokens) + payload = { + "model": self.model_name, + "messages": messages, + **params_copy, + } + # 如果用户既没传 max_tokens,也没传 max_completion_tokens,就补上一个默认值 + if "max_tokens" not in params_copy and "max_completion_tokens" not in params_copy: + payload["max_tokens"] = global_config.max_response_length return payload def _default_response_handler( @@ -648,16 +658,20 @@ class LLMRequest: async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: """异步方式根据输入的提示生成模型的响应""" - # 构建请求体 + # 构建基础 payload,不硬编码 max_tokens data = { "model": self.model_name, "messages": [{"role": "user", "content": prompt}], - "max_tokens": global_config.max_response_length, **self.params, **kwargs, } - - response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt) + # 如果调用者既没传 max_tokens 也没传 max_completion_tokens,就补默认值 + if "max_tokens" not in data and "max_completion_tokens" not in data: + data["max_tokens"] = global_config.max_response_length + # 发请求 + response = await self._execute_request( + endpoint="/chat/completions", payload=data, prompt=prompt + ) # 原样返回响应,不做处理 return response