支持流式

pull/1264/head
foxplaying 2025-09-24 23:19:53 +08:00 committed by GitHub
parent dcb75d7b94
commit a54bf78945
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 24 additions and 3 deletions

View File

@ -246,12 +246,14 @@ async def _default_stream_response_handler(
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
async for chunk in resp_stream:
last_resp = chunk # 保存最后一个响应
# 检查是否有中断量
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
@ -270,13 +272,32 @@ async def _default_stream_response_handler(
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
chunk.usage_metadata.total_token_count or 0,
)
try:
return _build_stream_api_resp(
api_response = _build_stream_api_resp(
_fc_delta_buffer,
_tool_calls_buffer,
), _usage_record
)
# 检查是否因为 max_tokens 截断
if last_resp and last_resp.candidates:
c0 = last_resp.candidates[0]
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
if reason and "MAX_TOKENS" in str(reason):
if api_response.content and api_response.content.strip():
logger.warning(
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
)
else:
logger.warning(
"⚠ Gemini 响应因达到 max_tokens 限制被截断,\n"
" 请修改模型 max_tokens 配置!"
)
return api_response, _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
raise