Merge pull request #1269 from foxplaying/patch-2

Gemini:修复意外thought输出并增加Search功能和截断提示
pull/1325/head
SengokuCola 2025-10-26 23:07:24 +08:00 committed by GitHub
commit 2b09edbe77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 77 additions and 11 deletions

View File

@ -13,6 +13,7 @@ from google.genai.types import (
ContentUnion,
ThinkingConfig,
Tool,
GoogleSearch,
GenerateContentConfig,
EmbedContentResponse,
EmbedContentConfig,
@ -176,19 +177,21 @@ def _process_delta(
delta: GenerateContentResponse,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
resp: APIResponse | None = None,
):
if not hasattr(delta, "candidates") or not delta.candidates:
raise RespParseException(delta, "响应解析失败缺失candidates字段")
if delta.text:
fc_delta_buffer.write(delta.text)
# 处理 thoughtGemini 的特殊字段)
for c in getattr(delta, "candidates", []):
if c.content and getattr(c.content, "parts", None):
for p in c.content.parts:
if getattr(p, "thought", False) and getattr(p, "text", None):
# 把 thought 写入 buffer避免 resp.content 永远为空
# 保存到 reasoning_content
if resp is not None:
resp.reasoning_content = (resp.reasoning_content or "") + p.text
elif getattr(p, "text", None):
# 正常输出写入 buffer
fc_delta_buffer.write(p.text)
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
@ -213,9 +216,11 @@ def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
last_resp: GenerateContentResponse | None = None, # 传入 last_resp
resp: APIResponse | None = None,
) -> APIResponse:
# sourcery skip: simplify-len-comparison, use-assigned-variable
resp = APIResponse()
if resp is None:
resp = APIResponse()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
@ -240,11 +245,15 @@ def _build_stream_api_resp(
# 检查是否因为 max_tokens 截断
reason = None
if last_resp and getattr(last_resp, "candidates", None):
c0 = last_resp.candidates[0]
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
for c in last_resp.candidates:
fr = getattr(c, "finish_reason", None) or getattr(c, "finishReason", None)
if fr:
reason = str(fr)
break
if str(reason).endswith("MAX_TOKENS"):
if resp.content and resp.content.strip():
has_visible_output = bool(resp.content and resp.content.strip())
if has_visible_output:
logger.warning(
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
@ -253,7 +262,8 @@ def _build_stream_api_resp(
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
if not resp.content and not resp.tool_calls:
raise EmptyResponseException()
if not getattr(resp, "reasoning_content", None):
raise EmptyResponseException()
return resp
@ -271,7 +281,8 @@ async def _default_stream_response_handler(
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
resp = APIResponse()
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
@ -287,6 +298,7 @@ async def _default_stream_response_handler(
chunk,
_fc_delta_buffer,
_tool_calls_buffer,
resp=resp,
)
if chunk.usage_metadata:
@ -302,6 +314,7 @@ async def _default_stream_response_handler(
_fc_delta_buffer,
_tool_calls_buffer,
last_resp=last_resp,
resp=resp,
), _usage_record
except Exception:
# 确保缓冲区被关闭
@ -526,6 +539,15 @@ class GeminiClient(BaseClient):
tools = _convert_tool_options(tool_options) if tool_options else None
# 解析并裁剪 thinking_budget
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
# 检测是否为带 -search 的模型
enable_google_search = False
model_identifier = model_info.model_identifier
if model_identifier.endswith("-search"):
enable_google_search = True
# 去掉后缀并更新模型ID
model_identifier = model_identifier.removesuffix("-search")
model_info.model_identifier = model_identifier
logger.info(f"模型已启用 GoogleSearch 功能:{model_identifier}")
# 将response_format转换为Gemini API所需的格式
generation_config_dict = {
@ -548,6 +570,17 @@ class GeminiClient(BaseClient):
elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA):
generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = response_format.to_dict()
# 自动启用 GoogleSearch grounding_tool
if enable_google_search:
grounding_tool = Tool(google_search=GoogleSearch())
if "tools" in generation_config_dict:
existing = generation_config_dict["tools"]
if isinstance(existing, list):
existing.append(grounding_tool)
else:
generation_config_dict["tools"] = [existing, grounding_tool]
else:
generation_config_dict["tools"] = [grounding_tool]
generation_config = GenerateContentConfig(**generation_config_dict)

View File

@ -199,6 +199,7 @@ def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_rc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, io.StringIO]],
finish_reason: str | None = None,
) -> APIResponse:
resp = APIResponse()
@ -236,6 +237,16 @@ def _build_stream_api_resp(
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
# 检查 max_tokens 截断
if finish_reason == "length":
if resp.content and resp.content.strip():
logger.warning(
"⚠ OpenAI 响应因达到 max_tokens 限制被部分截断,\n"
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
)
else:
logger.warning("⚠ OpenAI 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
if not resp.content and not resp.tool_calls:
raise EmptyResponseException()
@ -258,6 +269,7 @@ async def _default_stream_response_handler(
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
finish_reason: str | None = None # 记录最后的 finish_reason
def _insure_buffer_closed():
# 确保缓冲区被关闭
@ -285,6 +297,9 @@ async def _default_stream_response_handler(
continue # 跳过本帧,避免访问 choices[0]
delta = event.choices[0].delta # 获取当前块的delta内容
if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason:
finish_reason = event.choices[0].finish_reason
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
# 标记:有独立的推理内容块
_has_rc_attr_flag = True
@ -311,6 +326,7 @@ async def _default_stream_response_handler(
_fc_delta_buffer,
_rc_delta_buffer,
_tool_calls_buffer,
finish_reason=finish_reason,
), _usage_record
except Exception:
# 确保缓冲区被关闭
@ -381,6 +397,23 @@ def _default_normal_response_parser(
# 将原始响应存储在原始数据中
api_response.raw_data = resp
# 检查 max_tokens 截断
try:
choice0 = resp.choices[0]
reason = getattr(choice0, "finish_reason", None)
if reason and reason == "length":
has_real_output = bool(api_response.content and api_response.content.strip())
if has_real_output:
logger.warning(
"⚠ OpenAI 响应因达到 max_tokens 限制被部分截断,\n"
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
)
else:
logger.warning("⚠ OpenAI 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
return api_response, _usage_record
except Exception as e:
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
if not api_response.content and not api_response.tool_calls:
raise EmptyResponseException()