Files
DEMO-AGENT/review_agent/regulatory_review/services/llm_review.py

244 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import os
import re
import time
import inspect
from collections.abc import Callable
from typing import Any
from django.conf import settings
from review_agent.llm import LLMConfigurationError, LLMRequestError, generate_completion
FIELD_LABELS = ["产品名称", "型号规格", "预期用途", "管理类别", "分类编码", "注册类型", "临床评价路径"]
CompletionFunc = Callable[[list[dict[str, str]]], str]
def review_condition_fields(
*,
text: str,
rule_fields: dict[str, str],
file_context: str = "",
completion_func: Callable[..., str] | None = None,
) -> dict[str, Any]:
llm_fields: dict[str, str] = {}
status = "skipped"
error_message = ""
if not _should_call_llm(completion_func):
selected_fields, selected_sources = _select_fields(rule_fields, llm_fields)
return {
"status": status,
"error_message": error_message,
"rule_fields": _clean_fields(rule_fields),
"llm_fields": llm_fields,
"selected_fields": selected_fields,
"selected_sources": selected_sources,
}
try:
raw = _call_completion_with_retries(
completion_func or generate_completion,
_condition_messages(text, rule_fields, file_context),
)
payload = _parse_json_object(raw)
llm_fields = _clean_fields(payload.get("fields") or payload)
status = "success"
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError, OSError, TimeoutError) as exc:
status = "failed"
error_message = str(exc)
selected_fields, selected_sources = _select_fields(rule_fields, llm_fields)
return {
"status": status,
"error_message": error_message,
"rule_fields": _clean_fields(rule_fields),
"llm_fields": llm_fields,
"selected_fields": selected_fields,
"selected_sources": selected_sources,
}
def review_workflow_payload(
*,
stage: str,
payload: dict[str, Any],
completion_func: Callable[..., str] | None = None,
) -> dict[str, Any]:
if not _should_call_llm(completion_func):
return {
"status": "skipped",
"stage": stage,
"result": {},
"error_message": "",
}
try:
raw = _call_completion_with_retries(
completion_func or generate_completion,
_workflow_messages(stage, payload),
)
parsed = _parse_json_object(raw)
return {
"status": "success",
"stage": stage,
"result": parsed,
"error_message": "",
}
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError, OSError, TimeoutError) as exc:
return {
"status": "failed",
"stage": stage,
"result": {},
"error_message": str(exc),
}
def _condition_messages(text: str, rule_fields: dict[str, str], file_context: str) -> list[dict[str, str]]:
return [
{
"role": "system",
"content": (
"你是NMPA注册资料字段复核助手。请从附件文本中提取最合理的字段值"
"只返回JSON格式为 {\"fields\": {\"产品名称\": \"...\"}}。"
"产品名称应包含完整名称、检测对象和方法学括号;不要把章节标题当产品名称。"
),
},
{
"role": "user",
"content": json.dumps(
{
"file_context": file_context,
"rule_fields": rule_fields,
"text": text[:4000],
"allowed_fields": FIELD_LABELS,
},
ensure_ascii=False,
),
},
]
def _workflow_messages(stage: str, payload: dict[str, Any]) -> list[dict[str, str]]:
return [
{
"role": "system",
"content": (
"你是NMPA法规核查复核助手。请复核当前流程节点的规则结果"
"指出可能误判、漏判和更合理的建议。只返回JSON。"
),
},
{
"role": "user",
"content": json.dumps({"stage": stage, "payload": payload}, ensure_ascii=False)[:6000],
},
]
def _parse_json_object(raw: str) -> dict[str, Any]:
value = (raw or "").strip()
if value.startswith("```"):
value = re.sub(r"^```(?:json)?\s*", "", value)
value = re.sub(r"\s*```$", "", value)
start = value.find("{")
end = value.rfind("}")
if start >= 0 and end >= start:
value = value[start : end + 1]
parsed = json.loads(value)
if not isinstance(parsed, dict):
raise ValueError("LLM复核结果不是JSON对象。")
return parsed
def _call_completion_with_retries(completion_func: Callable[..., str], messages: list[dict[str, str]]) -> str:
attempts = max(1, int(getattr(settings, "REGULATORY_LLM_REVIEW_MAX_ATTEMPTS", 3) or 3))
delay_seconds = float(getattr(settings, "REGULATORY_LLM_REVIEW_RETRY_DELAY_SECONDS", 0.5) or 0)
timeout_seconds = float(getattr(settings, "REGULATORY_LLM_REVIEW_TIMEOUT_SECONDS", 15) or 15)
accepts_timeout = _accepts_timeout(completion_func)
last_error: Exception | None = None
for attempt in range(1, attempts + 1):
try:
if accepts_timeout:
return completion_func(messages, temperature=0.0, timeout=timeout_seconds)
return completion_func(messages, temperature=0.0)
except (LLMRequestError, OSError, TimeoutError) as exc:
last_error = exc
if attempt >= attempts:
break
if delay_seconds > 0:
time.sleep(delay_seconds)
if last_error:
raise last_error
raise LLMRequestError("LLM复核调用失败。")
def _accepts_timeout(completion_func: Callable[..., str]) -> bool:
try:
signature = inspect.signature(completion_func)
except (TypeError, ValueError):
return True
return "timeout" in signature.parameters
def _should_call_llm(completion_func: Callable[..., str] | None) -> bool:
if completion_func is not None:
return True
if os.environ.get("PYTEST_CURRENT_TEST") and not getattr(settings, "REGULATORY_LLM_REVIEW_ALLOW_TEST_CALLS", False):
return False
return bool(settings.LLM_API_KEY and settings.LLM_MODEL)
def _clean_fields(fields: dict[str, Any]) -> dict[str, str]:
clean = {}
for label in FIELD_LABELS:
value = fields.get(label)
if not isinstance(value, str):
continue
normalized = " ".join(value.strip().split()).replace("(", "").replace(")", "")
if normalized:
clean[label] = normalized
return clean
def _select_fields(rule_fields: dict[str, str], llm_fields: dict[str, str]) -> tuple[dict[str, str], dict[str, str]]:
rule_clean = _clean_fields(rule_fields)
selected = {}
sources = {}
for label in FIELD_LABELS:
rule_value = rule_clean.get(label, "")
llm_value = llm_fields.get(label, "")
value, source = _select_field(label, rule_value, llm_value)
if value:
selected[label] = value
sources[label] = source
return selected, sources
def _select_field(label: str, rule_value: str, llm_value: str) -> tuple[str, str]:
if _invalid_field_value(llm_value):
return rule_value, "rule" if rule_value else ""
if not rule_value:
return llm_value, "llm" if llm_value else ""
if not llm_value:
return rule_value, "rule"
if label == "产品名称" and _better_product_name(llm_value, rule_value):
return llm_value, "llm"
if len(llm_value) > len(rule_value) * 1.35 and rule_value in llm_value:
return llm_value, "llm"
return rule_value, "rule"
def _better_product_name(candidate: str, current: str) -> bool:
if current and current in candidate and len(candidate) > len(current):
return True
product_keywords = ["试剂盒", "检测试剂", "荧光PCR法", "PCR法", "核酸检测"]
return len(candidate) > len(current) and any(keyword in candidate for keyword in product_keywords)
def _invalid_field_value(value: str) -> bool:
if not value:
return True
if "<EFBFBD>" in value:
return True
return any(keyword in value for keyword in ["第1章", "第2章", "第3章", "监管信息", "综述资料", "章节目录"])