diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index c23ba90a..f28857c7 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -362,18 +362,29 @@ class GeminiClient(BaseClient): http_options_kwargs["api_version"] = parts[1] else: http_options_kwargs["base_url"] = api_provider.base_url + http_options_kwargs["api_version"] = None self.client = genai.Client( http_options=HttpOptions(**http_options_kwargs), api_key=api_provider.api_key, ) # 这里和openai不一样,gemini会自己决定自己是否需要retry @staticmethod - def clamp_thinking_budget(tb: int, model_id: str) -> int: + def clamp_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int: """ 按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本) """ limits = None + # 参数传入处理 + tb = THINKING_BUDGET_AUTO + if extra_params and "thinking_budget" in extra_params: + try: + tb = int(extra_params["thinking_budget"]) + except (ValueError, TypeError): + logger.warning( + f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}" + ) + # 优先尝试精确匹配 if model_id in THINKING_BUDGET_LIMITS: limits = THINKING_BUDGET_LIMITS[model_id] @@ -456,19 +467,9 @@ class GeminiClient(BaseClient): messages = _convert_messages(message_list) # 将tool_options转换为Gemini API所需的格式 tools = _convert_tool_options(tool_options) if tool_options else None - - tb = THINKING_BUDGET_AUTO - # 空处理 - if extra_params and "thinking_budget" in extra_params: - try: - tb = int(extra_params["thinking_budget"]) - except (ValueError, TypeError): - logger.warning( - f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}" - ) - # 裁剪到模型支持的范围 - tb = self.clamp_thinking_budget(tb, model_info.model_identifier) - + # 解析并裁剪 thinking_budget + tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier) + # 将response_format转换为Gemini API所需的格式 generation_config_dict = { "max_output_tokens": max_tokens, @@ -590,41 +591,51 @@ class GeminiClient(BaseClient): return response - def get_audio_transcriptions( - self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None + async def get_audio_transcriptions( + self, + model_info: ModelInfo, + audio_base64: str, + max_tokens: int = 2048, + extra_params: dict[str, Any] | None = None, ) -> APIResponse: """ 获取音频转录 :param model_info: 模型信息 :param audio_base64: 音频文件的Base64编码字符串 + :param max_tokens: 最大输出token数(默认2048) :param extra_params: 额外参数(可选) :return: 转录响应 """ + # 解析并裁剪 thinking_budget + tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier) + + # 构造 prompt + 音频输入 + prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." + contents = [ + Content( + role="user", + parts=[ + Part.from_text(text=prompt), + Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"), + ], + ) + ] + generation_config_dict = { - "max_output_tokens": 2048, + "max_output_tokens": max_tokens, "response_modalities": ["TEXT"], "thinking_config": ThinkingConfig( include_thoughts=True, - thinking_budget=( - extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024 - ), + thinking_budget=tb, ), "safety_settings": gemini_safe_settings, } generate_content_config = GenerateContentConfig(**generation_config_dict) - prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." + try: - raw_response: GenerateContentResponse = self.client.models.generate_content( + raw_response: GenerateContentResponse = await self.client.aio.models.generate_content( model=model_info.model_identifier, - contents=[ - Content( - role="user", - parts=[ - Part.from_text(text=prompt), - Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"), - ], - ) - ], + contents=contents, config=generate_content_config, ) resp, usage_record = _default_normal_response_parser(raw_response)