from __future__ import annotations import json import re from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any from django.conf import settings from review_agent.application_form_fill.schemas import ExtractedField, TemplateSpec from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir from review_agent.llm import generate_completion from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch, FileSummaryBatch from review_agent.regulatory_review.services.text_extract import extract_text FIELD_ALIASES = { "product_name": ["产品名称"], "applicant_name": ["注册人名称", "申请人名称", "生产企业名称"], "applicant_address": ["注册人住所", "申请人住所", "生产企业住所"], "manufacturer_address": ["生产地址", "生产企业地址", "生产场所"], "agent_name": ["代理人名称", "生产企业名称", "注册人名称", "申请人名称"], "agent_address": ["代理人住所", "生产企业住所", "注册人住所", "申请人住所"], "package_specification": ["包装规格", "规格"], "main_components": ["主要组成成分", "主要组成", "组成成分"], "intended_use": ["预期用途"], "storage_condition_and_validity": ["产品储存条件及有效期", "储存条件及有效期", "储存条件", "有效期"], } STATIC_STOP_LABELS = [ "申请人", "国家药品监督管理局", "填表说明", "注", "保证书", "应附资料", "优先通道申请", "分类编码", "医疗器械唯一标识", "注册产品目前是否", "临床评价路径", "临床试验", "其他需要说明的问题", "国家药监局器审中心医疗器械", ] def collect_document_texts(summary_batch: FileSummaryBatch) -> dict[str, str]: texts: dict[str, str] = {} for item in summary_batch.items.order_by("file_index"): path = Path(item.storage_path) if not path.is_absolute(): path = Path(settings.MEDIA_ROOT) / item.storage_path if not path.exists(): continue result = extract_text(path) if result.status == "success" and result.text: texts[item.file_name] = result.text return texts def extract_by_rules(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]: fields: list[dict[str, Any]] = [] field_defs = _field_defs(specs) labels = _all_field_labels(field_defs) for file_name, text in texts.items(): source_role = detect_source_role(file_name, text) for field in field_defs: value, evidence = _extract_field_value(text, field, labels) if not value: continue fields.append( ExtractedField( key=field["key"], label=field["label"], value=value, source_file=file_name, source_role=source_role, evidence=evidence, extractor="rule", confidence=0.75 if source_role == "说明书" else 0.65, ).__dict__ ) return {"fields": fields, "checklist_items": []} def extract_by_llm(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]: try: raw = generate_completion( [ {"role": "system", "content": _prompt_text()}, {"role": "user", "content": _build_llm_user_prompt(texts, specs)}, ], temperature=0.0, ) payload = _parse_json_object(raw) except Exception as exc: return {"fields": [], "checklist_items": [], "error_message": str(exc)} fields = [] allowed_keys = {field["key"] for field in _field_defs(specs)} for item in payload.get("fields") or []: if not isinstance(item, dict) or item.get("key") not in allowed_keys or not item.get("value"): continue fields.append( { "key": str(item.get("key") or ""), "label": str(item.get("label") or item.get("key") or ""), "value": str(item.get("value") or "").strip(), "source_file": str(item.get("source_file") or ""), "source_role": str(item.get("source_role") or detect_source_role(str(item.get("source_file") or ""), "")), "evidence": str(item.get("evidence") or "").strip(), "extractor": "llm", "confidence": _float_confidence(item.get("confidence"), default=0.7), } ) return {"fields": fields, "checklist_items": payload.get("checklist_items") or []} def run_parallel_extract(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]: with ThreadPoolExecutor(max_workers=2) as executor: rule_future = executor.submit(extract_by_rules, texts, specs) llm_future = executor.submit(extract_by_llm, texts, specs) regex_results = rule_future.result() llm_results = llm_future.result() return { "regex_results": regex_results, "llm_results": llm_results, "selected_templates": [spec.code for spec in specs], "source_evidence": [{"source_file": name, "char_count": len(text)} for name, text in texts.items()], } def save_field_extract_result(batch: ApplicationFormFillBatch, payload: dict[str, Any]) -> ApplicationFormFillArtifact: target_dir = ensure_batch_subdir(batch, "exports") path = target_dir / "field_extract_result.json" path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") return create_artifact_for_file( batch, path=path, artifact_type=ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT, file_format=ApplicationFormFillArtifact.FileFormat.JSON, name="field_extract_result", metadata={"artifact": "field_extract_result"}, created_by_node="field_extract", ) def detect_source_role(file_name: str, text: str = "") -> str: target = f"{file_name}\n{text[:200]}" if "说明书" in target: return "说明书" if "产品技术要求" in target: return "产品技术要求" if "注册检验" in target or "检测报告" in target: return "注册检验报告" if "性能研究" in target: return "性能研究资料" if "申请表" in target: return "申请表" return "其他注册资料" def _field_defs(specs: list[TemplateSpec]) -> list[dict[str, str]]: fields: list[dict[str, str]] = [] for spec in specs: for field in spec.fields: key = str(field.get("key") or "") label = str(field.get("label") or "") if key and label: fields.append({"key": key, "label": label}) return fields def _extract_field_value(text: str, field: dict[str, str], labels: list[str]) -> tuple[str, str]: aliases = _field_aliases(field) for label in aliases: value, evidence = _extract_colon_label_value(text, label, labels + aliases) if value: return value, evidence value, evidence = _extract_bracket_section_value(text, label) if value: return value, evidence return "", "" def _field_aliases(field: dict[str, str]) -> list[str]: aliases = [field["label"]] aliases.extend(FIELD_ALIASES.get(field["key"], [])) result: list[str] = [] for alias in aliases: normalized = str(alias or "").strip() if normalized and normalized not in result: result.append(normalized) return result def _all_field_labels(fields: list[dict[str, str]]) -> list[str]: labels: list[str] = list(STATIC_STOP_LABELS) for field in fields: for label in _field_aliases(field): if label not in labels: labels.append(label) return labels def _extract_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]: return _extract_colon_label_value(text, label, labels) def _extract_colon_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]: escaped_labels = "|".join(re.escape(item) for item in labels if item != label) stop_pattern = rf"(?=\n\s*(?:{escaped_labels})(?:\s*[::]|\s*$))" if escaped_labels else r"(?=\Z)" pattern = re.compile(rf"{re.escape(label)}\s*[::]\s*(.+?)(?:{stop_pattern}|\Z)", re.S) match = pattern.search(text or "") if not match: return "", "" raw = match.group(1).strip() value = re.sub(r"\n{2,}.*\Z", "", raw, flags=re.S).strip() value = "\n".join(line.strip() for line in value.splitlines() if line.strip()) evidence = f"{label}:{value}"[:300] return value, evidence def _extract_bracket_section_value(text: str, label: str) -> tuple[str, str]: heading_pattern = rf"^\s*[【\[]\s*{re.escape(label)}\s*[】\]]\s*$" lines = (text or "").splitlines() for index, line in enumerate(lines): if not re.match(heading_pattern, line.strip()): continue value_parts: list[str] = [] for next_line in lines[index + 1 :]: normalized = next_line.strip() if not normalized: continue if _looks_like_bracket_heading(normalized): break value_parts.append(normalized) value = "\n".join(value_parts).strip() if value: return value, f"【{label}】\n{value}"[:300] return "", "" def _looks_like_bracket_heading(line: str) -> bool: return bool(re.match(r"^\s*[【\[].{1,40}[】\]]\s*$", line)) def _prompt_text() -> str: path = Path(__file__).resolve().parents[1] / "prompts" / "field_extract.md" return path.read_text(encoding="utf-8") def _build_llm_user_prompt(texts: dict[str, str], specs: list[TemplateSpec]) -> str: fields = [{"key": field["key"], "label": field["label"]} for field in _field_defs(specs)] documents = [{"source_file": name, "text": text[:4000]} for name, text in texts.items()] return json.dumps({"fields": fields, "documents": documents}, ensure_ascii=False) def _parse_json_object(raw: str) -> dict[str, Any]: text = (raw or "").strip() if text.startswith("```"): text = text.strip("`").strip() if text.lower().startswith("json"): text = text[4:].strip() start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end < start: raise json.JSONDecodeError("未找到 JSON 对象", text, 0) return json.loads(text[start : end + 1]) def _float_confidence(value, *, default: float) -> float: try: return float(value) except (TypeError, ValueError): return default