fix(regulatory): 为LLM复核超时增加重试

This commit is contained in:
2026-06-07 13:03:24 +08:00
parent 9e27c4c684
commit 0f9fb980f2
3 changed files with 95 additions and 5 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import json import json
import os import os
import re import re
import time
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@@ -36,11 +37,14 @@ def review_condition_fields(
"selected_sources": selected_sources, "selected_sources": selected_sources,
} }
try: try:
raw = (completion_func or generate_completion)(_condition_messages(text, rule_fields, file_context), temperature=0.0) raw = _call_completion_with_retries(
completion_func or generate_completion,
_condition_messages(text, rule_fields, file_context),
)
payload = _parse_json_object(raw) payload = _parse_json_object(raw)
llm_fields = _clean_fields(payload.get("fields") or payload) llm_fields = _clean_fields(payload.get("fields") or payload)
status = "success" status = "success"
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc: except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError, OSError, TimeoutError) as exc:
status = "failed" status = "failed"
error_message = str(exc) error_message = str(exc)
@@ -69,7 +73,10 @@ def review_workflow_payload(
"error_message": "", "error_message": "",
} }
try: try:
raw = (completion_func or generate_completion)(_workflow_messages(stage, payload), temperature=0.0) raw = _call_completion_with_retries(
completion_func or generate_completion,
_workflow_messages(stage, payload),
)
parsed = _parse_json_object(raw) parsed = _parse_json_object(raw)
return { return {
"status": "success", "status": "success",
@@ -77,7 +84,7 @@ def review_workflow_payload(
"result": parsed, "result": parsed,
"error_message": "", "error_message": "",
} }
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc: except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError, OSError, TimeoutError) as exc:
return { return {
"status": "failed", "status": "failed",
"stage": stage, "stage": stage,
@@ -142,6 +149,24 @@ def _parse_json_object(raw: str) -> dict[str, Any]:
return parsed return parsed
def _call_completion_with_retries(completion_func: Callable[..., str], messages: list[dict[str, str]]) -> str:
attempts = max(1, int(getattr(settings, "REGULATORY_LLM_REVIEW_MAX_ATTEMPTS", 3) or 3))
delay_seconds = float(getattr(settings, "REGULATORY_LLM_REVIEW_RETRY_DELAY_SECONDS", 0.5) or 0)
last_error: Exception | None = None
for attempt in range(1, attempts + 1):
try:
return completion_func(messages, temperature=0.0)
except (LLMRequestError, OSError, TimeoutError) as exc:
last_error = exc
if attempt >= attempts:
break
if delay_seconds > 0:
time.sleep(delay_seconds)
if last_error:
raise last_error
raise LLMRequestError("LLM复核调用失败。")
def _should_call_llm(completion_func: Callable[..., str] | None) -> bool: def _should_call_llm(completion_func: Callable[..., str] | None) -> bool:
if completion_func is not None: if completion_func is not None:
return True return True

View File

@@ -1,6 +1,6 @@
import json import json
from review_agent.regulatory_review.services.llm_review import review_condition_fields from review_agent.regulatory_review.services.llm_review import review_condition_fields, review_workflow_payload
def test_review_condition_fields_selects_more_complete_llm_product_name(): def test_review_condition_fields_selects_more_complete_llm_product_name():
@@ -55,3 +55,39 @@ def test_review_condition_fields_rejects_garbled_llm_product_name():
assert result["selected_fields"]["产品名称"] == "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒荧光PCR法" assert result["selected_fields"]["产品名称"] == "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒荧光PCR法"
assert result["selected_sources"]["产品名称"] == "rule" assert result["selected_sources"]["产品名称"] == "rule"
def test_review_workflow_payload_handles_timeout_without_raising():
def completion(messages, temperature=0.0):
raise TimeoutError("The read operation timed out")
result = review_workflow_payload(
stage="completeness_check",
payload={"findings": []},
completion_func=completion,
)
assert result["status"] == "failed"
assert result["stage"] == "completeness_check"
assert "timed out" in result["error_message"]
def test_review_workflow_payload_retries_timeout_before_success(settings):
settings.REGULATORY_LLM_REVIEW_RETRY_DELAY_SECONDS = 0
attempts = {"count": 0}
def completion(messages, temperature=0.0):
attempts["count"] += 1
if attempts["count"] < 3:
raise TimeoutError("The read operation timed out")
return json.dumps({"reviewed": True})
result = review_workflow_payload(
stage="completeness_check",
payload={"findings": []},
completion_func=completion,
)
assert attempts["count"] == 3
assert result["status"] == "success"
assert result["result"]["reviewed"] is True

View File

@@ -118,6 +118,35 @@ def test_start_regulatory_review_workflow_runs_synchronously(django_user_model):
).exists() ).exists()
def test_workflow_continues_when_llm_review_times_out(monkeypatch, settings, django_user_model):
settings.REGULATORY_LLM_REVIEW_ALLOW_TEST_CALLS = True
user = django_user_model.objects.create_user(username="owner", password="pass")
conversation = Conversation.objects.create(user=user, title="会话")
summary = FileSummaryBatch.objects.create(
conversation=conversation,
user=user,
batch_no="FS-OK",
status=FileSummaryBatch.Status.SUCCESS,
)
batch = create_regulatory_review_batch(
conversation=conversation,
user=user,
source_summary_batch=summary,
)
batch.condition_json = {"confirmed": True, "confirmed_conditions": {"product_category": "体外诊断试剂"}}
batch.save(update_fields=["condition_json"])
monkeypatch.setattr(
"review_agent.regulatory_review.services.llm_review.generate_completion",
lambda messages, temperature=0.0: (_ for _ in ()).throw(TimeoutError("The read operation timed out")),
)
start_regulatory_review_workflow(batch, async_run=False)
batch.refresh_from_db()
assert batch.status == RegulatoryReviewBatch.Status.SUCCESS
assert batch.error_message == ""
def test_stream_message_prompts_for_summary_when_missing(monkeypatch, django_user_model): def test_stream_message_prompts_for_summary_when_missing(monkeypatch, django_user_model):
user = django_user_model.objects.create_user(username="owner", password="pass") user = django_user_model.objects.create_user(username="owner", password="pass")
conversation = Conversation.objects.create(user=user, title="会话") conversation = Conversation.objects.create(user=user, title="会话")