from __future__ import annotations import json import os import re import time import inspect from collections.abc import Callable from typing import Any from django.conf import settings from review_agent.llm import LLMConfigurationError, LLMRequestError, generate_completion FIELD_LABELS = ["产品名称", "型号规格", "预期用途", "管理类别", "分类编码", "注册类型", "临床评价路径"] CompletionFunc = Callable[[list[dict[str, str]]], str] def review_condition_fields( *, text: str, rule_fields: dict[str, str], file_context: str = "", completion_func: Callable[..., str] | None = None, ) -> dict[str, Any]: llm_fields: dict[str, str] = {} status = "skipped" error_message = "" if not _should_call_llm(completion_func): selected_fields, selected_sources = _select_fields(rule_fields, llm_fields) return { "status": status, "error_message": error_message, "rule_fields": _clean_fields(rule_fields), "llm_fields": llm_fields, "selected_fields": selected_fields, "selected_sources": selected_sources, } try: raw = _call_completion_with_retries( completion_func or generate_completion, _condition_messages(text, rule_fields, file_context), ) payload = _parse_json_object(raw) llm_fields = _clean_fields(payload.get("fields") or payload) status = "success" except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError, OSError, TimeoutError) as exc: status = "failed" error_message = str(exc) selected_fields, selected_sources = _select_fields(rule_fields, llm_fields) return { "status": status, "error_message": error_message, "rule_fields": _clean_fields(rule_fields), "llm_fields": llm_fields, "selected_fields": selected_fields, "selected_sources": selected_sources, } def review_workflow_payload( *, stage: str, payload: dict[str, Any], completion_func: Callable[..., str] | None = None, ) -> dict[str, Any]: if not _should_call_llm(completion_func): return { "status": "skipped", "stage": stage, "result": {}, "error_message": "", } try: raw = _call_completion_with_retries( completion_func or generate_completion, _workflow_messages(stage, payload), ) parsed = _parse_json_object(raw) return { "status": "success", "stage": stage, "result": parsed, "error_message": "", } except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError, OSError, TimeoutError) as exc: return { "status": "failed", "stage": stage, "result": {}, "error_message": str(exc), } def _condition_messages(text: str, rule_fields: dict[str, str], file_context: str) -> list[dict[str, str]]: return [ { "role": "system", "content": ( "你是NMPA注册资料字段复核助手。请从附件文本中提取最合理的字段值," "只返回JSON,格式为 {\"fields\": {\"产品名称\": \"...\"}}。" "产品名称应包含完整名称、检测对象和方法学括号;不要把章节标题当产品名称。" ), }, { "role": "user", "content": json.dumps( { "file_context": file_context, "rule_fields": rule_fields, "text": text[:4000], "allowed_fields": FIELD_LABELS, }, ensure_ascii=False, ), }, ] def _workflow_messages(stage: str, payload: dict[str, Any]) -> list[dict[str, str]]: return [ { "role": "system", "content": ( "你是NMPA法规核查复核助手。请复核当前流程节点的规则结果," "指出可能误判、漏判和更合理的建议。只返回JSON。" ), }, { "role": "user", "content": json.dumps({"stage": stage, "payload": payload}, ensure_ascii=False)[:6000], }, ] def _parse_json_object(raw: str) -> dict[str, Any]: value = (raw or "").strip() if value.startswith("```"): value = re.sub(r"^```(?:json)?\s*", "", value) value = re.sub(r"\s*```$", "", value) start = value.find("{") end = value.rfind("}") if start >= 0 and end >= start: value = value[start : end + 1] parsed = json.loads(value) if not isinstance(parsed, dict): raise ValueError("LLM复核结果不是JSON对象。") return parsed def _call_completion_with_retries(completion_func: Callable[..., str], messages: list[dict[str, str]]) -> str: attempts = max(1, int(getattr(settings, "REGULATORY_LLM_REVIEW_MAX_ATTEMPTS", 3) or 3)) delay_seconds = float(getattr(settings, "REGULATORY_LLM_REVIEW_RETRY_DELAY_SECONDS", 0.5) or 0) timeout_seconds = float(getattr(settings, "REGULATORY_LLM_REVIEW_TIMEOUT_SECONDS", 15) or 15) accepts_timeout = _accepts_timeout(completion_func) last_error: Exception | None = None for attempt in range(1, attempts + 1): try: if accepts_timeout: return completion_func(messages, temperature=0.0, timeout=timeout_seconds) return completion_func(messages, temperature=0.0) except (LLMRequestError, OSError, TimeoutError) as exc: last_error = exc if attempt >= attempts: break if delay_seconds > 0: time.sleep(delay_seconds) if last_error: raise last_error raise LLMRequestError("LLM复核调用失败。") def _accepts_timeout(completion_func: Callable[..., str]) -> bool: try: signature = inspect.signature(completion_func) except (TypeError, ValueError): return True return "timeout" in signature.parameters def _should_call_llm(completion_func: Callable[..., str] | None) -> bool: if completion_func is not None: return True if os.environ.get("PYTEST_CURRENT_TEST") and not getattr(settings, "REGULATORY_LLM_REVIEW_ALLOW_TEST_CALLS", False): return False return bool(settings.LLM_API_KEY and settings.LLM_MODEL) def _clean_fields(fields: dict[str, Any]) -> dict[str, str]: clean = {} for label in FIELD_LABELS: value = fields.get(label) if not isinstance(value, str): continue normalized = " ".join(value.strip().split()).replace("(", "(").replace(")", ")") if normalized: clean[label] = normalized return clean def _select_fields(rule_fields: dict[str, str], llm_fields: dict[str, str]) -> tuple[dict[str, str], dict[str, str]]: rule_clean = _clean_fields(rule_fields) selected = {} sources = {} for label in FIELD_LABELS: rule_value = rule_clean.get(label, "") llm_value = llm_fields.get(label, "") value, source = _select_field(label, rule_value, llm_value) if value: selected[label] = value sources[label] = source return selected, sources def _select_field(label: str, rule_value: str, llm_value: str) -> tuple[str, str]: if _invalid_field_value(llm_value): return rule_value, "rule" if rule_value else "" if not rule_value: return llm_value, "llm" if llm_value else "" if not llm_value: return rule_value, "rule" if label == "产品名称" and _better_product_name(llm_value, rule_value): return llm_value, "llm" if len(llm_value) > len(rule_value) * 1.35 and rule_value in llm_value: return llm_value, "llm" return rule_value, "rule" def _better_product_name(candidate: str, current: str) -> bool: if current and current in candidate and len(candidate) > len(current): return True product_keywords = ["试剂盒", "检测试剂", "荧光PCR法", "PCR法", "核酸检测"] return len(candidate) > len(current) and any(keyword in candidate for keyword in product_keywords) def _invalid_field_value(value: str) -> bool: if not value: return True if "�" in value: return True return any(keyword in value for keyword in ["第1章", "第2章", "第3章", "监管信息", "综述资料", "章节目录"])