fix(regulatory): 为LLM复核超时增加重试

This commit is contained in:
2026-06-07 13:03:24 +08:00
parent 9e27c4c684
commit 0f9fb980f2
3 changed files with 95 additions and 5 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import json
import os
import re
import time
from collections.abc import Callable
from typing import Any
@@ -36,11 +37,14 @@ def review_condition_fields(
"selected_sources": selected_sources,
}
try:
raw = (completion_func or generate_completion)(_condition_messages(text, rule_fields, file_context), temperature=0.0)
raw = _call_completion_with_retries(
completion_func or generate_completion,
_condition_messages(text, rule_fields, file_context),
)
payload = _parse_json_object(raw)
llm_fields = _clean_fields(payload.get("fields") or payload)
status = "success"
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc:
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError, OSError, TimeoutError) as exc:
status = "failed"
error_message = str(exc)
@@ -69,7 +73,10 @@ def review_workflow_payload(
"error_message": "",
}
try:
raw = (completion_func or generate_completion)(_workflow_messages(stage, payload), temperature=0.0)
raw = _call_completion_with_retries(
completion_func or generate_completion,
_workflow_messages(stage, payload),
)
parsed = _parse_json_object(raw)
return {
"status": "success",
@@ -77,7 +84,7 @@ def review_workflow_payload(
"result": parsed,
"error_message": "",
}
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc:
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError, OSError, TimeoutError) as exc:
return {
"status": "failed",
"stage": stage,
@@ -142,6 +149,24 @@ def _parse_json_object(raw: str) -> dict[str, Any]:
return parsed
def _call_completion_with_retries(completion_func: Callable[..., str], messages: list[dict[str, str]]) -> str:
attempts = max(1, int(getattr(settings, "REGULATORY_LLM_REVIEW_MAX_ATTEMPTS", 3) or 3))
delay_seconds = float(getattr(settings, "REGULATORY_LLM_REVIEW_RETRY_DELAY_SECONDS", 0.5) or 0)
last_error: Exception | None = None
for attempt in range(1, attempts + 1):
try:
return completion_func(messages, temperature=0.0)
except (LLMRequestError, OSError, TimeoutError) as exc:
last_error = exc
if attempt >= attempts:
break
if delay_seconds > 0:
time.sleep(delay_seconds)
if last_error:
raise last_error
raise LLMRequestError("LLM复核调用失败。")
def _should_call_llm(completion_func: Callable[..., str] | None) -> bool:
if completion_func is not None:
return True