mirror of https://github.com/Mai-with-u/MaiBot.git
commit
3a2685cf26
|
|
@ -72,8 +72,8 @@ class BaseClient(ABC):
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
message_list: list[Message],
|
message_list: list[Message],
|
||||||
tool_options: list[ToolOption] | None = None,
|
tool_options: list[ToolOption] | None = None,
|
||||||
max_tokens: int = 1024,
|
max_tokens: Optional[int] = None,
|
||||||
temperature: float = 0.7,
|
temperature: Optional[float] = None,
|
||||||
response_format: RespFormat | None = None,
|
response_format: RespFormat | None = None,
|
||||||
stream_response_handler: Optional[
|
stream_response_handler: Optional[
|
||||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||||
|
|
@ -117,6 +117,7 @@ class BaseClient(ABC):
|
||||||
self,
|
self,
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
audio_base64: str,
|
audio_base64: str,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
extra_params: dict[str, Any] | None = None,
|
extra_params: dict[str, Any] | None = None,
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -182,7 +182,15 @@ def _process_delta(
|
||||||
|
|
||||||
if delta.text:
|
if delta.text:
|
||||||
fc_delta_buffer.write(delta.text)
|
fc_delta_buffer.write(delta.text)
|
||||||
|
|
||||||
|
# 处理 thought(Gemini 的特殊字段)
|
||||||
|
for c in getattr(delta, "candidates", []):
|
||||||
|
if c.content and getattr(c.content, "parts", None):
|
||||||
|
for p in c.content.parts:
|
||||||
|
if getattr(p, "thought", False) and getattr(p, "text", None):
|
||||||
|
# 把 thought 写入 buffer,避免 resp.content 永远为空
|
||||||
|
fc_delta_buffer.write(p.text)
|
||||||
|
|
||||||
if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的
|
if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的
|
||||||
for call in delta.function_calls:
|
for call in delta.function_calls:
|
||||||
try:
|
try:
|
||||||
|
|
@ -204,6 +212,7 @@ def _process_delta(
|
||||||
def _build_stream_api_resp(
|
def _build_stream_api_resp(
|
||||||
_fc_delta_buffer: io.StringIO,
|
_fc_delta_buffer: io.StringIO,
|
||||||
_tool_calls_buffer: list[tuple[str, str, dict]],
|
_tool_calls_buffer: list[tuple[str, str, dict]],
|
||||||
|
last_resp: GenerateContentResponse | None = None, # 传入 last_resp
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
# sourcery skip: simplify-len-comparison, use-assigned-variable
|
# sourcery skip: simplify-len-comparison, use-assigned-variable
|
||||||
resp = APIResponse()
|
resp = APIResponse()
|
||||||
|
|
@ -228,6 +237,24 @@ def _build_stream_api_resp(
|
||||||
|
|
||||||
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
|
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
|
||||||
|
|
||||||
|
# 检查是否因为 max_tokens 截断
|
||||||
|
reason = None
|
||||||
|
if last_resp and getattr(last_resp, "candidates", None):
|
||||||
|
c0 = last_resp.candidates[0]
|
||||||
|
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
|
||||||
|
|
||||||
|
if str(reason).endswith("MAX_TOKENS"):
|
||||||
|
if resp.content and resp.content.strip():
|
||||||
|
logger.warning(
|
||||||
|
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
|
||||||
|
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"⚠ Gemini 响应因达到 max_tokens 限制被截断,\n"
|
||||||
|
" 请修改模型 max_tokens 配置!"
|
||||||
|
)
|
||||||
|
|
||||||
if not resp.content and not resp.tool_calls:
|
if not resp.content and not resp.tool_calls:
|
||||||
raise EmptyResponseException()
|
raise EmptyResponseException()
|
||||||
|
|
||||||
|
|
@ -246,12 +273,14 @@ async def _default_stream_response_handler(
|
||||||
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
|
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
|
||||||
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
|
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
|
||||||
_usage_record = None # 使用情况记录
|
_usage_record = None # 使用情况记录
|
||||||
|
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
|
||||||
|
|
||||||
def _insure_buffer_closed():
|
def _insure_buffer_closed():
|
||||||
if _fc_delta_buffer and not _fc_delta_buffer.closed:
|
if _fc_delta_buffer and not _fc_delta_buffer.closed:
|
||||||
_fc_delta_buffer.close()
|
_fc_delta_buffer.close()
|
||||||
|
|
||||||
async for chunk in resp_stream:
|
async for chunk in resp_stream:
|
||||||
|
last_resp = chunk # 保存最后一个响应
|
||||||
# 检查是否有中断量
|
# 检查是否有中断量
|
||||||
if interrupt_flag and interrupt_flag.is_set():
|
if interrupt_flag and interrupt_flag.is_set():
|
||||||
# 如果中断量被设置,则抛出ReqAbortException
|
# 如果中断量被设置,则抛出ReqAbortException
|
||||||
|
|
@ -270,10 +299,12 @@ async def _default_stream_response_handler(
|
||||||
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
|
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
|
||||||
chunk.usage_metadata.total_token_count or 0,
|
chunk.usage_metadata.total_token_count or 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return _build_stream_api_resp(
|
return _build_stream_api_resp(
|
||||||
_fc_delta_buffer,
|
_fc_delta_buffer,
|
||||||
_tool_calls_buffer,
|
_tool_calls_buffer,
|
||||||
|
last_resp=last_resp,
|
||||||
), _usage_record
|
), _usage_record
|
||||||
except Exception:
|
except Exception:
|
||||||
# 确保缓冲区被关闭
|
# 确保缓冲区被关闭
|
||||||
|
|
@ -333,6 +364,38 @@ def _default_normal_response_parser(
|
||||||
|
|
||||||
api_response.raw_data = resp
|
api_response.raw_data = resp
|
||||||
|
|
||||||
|
# 检查是否因为 max_tokens 截断
|
||||||
|
try:
|
||||||
|
if resp.candidates:
|
||||||
|
c0 = resp.candidates[0]
|
||||||
|
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
|
||||||
|
if reason and "MAX_TOKENS" in str(reason):
|
||||||
|
# 检查第二个及之后的 parts 是否有内容
|
||||||
|
has_real_output = False
|
||||||
|
if getattr(c0, "content", None) and getattr(c0.content, "parts", None):
|
||||||
|
for p in c0.content.parts[1:]: # 跳过第一个 thought
|
||||||
|
if getattr(p, "text", None) and p.text.strip():
|
||||||
|
has_real_output = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_real_output and getattr(resp, "text", None):
|
||||||
|
has_real_output = True
|
||||||
|
|
||||||
|
if has_real_output:
|
||||||
|
logger.warning(
|
||||||
|
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
|
||||||
|
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"⚠ Gemini 响应因达到 max_tokens 限制被截断,\n"
|
||||||
|
" 请修改模型 max_tokens 配置!"
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_response, _usage_record
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
|
||||||
|
|
||||||
# 最终的、唯一的空响应检查
|
# 最终的、唯一的空响应检查
|
||||||
if not api_response.content and not api_response.tool_calls:
|
if not api_response.content and not api_response.tool_calls:
|
||||||
raise EmptyResponseException("响应中既无文本内容也无工具调用")
|
raise EmptyResponseException("响应中既无文本内容也无工具调用")
|
||||||
|
|
@ -362,18 +425,29 @@ class GeminiClient(BaseClient):
|
||||||
http_options_kwargs["api_version"] = parts[1]
|
http_options_kwargs["api_version"] = parts[1]
|
||||||
else:
|
else:
|
||||||
http_options_kwargs["base_url"] = api_provider.base_url
|
http_options_kwargs["base_url"] = api_provider.base_url
|
||||||
|
http_options_kwargs["api_version"] = None
|
||||||
self.client = genai.Client(
|
self.client = genai.Client(
|
||||||
http_options=HttpOptions(**http_options_kwargs),
|
http_options=HttpOptions(**http_options_kwargs),
|
||||||
api_key=api_provider.api_key,
|
api_key=api_provider.api_key,
|
||||||
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clamp_thinking_budget(tb: int, model_id: str) -> int:
|
def clamp_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本)
|
按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本)
|
||||||
"""
|
"""
|
||||||
limits = None
|
limits = None
|
||||||
|
|
||||||
|
# 参数传入处理
|
||||||
|
tb = THINKING_BUDGET_AUTO
|
||||||
|
if extra_params and "thinking_budget" in extra_params:
|
||||||
|
try:
|
||||||
|
tb = int(extra_params["thinking_budget"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
|
||||||
|
)
|
||||||
|
|
||||||
# 优先尝试精确匹配
|
# 优先尝试精确匹配
|
||||||
if model_id in THINKING_BUDGET_LIMITS:
|
if model_id in THINKING_BUDGET_LIMITS:
|
||||||
limits = THINKING_BUDGET_LIMITS[model_id]
|
limits = THINKING_BUDGET_LIMITS[model_id]
|
||||||
|
|
@ -416,8 +490,8 @@ class GeminiClient(BaseClient):
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
message_list: list[Message],
|
message_list: list[Message],
|
||||||
tool_options: list[ToolOption] | None = None,
|
tool_options: list[ToolOption] | None = None,
|
||||||
max_tokens: int = 1024,
|
max_tokens: Optional[int] = 1024,
|
||||||
temperature: float = 0.4,
|
temperature: Optional[float] = 0.4,
|
||||||
response_format: RespFormat | None = None,
|
response_format: RespFormat | None = None,
|
||||||
stream_response_handler: Optional[
|
stream_response_handler: Optional[
|
||||||
Callable[
|
Callable[
|
||||||
|
|
@ -456,19 +530,9 @@ class GeminiClient(BaseClient):
|
||||||
messages = _convert_messages(message_list)
|
messages = _convert_messages(message_list)
|
||||||
# 将tool_options转换为Gemini API所需的格式
|
# 将tool_options转换为Gemini API所需的格式
|
||||||
tools = _convert_tool_options(tool_options) if tool_options else None
|
tools = _convert_tool_options(tool_options) if tool_options else None
|
||||||
|
# 解析并裁剪 thinking_budget
|
||||||
tb = THINKING_BUDGET_AUTO
|
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
|
||||||
# 空处理
|
|
||||||
if extra_params and "thinking_budget" in extra_params:
|
|
||||||
try:
|
|
||||||
tb = int(extra_params["thinking_budget"])
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
logger.warning(
|
|
||||||
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
|
|
||||||
)
|
|
||||||
# 裁剪到模型支持的范围
|
|
||||||
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
|
||||||
|
|
||||||
# 将response_format转换为Gemini API所需的格式
|
# 将response_format转换为Gemini API所需的格式
|
||||||
generation_config_dict = {
|
generation_config_dict = {
|
||||||
"max_output_tokens": max_tokens,
|
"max_output_tokens": max_tokens,
|
||||||
|
|
@ -526,15 +590,20 @@ class GeminiClient(BaseClient):
|
||||||
|
|
||||||
resp, usage_record = async_response_parser(req_task.result())
|
resp, usage_record = async_response_parser(req_task.result())
|
||||||
except (ClientError, ServerError) as e:
|
except (ClientError, ServerError) as e:
|
||||||
# 重封装ClientError和ServerError为RespNotOkException
|
# 重封装 ClientError 和 ServerError 为 RespNotOkException
|
||||||
raise RespNotOkException(e.code, e.message) from None
|
raise RespNotOkException(e.code, e.message) from None
|
||||||
except (
|
except (
|
||||||
UnknownFunctionCallArgumentError,
|
UnknownFunctionCallArgumentError,
|
||||||
UnsupportedFunctionError,
|
UnsupportedFunctionError,
|
||||||
FunctionInvocationError,
|
FunctionInvocationError,
|
||||||
) as e:
|
) as e:
|
||||||
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
|
# 工具调用相关错误
|
||||||
|
raise RespParseException(None, f"工具调用参数错误: {str(e)}") from None
|
||||||
|
except EmptyResponseException as e:
|
||||||
|
# 保持原始异常,便于区分“空响应”和网络异常
|
||||||
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 其他未预料的错误,才归为网络连接类
|
||||||
raise NetworkConnectionError() from e
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
if usage_record:
|
if usage_record:
|
||||||
|
|
@ -590,41 +659,51 @@ class GeminiClient(BaseClient):
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_audio_transcriptions(
|
async def get_audio_transcriptions(
|
||||||
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
audio_base64: str,
|
||||||
|
max_tokens: Optional[int] = 2048,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
获取音频转录
|
获取音频转录
|
||||||
:param model_info: 模型信息
|
:param model_info: 模型信息
|
||||||
:param audio_base64: 音频文件的Base64编码字符串
|
:param audio_base64: 音频文件的Base64编码字符串
|
||||||
|
:param max_tokens: 最大输出token数(默认2048)
|
||||||
:param extra_params: 额外参数(可选)
|
:param extra_params: 额外参数(可选)
|
||||||
:return: 转录响应
|
:return: 转录响应
|
||||||
"""
|
"""
|
||||||
|
# 解析并裁剪 thinking_budget
|
||||||
|
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
|
||||||
|
|
||||||
|
# 构造 prompt + 音频输入
|
||||||
|
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||||
|
contents = [
|
||||||
|
Content(
|
||||||
|
role="user",
|
||||||
|
parts=[
|
||||||
|
Part.from_text(text=prompt),
|
||||||
|
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
generation_config_dict = {
|
generation_config_dict = {
|
||||||
"max_output_tokens": 2048,
|
"max_output_tokens": max_tokens,
|
||||||
"response_modalities": ["TEXT"],
|
"response_modalities": ["TEXT"],
|
||||||
"thinking_config": ThinkingConfig(
|
"thinking_config": ThinkingConfig(
|
||||||
include_thoughts=True,
|
include_thoughts=True,
|
||||||
thinking_budget=(
|
thinking_budget=tb,
|
||||||
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"safety_settings": gemini_safe_settings,
|
"safety_settings": gemini_safe_settings,
|
||||||
}
|
}
|
||||||
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||||
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
|
||||||
try:
|
try:
|
||||||
raw_response: GenerateContentResponse = self.client.models.generate_content(
|
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
contents=[
|
contents=contents,
|
||||||
Content(
|
|
||||||
role="user",
|
|
||||||
parts=[
|
|
||||||
Part.from_text(text=prompt),
|
|
||||||
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
],
|
|
||||||
config=generate_content_config,
|
config=generate_content_config,
|
||||||
)
|
)
|
||||||
resp, usage_record = _default_normal_response_parser(raw_response)
|
resp, usage_record = _default_normal_response_parser(raw_response)
|
||||||
|
|
|
||||||
|
|
@ -403,8 +403,8 @@ class OpenaiClient(BaseClient):
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
message_list: list[Message],
|
message_list: list[Message],
|
||||||
tool_options: list[ToolOption] | None = None,
|
tool_options: list[ToolOption] | None = None,
|
||||||
max_tokens: int = 1024,
|
max_tokens: Optional[int] = 1024,
|
||||||
temperature: float = 0.7,
|
temperature: Optional[float] = 0.7,
|
||||||
response_format: RespFormat | None = None,
|
response_format: RespFormat | None = None,
|
||||||
stream_response_handler: Optional[
|
stream_response_handler: Optional[
|
||||||
Callable[
|
Callable[
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue