优化处理

pull/1261/head
foxplaying 2025-09-24 13:05:33 +08:00 committed by GitHub
parent a1e9893ac1
commit e2d277fd7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 44 additions and 30 deletions

View File

@ -339,6 +339,19 @@ def _default_normal_response_parser(
return api_response, _usage_record return api_response, _usage_record
def resolve_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int:
"""
统一解析并裁剪 thinking_budget
"""
tb = THINKING_BUDGET_AUTO
if extra_params and "thinking_budget" in extra_params:
try:
tb = int(extra_params["thinking_budget"])
except (ValueError, TypeError):
logger.warning(
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
)
return GeminiClient.clamp_thinking_budget(tb, model_id)
@client_registry.register_client_class("gemini") @client_registry.register_client_class("gemini")
class GeminiClient(BaseClient): class GeminiClient(BaseClient):
@ -362,6 +375,7 @@ class GeminiClient(BaseClient):
http_options_kwargs["api_version"] = parts[1] http_options_kwargs["api_version"] = parts[1]
else: else:
http_options_kwargs["base_url"] = api_provider.base_url http_options_kwargs["base_url"] = api_provider.base_url
http_options_kwargs["api_version"] = None
self.client = genai.Client( self.client = genai.Client(
http_options=HttpOptions(**http_options_kwargs), http_options=HttpOptions(**http_options_kwargs),
api_key=api_provider.api_key, api_key=api_provider.api_key,
@ -456,18 +470,8 @@ class GeminiClient(BaseClient):
messages = _convert_messages(message_list) messages = _convert_messages(message_list)
# 将tool_options转换为Gemini API所需的格式 # 将tool_options转换为Gemini API所需的格式
tools = _convert_tool_options(tool_options) if tool_options else None tools = _convert_tool_options(tool_options) if tool_options else None
# 解析并裁剪 thinking_budget
tb = THINKING_BUDGET_AUTO tb = resolve_thinking_budget(extra_params, model_info.model_identifier)
# 空处理
if extra_params and "thinking_budget" in extra_params:
try:
tb = int(extra_params["thinking_budget"])
except (ValueError, TypeError):
logger.warning(
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
)
# 裁剪到模型支持的范围
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
# 将response_format转换为Gemini API所需的格式 # 将response_format转换为Gemini API所需的格式
generation_config_dict = { generation_config_dict = {
@ -590,41 +594,51 @@ class GeminiClient(BaseClient):
return response return response
def get_audio_transcriptions( async def get_audio_transcriptions(
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None self,
model_info: ModelInfo,
audio_base64: str,
max_tokens: int = 1024,
extra_params: dict[str, Any] | None = None,
) -> APIResponse: ) -> APIResponse:
""" """
获取音频转录 获取音频转录
:param model_info: 模型信息 :param model_info: 模型信息
:param audio_base64: 音频文件的Base64编码字符串 :param audio_base64: 音频文件的Base64编码字符串
:param max_tokens: 最大输出token数默认1024
:param extra_params: 额外参数可选 :param extra_params: 额外参数可选
:return: 转录响应 :return: 转录响应
""" """
# 解析并裁剪 thinking_budget
tb = resolve_thinking_budget(extra_params, model_info.model_identifier)
# 构造 prompt + 音频输入
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
contents = [
Content(
role="user",
parts=[
Part.from_text(text=prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
],
)
]
generation_config_dict = { generation_config_dict = {
"max_output_tokens": 2048, "max_output_tokens": max_tokens,
"response_modalities": ["TEXT"], "response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig( "thinking_config": ThinkingConfig(
include_thoughts=True, include_thoughts=True,
thinking_budget=( thinking_budget=tb,
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
),
), ),
"safety_settings": gemini_safe_settings, "safety_settings": gemini_safe_settings,
} }
generate_content_config = GenerateContentConfig(**generation_config_dict) generate_content_config = GenerateContentConfig(**generation_config_dict)
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
try: try:
raw_response: GenerateContentResponse = self.client.models.generate_content( raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
model=model_info.model_identifier, model=model_info.model_identifier,
contents=[ contents=contents,
Content(
role="user",
parts=[
Part.from_text(text=prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
],
)
],
config=generate_content_config, config=generate_content_config,
) )
resp, usage_record = _default_normal_response_parser(raw_response) resp, usage_record = _default_normal_response_parser(raw_response)