mirror of https://github.com/Mai-with-u/MaiBot.git
fix: 捕获OpenAI流式错误并触发重试
parent
4d5456ed4b
commit
cb40ff6a0e
|
|
@ -7,13 +7,7 @@ from collections.abc import Iterable
|
||||||
from typing import Callable, Any, Coroutine, Optional
|
from typing import Callable, Any, Coroutine, Optional
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
from openai import (
|
from openai import AsyncOpenAI, APIConnectionError, APIError, APIStatusError, NOT_GIVEN, AsyncStream
|
||||||
AsyncOpenAI,
|
|
||||||
APIConnectionError,
|
|
||||||
APIStatusError,
|
|
||||||
NOT_GIVEN,
|
|
||||||
AsyncStream,
|
|
||||||
)
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
|
|
@ -39,6 +33,23 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||||
logger = get_logger("OpenAI客户端")
|
logger = get_logger("OpenAI客户端")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_status_code_from_api_error(error: APIError) -> int:
|
||||||
|
"""
|
||||||
|
尝试从APIError对象中提取HTTP状态码,无法确定时回退为500。
|
||||||
|
"""
|
||||||
|
status_code = getattr(error, "status_code", None)
|
||||||
|
if status_code is None:
|
||||||
|
status_code = getattr(error, "status", None)
|
||||||
|
if status_code is None:
|
||||||
|
response = getattr(error, "response", None)
|
||||||
|
if response is not None:
|
||||||
|
status_code = getattr(response, "status_code", None) or getattr(response, "status", None)
|
||||||
|
try:
|
||||||
|
return int(status_code)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return 500
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
|
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
|
||||||
"""
|
"""
|
||||||
转换消息格式 - 将消息转换为OpenAI API所需的格式
|
转换消息格式 - 将消息转换为OpenAI API所需的格式
|
||||||
|
|
@ -281,6 +292,7 @@ async def _default_stream_response_handler(
|
||||||
if buffer and not buffer.closed:
|
if buffer and not buffer.closed:
|
||||||
buffer.close()
|
buffer.close()
|
||||||
|
|
||||||
|
try:
|
||||||
async for event in resp_stream:
|
async for event in resp_stream:
|
||||||
if interrupt_flag and interrupt_flag.is_set():
|
if interrupt_flag and interrupt_flag.is_set():
|
||||||
# 如果中断量被设置,则抛出ReqAbortException
|
# 如果中断量被设置,则抛出ReqAbortException
|
||||||
|
|
@ -320,6 +332,12 @@ async def _default_stream_response_handler(
|
||||||
event.usage.completion_tokens or 0,
|
event.usage.completion_tokens or 0,
|
||||||
event.usage.total_tokens or 0,
|
event.usage.total_tokens or 0,
|
||||||
)
|
)
|
||||||
|
except APIError as e:
|
||||||
|
_insure_buffer_closed()
|
||||||
|
status_code = _extract_status_code_from_api_error(e)
|
||||||
|
message = getattr(e, "message", None) or str(e)
|
||||||
|
logger.warning(f"OpenAI流式响应异常: {message}")
|
||||||
|
raise RespNotOkException(status_code, message) from e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return _build_stream_api_resp(
|
return _build_stream_api_resp(
|
||||||
|
|
@ -533,6 +551,10 @@ class OpenaiClient(BaseClient):
|
||||||
except APIStatusError as e:
|
except APIStatusError as e:
|
||||||
# 重封装APIError为RespNotOkException
|
# 重封装APIError为RespNotOkException
|
||||||
raise RespNotOkException(e.status_code, e.message) from e
|
raise RespNotOkException(e.status_code, e.message) from e
|
||||||
|
except APIError as e:
|
||||||
|
status_code = _extract_status_code_from_api_error(e)
|
||||||
|
message = getattr(e, "message", None) or str(e)
|
||||||
|
raise RespNotOkException(status_code, message) from e
|
||||||
|
|
||||||
if usage_record:
|
if usage_record:
|
||||||
resp.usage = UsageRecord(
|
resp.usage = UsageRecord(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue