From e2d277fd7c36dae08fafc1f2f73ba000c3eca0ad Mon Sep 17 00:00:00 2001 From: foxplaying <166147707+foxplaying@users.noreply.github.com> Date: Wed, 24 Sep 2025 13:05:33 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/model_client/gemini_client.py | 74 ++++++++++++-------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index c23ba90a..091df8f8 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -339,6 +339,19 @@ def _default_normal_response_parser( return api_response, _usage_record +def resolve_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int: + """ + 统一解析并裁剪 thinking_budget + """ + tb = THINKING_BUDGET_AUTO + if extra_params and "thinking_budget" in extra_params: + try: + tb = int(extra_params["thinking_budget"]) + except (ValueError, TypeError): + logger.warning( + f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}" + ) + return GeminiClient.clamp_thinking_budget(tb, model_id) @client_registry.register_client_class("gemini") class GeminiClient(BaseClient): @@ -362,6 +375,7 @@ class GeminiClient(BaseClient): http_options_kwargs["api_version"] = parts[1] else: http_options_kwargs["base_url"] = api_provider.base_url + http_options_kwargs["api_version"] = None self.client = genai.Client( http_options=HttpOptions(**http_options_kwargs), api_key=api_provider.api_key, @@ -456,19 +470,9 @@ class GeminiClient(BaseClient): messages = _convert_messages(message_list) # 将tool_options转换为Gemini API所需的格式 tools = _convert_tool_options(tool_options) if tool_options else None - - tb = THINKING_BUDGET_AUTO - # 空处理 - if extra_params and "thinking_budget" in extra_params: - try: - tb = int(extra_params["thinking_budget"]) - except (ValueError, TypeError): - logger.warning( - f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}" - ) - # 裁剪到模型支持的范围 - tb = self.clamp_thinking_budget(tb, model_info.model_identifier) - + # 解析并裁剪 thinking_budget + tb = resolve_thinking_budget(extra_params, model_info.model_identifier) + # 将response_format转换为Gemini API所需的格式 generation_config_dict = { "max_output_tokens": max_tokens, @@ -590,41 +594,51 @@ class GeminiClient(BaseClient): return response - def get_audio_transcriptions( - self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None + async def get_audio_transcriptions( + self, + model_info: ModelInfo, + audio_base64: str, + max_tokens: int = 1024, + extra_params: dict[str, Any] | None = None, ) -> APIResponse: """ 获取音频转录 :param model_info: 模型信息 :param audio_base64: 音频文件的Base64编码字符串 + :param max_tokens: 最大输出token数(默认1024) :param extra_params: 额外参数(可选) :return: 转录响应 """ + # 解析并裁剪 thinking_budget + tb = resolve_thinking_budget(extra_params, model_info.model_identifier) + + # 构造 prompt + 音频输入 + prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." + contents = [ + Content( + role="user", + parts=[ + Part.from_text(text=prompt), + Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"), + ], + ) + ] + generation_config_dict = { - "max_output_tokens": 2048, + "max_output_tokens": max_tokens, "response_modalities": ["TEXT"], "thinking_config": ThinkingConfig( include_thoughts=True, - thinking_budget=( - extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024 - ), + thinking_budget=tb, ), "safety_settings": gemini_safe_settings, } generate_content_config = GenerateContentConfig(**generation_config_dict) - prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." + try: - raw_response: GenerateContentResponse = self.client.models.generate_content( + raw_response: GenerateContentResponse = await self.client.aio.models.generate_content( model=model_info.model_identifier, - contents=[ - Content( - role="user", - parts=[ - Part.from_text(text=prompt), - Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"), - ], - ) - ], + contents=contents, config=generate_content_config, ) resp, usage_record = _default_normal_response_parser(raw_response)