diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index e8cf44f7..b83c3b8f 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -13,6 +13,7 @@ from google.genai.types import ( ContentUnion, ThinkingConfig, Tool, + GoogleSearch, GenerateContentConfig, EmbedContentResponse, EmbedContentConfig, @@ -176,19 +177,21 @@ def _process_delta( delta: GenerateContentResponse, fc_delta_buffer: io.StringIO, tool_calls_buffer: list[tuple[str, str, dict[str, Any]]], + resp: APIResponse | None = None, ): if not hasattr(delta, "candidates") or not delta.candidates: raise RespParseException(delta, "响应解析失败,缺失candidates字段") - if delta.text: - fc_delta_buffer.write(delta.text) - # 处理 thought(Gemini 的特殊字段) for c in getattr(delta, "candidates", []): if c.content and getattr(c.content, "parts", None): for p in c.content.parts: if getattr(p, "thought", False) and getattr(p, "text", None): - # 把 thought 写入 buffer,避免 resp.content 永远为空 + # 保存到 reasoning_content + if resp is not None: + resp.reasoning_content = (resp.reasoning_content or "") + p.text + elif getattr(p, "text", None): + # 正常输出写入 buffer fc_delta_buffer.write(p.text) if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的 @@ -213,9 +216,11 @@ def _build_stream_api_resp( _fc_delta_buffer: io.StringIO, _tool_calls_buffer: list[tuple[str, str, dict]], last_resp: GenerateContentResponse | None = None, # 传入 last_resp + resp: APIResponse | None = None, ) -> APIResponse: # sourcery skip: simplify-len-comparison, use-assigned-variable - resp = APIResponse() + if resp is None: + resp = APIResponse() if _fc_delta_buffer.tell() > 0: # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 @@ -240,11 +245,15 @@ def _build_stream_api_resp( # 检查是否因为 max_tokens 截断 reason = None if last_resp and getattr(last_resp, "candidates", None): - c0 = last_resp.candidates[0] - reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None) - + for c in last_resp.candidates: + fr = getattr(c, "finish_reason", None) or getattr(c, "finishReason", None) + if fr: + reason = str(fr) + break + if str(reason).endswith("MAX_TOKENS"): - if resp.content and resp.content.strip(): + has_visible_output = bool(resp.content and resp.content.strip()) + if has_visible_output: logger.warning( "⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n" " 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!" @@ -253,7 +262,8 @@ def _build_stream_api_resp( logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!") if not resp.content and not resp.tool_calls: - raise EmptyResponseException() + if not getattr(resp, "reasoning_content", None): + raise EmptyResponseException() return resp @@ -271,7 +281,8 @@ async def _default_stream_response_handler( _tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 _usage_record = None # 使用情况记录 last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk - + resp = APIResponse() + def _insure_buffer_closed(): if _fc_delta_buffer and not _fc_delta_buffer.closed: _fc_delta_buffer.close() @@ -287,6 +298,7 @@ async def _default_stream_response_handler( chunk, _fc_delta_buffer, _tool_calls_buffer, + resp=resp, ) if chunk.usage_metadata: @@ -302,6 +314,7 @@ async def _default_stream_response_handler( _fc_delta_buffer, _tool_calls_buffer, last_resp=last_resp, + resp=resp, ), _usage_record except Exception: # 确保缓冲区被关闭 @@ -526,6 +539,15 @@ class GeminiClient(BaseClient): tools = _convert_tool_options(tool_options) if tool_options else None # 解析并裁剪 thinking_budget tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier) + # 检测是否为带 -search 的模型 + enable_google_search = False + model_identifier = model_info.model_identifier + if model_identifier.endswith("-search"): + enable_google_search = True + # 去掉后缀并更新模型ID + model_identifier = model_identifier.removesuffix("-search") + model_info.model_identifier = model_identifier + logger.info(f"模型已启用 GoogleSearch 功能:{model_identifier}") # 将response_format转换为Gemini API所需的格式 generation_config_dict = { @@ -548,6 +570,17 @@ class GeminiClient(BaseClient): elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA): generation_config_dict["response_mime_type"] = "application/json" generation_config_dict["response_schema"] = response_format.to_dict() + # 自动启用 GoogleSearch grounding_tool + if enable_google_search: + grounding_tool = Tool(google_search=GoogleSearch()) + if "tools" in generation_config_dict: + existing = generation_config_dict["tools"] + if isinstance(existing, list): + existing.append(grounding_tool) + else: + generation_config_dict["tools"] = [existing, grounding_tool] + else: + generation_config_dict["tools"] = [grounding_tool] generation_config = GenerateContentConfig(**generation_config_dict) diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 148ec8cb..dd92b9e8 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -199,6 +199,7 @@ def _build_stream_api_resp( _fc_delta_buffer: io.StringIO, _rc_delta_buffer: io.StringIO, _tool_calls_buffer: list[tuple[str, str, io.StringIO]], + finish_reason: str | None = None, ) -> APIResponse: resp = APIResponse() @@ -236,6 +237,16 @@ def _build_stream_api_resp( resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) + # 检查 max_tokens 截断 + if finish_reason == "length": + if resp.content and resp.content.strip(): + logger.warning( + "⚠ OpenAI 响应因达到 max_tokens 限制被部分截断,\n" + " 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!" + ) + else: + logger.warning("⚠ OpenAI 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!") + if not resp.content and not resp.tool_calls: raise EmptyResponseException() @@ -258,6 +269,7 @@ async def _default_stream_response_handler( _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 _tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 _usage_record = None # 使用情况记录 + finish_reason: str | None = None # 记录最后的 finish_reason def _insure_buffer_closed(): # 确保缓冲区被关闭 @@ -285,6 +297,9 @@ async def _default_stream_response_handler( continue # 跳过本帧,避免访问 choices[0] delta = event.choices[0].delta # 获取当前块的delta内容 + if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason: + finish_reason = event.choices[0].finish_reason + if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore # 标记:有独立的推理内容块 _has_rc_attr_flag = True @@ -311,6 +326,7 @@ async def _default_stream_response_handler( _fc_delta_buffer, _rc_delta_buffer, _tool_calls_buffer, + finish_reason=finish_reason, ), _usage_record except Exception: # 确保缓冲区被关闭 @@ -381,6 +397,23 @@ def _default_normal_response_parser( # 将原始响应存储在原始数据中 api_response.raw_data = resp + # 检查 max_tokens 截断 + try: + choice0 = resp.choices[0] + reason = getattr(choice0, "finish_reason", None) + if reason and reason == "length": + has_real_output = bool(api_response.content and api_response.content.strip()) + if has_real_output: + logger.warning( + "⚠ OpenAI 响应因达到 max_tokens 限制被部分截断,\n" + " 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!" + ) + else: + logger.warning("⚠ OpenAI 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!") + return api_response, _usage_record + except Exception as e: + logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}") + if not api_response.content and not api_response.tool_calls: raise EmptyResponseException()