from __future__ import annotations import json import re from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any from django.conf import settings from review_agent.application_form_fill.schemas import ExtractedField, TemplateSpec from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir from review_agent.llm import generate_completion from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch, FileSummaryBatch from review_agent.regulatory_review.services.text_extract import extract_text def collect_document_texts(summary_batch: FileSummaryBatch) -> dict[str, str]: texts: dict[str, str] = {} for item in summary_batch.items.order_by("file_index"): path = Path(item.storage_path) if not path.is_absolute(): path = Path(settings.MEDIA_ROOT) / item.storage_path if not path.exists(): continue result = extract_text(path) if result.status == "success" and result.text: texts[item.file_name] = result.text return texts def extract_by_rules(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]: fields: list[dict[str, Any]] = [] field_defs = _field_defs(specs) labels = [field["label"] for field in field_defs if field.get("label")] for file_name, text in texts.items(): source_role = detect_source_role(file_name, text) for field in field_defs: value, evidence = _extract_label_value(text, field["label"], labels) if not value: continue fields.append( ExtractedField( key=field["key"], label=field["label"], value=value, source_file=file_name, source_role=source_role, evidence=evidence, extractor="rule", confidence=0.75 if source_role == "说明书" else 0.65, ).__dict__ ) return {"fields": fields, "checklist_items": []} def extract_by_llm(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]: try: raw = generate_completion( [ {"role": "system", "content": _prompt_text()}, {"role": "user", "content": _build_llm_user_prompt(texts, specs)}, ], temperature=0.0, ) payload = _parse_json_object(raw) except Exception as exc: return {"fields": [], "checklist_items": [], "error_message": str(exc)} fields = [] allowed_keys = {field["key"] for field in _field_defs(specs)} for item in payload.get("fields") or []: if not isinstance(item, dict) or item.get("key") not in allowed_keys or not item.get("value"): continue fields.append( { "key": str(item.get("key") or ""), "label": str(item.get("label") or item.get("key") or ""), "value": str(item.get("value") or "").strip(), "source_file": str(item.get("source_file") or ""), "source_role": str(item.get("source_role") or detect_source_role(str(item.get("source_file") or ""), "")), "evidence": str(item.get("evidence") or "").strip(), "extractor": "llm", "confidence": _float_confidence(item.get("confidence"), default=0.7), } ) return {"fields": fields, "checklist_items": payload.get("checklist_items") or []} def run_parallel_extract(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]: with ThreadPoolExecutor(max_workers=2) as executor: rule_future = executor.submit(extract_by_rules, texts, specs) llm_future = executor.submit(extract_by_llm, texts, specs) regex_results = rule_future.result() llm_results = llm_future.result() return { "regex_results": regex_results, "llm_results": llm_results, "selected_templates": [spec.code for spec in specs], "source_evidence": [{"source_file": name, "char_count": len(text)} for name, text in texts.items()], } def save_field_extract_result(batch: ApplicationFormFillBatch, payload: dict[str, Any]) -> ApplicationFormFillArtifact: target_dir = ensure_batch_subdir(batch, "exports") path = target_dir / "field_extract_result.json" path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") return create_artifact_for_file( batch, path=path, artifact_type=ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT, file_format=ApplicationFormFillArtifact.FileFormat.JSON, name="field_extract_result", metadata={"artifact": "field_extract_result"}, created_by_node="field_extract", ) def detect_source_role(file_name: str, text: str = "") -> str: target = f"{file_name}\n{text[:200]}" if "说明书" in target: return "说明书" if "产品技术要求" in target: return "产品技术要求" if "注册检验" in target or "检测报告" in target: return "注册检验报告" if "性能研究" in target: return "性能研究资料" if "申请表" in target: return "申请表" return "其他注册资料" def _field_defs(specs: list[TemplateSpec]) -> list[dict[str, str]]: fields: list[dict[str, str]] = [] for spec in specs: for field in spec.fields: key = str(field.get("key") or "") label = str(field.get("label") or "") if key and label: fields.append({"key": key, "label": label}) return fields def _extract_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]: escaped_labels = "|".join(re.escape(item) for item in labels if item != label) stop_pattern = rf"(?=\n\s*(?:{escaped_labels})\s*[::])" if escaped_labels else r"(?=\Z)" pattern = re.compile(rf"{re.escape(label)}\s*[::]\s*(.+?)(?:{stop_pattern}|\Z)", re.S) match = pattern.search(text or "") if not match: return "", "" raw = match.group(1).strip() value = re.sub(r"\n{2,}.*\Z", "", raw, flags=re.S).strip() value = "\n".join(line.strip() for line in value.splitlines() if line.strip()) evidence = f"{label}:{value}"[:300] return value, evidence def _prompt_text() -> str: path = Path(__file__).resolve().parents[1] / "prompts" / "field_extract.md" return path.read_text(encoding="utf-8") def _build_llm_user_prompt(texts: dict[str, str], specs: list[TemplateSpec]) -> str: fields = [{"key": field["key"], "label": field["label"]} for field in _field_defs(specs)] documents = [{"source_file": name, "text": text[:4000]} for name, text in texts.items()] return json.dumps({"fields": fields, "documents": documents}, ensure_ascii=False) def _parse_json_object(raw: str) -> dict[str, Any]: text = (raw or "").strip() if text.startswith("```"): text = text.strip("`").strip() if text.lower().startswith("json"): text = text[4:].strip() start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end < start: raise json.JSONDecodeError("未找到 JSON 对象", text, 0) return json.loads(text[start : end + 1]) def _float_confidence(value, *, default: float) -> float: try: return float(value) except (TypeError, ValueError): return default