mirror of https://github.com/Mai-with-u/MaiBot.git
优化处理
parent
a1e9893ac1
commit
e2d277fd7c
|
|
@ -339,6 +339,19 @@ def _default_normal_response_parser(
|
||||||
|
|
||||||
return api_response, _usage_record
|
return api_response, _usage_record
|
||||||
|
|
||||||
|
def resolve_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int:
|
||||||
|
"""
|
||||||
|
统一解析并裁剪 thinking_budget
|
||||||
|
"""
|
||||||
|
tb = THINKING_BUDGET_AUTO
|
||||||
|
if extra_params and "thinking_budget" in extra_params:
|
||||||
|
try:
|
||||||
|
tb = int(extra_params["thinking_budget"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
|
||||||
|
)
|
||||||
|
return GeminiClient.clamp_thinking_budget(tb, model_id)
|
||||||
|
|
||||||
@client_registry.register_client_class("gemini")
|
@client_registry.register_client_class("gemini")
|
||||||
class GeminiClient(BaseClient):
|
class GeminiClient(BaseClient):
|
||||||
|
|
@ -362,6 +375,7 @@ class GeminiClient(BaseClient):
|
||||||
http_options_kwargs["api_version"] = parts[1]
|
http_options_kwargs["api_version"] = parts[1]
|
||||||
else:
|
else:
|
||||||
http_options_kwargs["base_url"] = api_provider.base_url
|
http_options_kwargs["base_url"] = api_provider.base_url
|
||||||
|
http_options_kwargs["api_version"] = None
|
||||||
self.client = genai.Client(
|
self.client = genai.Client(
|
||||||
http_options=HttpOptions(**http_options_kwargs),
|
http_options=HttpOptions(**http_options_kwargs),
|
||||||
api_key=api_provider.api_key,
|
api_key=api_provider.api_key,
|
||||||
|
|
@ -456,19 +470,9 @@ class GeminiClient(BaseClient):
|
||||||
messages = _convert_messages(message_list)
|
messages = _convert_messages(message_list)
|
||||||
# 将tool_options转换为Gemini API所需的格式
|
# 将tool_options转换为Gemini API所需的格式
|
||||||
tools = _convert_tool_options(tool_options) if tool_options else None
|
tools = _convert_tool_options(tool_options) if tool_options else None
|
||||||
|
# 解析并裁剪 thinking_budget
|
||||||
tb = THINKING_BUDGET_AUTO
|
tb = resolve_thinking_budget(extra_params, model_info.model_identifier)
|
||||||
# 空处理
|
|
||||||
if extra_params and "thinking_budget" in extra_params:
|
|
||||||
try:
|
|
||||||
tb = int(extra_params["thinking_budget"])
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
logger.warning(
|
|
||||||
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
|
|
||||||
)
|
|
||||||
# 裁剪到模型支持的范围
|
|
||||||
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
|
||||||
|
|
||||||
# 将response_format转换为Gemini API所需的格式
|
# 将response_format转换为Gemini API所需的格式
|
||||||
generation_config_dict = {
|
generation_config_dict = {
|
||||||
"max_output_tokens": max_tokens,
|
"max_output_tokens": max_tokens,
|
||||||
|
|
@ -590,41 +594,51 @@ class GeminiClient(BaseClient):
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_audio_transcriptions(
|
async def get_audio_transcriptions(
|
||||||
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
audio_base64: str,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
获取音频转录
|
获取音频转录
|
||||||
:param model_info: 模型信息
|
:param model_info: 模型信息
|
||||||
:param audio_base64: 音频文件的Base64编码字符串
|
:param audio_base64: 音频文件的Base64编码字符串
|
||||||
|
:param max_tokens: 最大输出token数(默认1024)
|
||||||
:param extra_params: 额外参数(可选)
|
:param extra_params: 额外参数(可选)
|
||||||
:return: 转录响应
|
:return: 转录响应
|
||||||
"""
|
"""
|
||||||
|
# 解析并裁剪 thinking_budget
|
||||||
|
tb = resolve_thinking_budget(extra_params, model_info.model_identifier)
|
||||||
|
|
||||||
|
# 构造 prompt + 音频输入
|
||||||
|
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||||
|
contents = [
|
||||||
|
Content(
|
||||||
|
role="user",
|
||||||
|
parts=[
|
||||||
|
Part.from_text(text=prompt),
|
||||||
|
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
generation_config_dict = {
|
generation_config_dict = {
|
||||||
"max_output_tokens": 2048,
|
"max_output_tokens": max_tokens,
|
||||||
"response_modalities": ["TEXT"],
|
"response_modalities": ["TEXT"],
|
||||||
"thinking_config": ThinkingConfig(
|
"thinking_config": ThinkingConfig(
|
||||||
include_thoughts=True,
|
include_thoughts=True,
|
||||||
thinking_budget=(
|
thinking_budget=tb,
|
||||||
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"safety_settings": gemini_safe_settings,
|
"safety_settings": gemini_safe_settings,
|
||||||
}
|
}
|
||||||
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||||
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
|
||||||
try:
|
try:
|
||||||
raw_response: GenerateContentResponse = self.client.models.generate_content(
|
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
contents=[
|
contents=contents,
|
||||||
Content(
|
|
||||||
role="user",
|
|
||||||
parts=[
|
|
||||||
Part.from_text(text=prompt),
|
|
||||||
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
],
|
|
||||||
config=generate_content_config,
|
config=generate_content_config,
|
||||||
)
|
)
|
||||||
resp, usage_record = _default_normal_response_parser(raw_response)
|
resp, usage_record = _default_normal_response_parser(raw_response)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue