From a48f778e09019fec19ba569d7af0d2c3810de48d Mon Sep 17 00:00:00 2001 From: bruce Date: Sun, 7 Jun 2026 18:31:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(application-form-fill):=20=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E5=AD=97=E6=AE=B5=E6=8A=BD=E5=8F=96=E4=B8=8E=E5=86=B2?= =?UTF-8?q?=E7=AA=81=E5=90=88=E5=B9=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../prompts/field_extract.md | 23 +++ .../services/field_extract.py | 187 ++++++++++++++++++ .../services/field_merge.py | 88 +++++++++ ...est_application_form_fill_field_extract.py | 121 ++++++++++++ .../test_application_form_fill_field_merge.py | 79 ++++++++ 5 files changed, 498 insertions(+) create mode 100644 review_agent/application_form_fill/prompts/field_extract.md create mode 100644 review_agent/application_form_fill/services/field_extract.py create mode 100644 review_agent/application_form_fill/services/field_merge.py create mode 100644 tests/test_application_form_fill_field_extract.py create mode 100644 tests/test_application_form_fill_field_merge.py diff --git a/review_agent/application_form_fill/prompts/field_extract.md b/review_agent/application_form_fill/prompts/field_extract.md new file mode 100644 index 0000000..6ff1461 --- /dev/null +++ b/review_agent/application_form_fill/prompts/field_extract.md @@ -0,0 +1,23 @@ +你是医疗器械体外诊断试剂申报资料字段抽取助手。 + +请只输出 JSON 对象,不要输出 Markdown。结构如下: + +{ + "fields": [ + { + "key": "product_name", + "label": "产品名称", + "value": "字段值", + "source_file": "来源文件名", + "source_role": "说明书", + "evidence": "原文证据", + "confidence": 0.8 + } + ], + "checklist_items": [] +} + +要求: +- 只抽取输入模板字段中出现的信息。 +- 字段值必须来自资料原文,不要编造。 +- 找不到时不要输出该字段。 diff --git a/review_agent/application_form_fill/services/field_extract.py b/review_agent/application_form_fill/services/field_extract.py new file mode 100644 index 0000000..4c72f10 --- /dev/null +++ b/review_agent/application_form_fill/services/field_extract.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import json +import re +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any + +from django.conf import settings + +from review_agent.application_form_fill.schemas import ExtractedField, TemplateSpec +from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir +from review_agent.llm import generate_completion +from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch, FileSummaryBatch +from review_agent.regulatory_review.services.text_extract import extract_text + + +def collect_document_texts(summary_batch: FileSummaryBatch) -> dict[str, str]: + texts: dict[str, str] = {} + for item in summary_batch.items.order_by("file_index"): + path = Path(item.storage_path) + if not path.is_absolute(): + path = Path(settings.MEDIA_ROOT) / item.storage_path + if not path.exists(): + continue + result = extract_text(path) + if result.status == "success" and result.text: + texts[item.file_name] = result.text + return texts + + +def extract_by_rules(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]: + fields: list[dict[str, Any]] = [] + field_defs = _field_defs(specs) + labels = [field["label"] for field in field_defs if field.get("label")] + for file_name, text in texts.items(): + source_role = detect_source_role(file_name, text) + for field in field_defs: + value, evidence = _extract_label_value(text, field["label"], labels) + if not value: + continue + fields.append( + ExtractedField( + key=field["key"], + label=field["label"], + value=value, + source_file=file_name, + source_role=source_role, + evidence=evidence, + extractor="rule", + confidence=0.75 if source_role == "说明书" else 0.65, + ).__dict__ + ) + return {"fields": fields, "checklist_items": []} + + +def extract_by_llm(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]: + try: + raw = generate_completion( + [ + {"role": "system", "content": _prompt_text()}, + {"role": "user", "content": _build_llm_user_prompt(texts, specs)}, + ], + temperature=0.0, + ) + payload = _parse_json_object(raw) + except Exception as exc: + return {"fields": [], "checklist_items": [], "error_message": str(exc)} + + fields = [] + allowed_keys = {field["key"] for field in _field_defs(specs)} + for item in payload.get("fields") or []: + if not isinstance(item, dict) or item.get("key") not in allowed_keys or not item.get("value"): + continue + fields.append( + { + "key": str(item.get("key") or ""), + "label": str(item.get("label") or item.get("key") or ""), + "value": str(item.get("value") or "").strip(), + "source_file": str(item.get("source_file") or ""), + "source_role": str(item.get("source_role") or detect_source_role(str(item.get("source_file") or ""), "")), + "evidence": str(item.get("evidence") or "").strip(), + "extractor": "llm", + "confidence": _float_confidence(item.get("confidence"), default=0.7), + } + ) + return {"fields": fields, "checklist_items": payload.get("checklist_items") or []} + + +def run_parallel_extract(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]: + with ThreadPoolExecutor(max_workers=2) as executor: + rule_future = executor.submit(extract_by_rules, texts, specs) + llm_future = executor.submit(extract_by_llm, texts, specs) + regex_results = rule_future.result() + llm_results = llm_future.result() + return { + "regex_results": regex_results, + "llm_results": llm_results, + "selected_templates": [spec.code for spec in specs], + "source_evidence": [{"source_file": name, "char_count": len(text)} for name, text in texts.items()], + } + + +def save_field_extract_result(batch: ApplicationFormFillBatch, payload: dict[str, Any]) -> ApplicationFormFillArtifact: + target_dir = ensure_batch_subdir(batch, "exports") + path = target_dir / "field_extract_result.json" + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + return create_artifact_for_file( + batch, + path=path, + artifact_type=ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT, + file_format=ApplicationFormFillArtifact.FileFormat.JSON, + name="field_extract_result", + metadata={"artifact": "field_extract_result"}, + created_by_node="field_extract", + ) + + +def detect_source_role(file_name: str, text: str = "") -> str: + target = f"{file_name}\n{text[:200]}" + if "说明书" in target: + return "说明书" + if "产品技术要求" in target: + return "产品技术要求" + if "注册检验" in target or "检测报告" in target: + return "注册检验报告" + if "性能研究" in target: + return "性能研究资料" + if "申请表" in target: + return "申请表" + return "其他注册资料" + + +def _field_defs(specs: list[TemplateSpec]) -> list[dict[str, str]]: + fields: list[dict[str, str]] = [] + for spec in specs: + for field in spec.fields: + key = str(field.get("key") or "") + label = str(field.get("label") or "") + if key and label: + fields.append({"key": key, "label": label}) + return fields + + +def _extract_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]: + escaped_labels = "|".join(re.escape(item) for item in labels if item != label) + stop_pattern = rf"(?=\n\s*(?:{escaped_labels})\s*[::])" if escaped_labels else r"(?=\Z)" + pattern = re.compile(rf"{re.escape(label)}\s*[::]\s*(.+?)(?:{stop_pattern}|\Z)", re.S) + match = pattern.search(text or "") + if not match: + return "", "" + raw = match.group(1).strip() + value = re.sub(r"\n{2,}.*\Z", "", raw, flags=re.S).strip() + value = "\n".join(line.strip() for line in value.splitlines() if line.strip()) + evidence = f"{label}:{value}"[:300] + return value, evidence + + +def _prompt_text() -> str: + path = Path(__file__).resolve().parents[1] / "prompts" / "field_extract.md" + return path.read_text(encoding="utf-8") + + +def _build_llm_user_prompt(texts: dict[str, str], specs: list[TemplateSpec]) -> str: + fields = [{"key": field["key"], "label": field["label"]} for field in _field_defs(specs)] + documents = [{"source_file": name, "text": text[:4000]} for name, text in texts.items()] + return json.dumps({"fields": fields, "documents": documents}, ensure_ascii=False) + + +def _parse_json_object(raw: str) -> dict[str, Any]: + text = (raw or "").strip() + if text.startswith("```"): + text = text.strip("`").strip() + if text.lower().startswith("json"): + text = text[4:].strip() + start = text.find("{") + end = text.rfind("}") + if start == -1 or end == -1 or end < start: + raise json.JSONDecodeError("未找到 JSON 对象", text, 0) + return json.loads(text[start : end + 1]) + + +def _float_confidence(value, *, default: float) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default diff --git a/review_agent/application_form_fill/services/field_merge.py b/review_agent/application_form_fill/services/field_merge.py new file mode 100644 index 0000000..b6c858a --- /dev/null +++ b/review_agent/application_form_fill/services/field_merge.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import re +from typing import Any + +from review_agent.application_form_fill.schemas import MergedField + + +SOURCE_PRIORITY = { + "说明书": 1, + "产品技术要求": 2, + "注册检验报告": 3, + "检测报告": 3, + "性能研究资料": 4, + "其他注册资料": 5, +} + + +def normalize_field_value(value: str) -> str: + return re.sub(r"\s+", "", str(value or "")).strip().lower() + + +def rank_source(source_role: str, source_file: str = "") -> int: + target = f"{source_role}\n{source_file}" + for keyword, rank in SOURCE_PRIORITY.items(): + if keyword in target: + return rank + return 9 + + +def merge_fields(regex_results: dict[str, Any], llm_results: dict[str, Any]) -> tuple[dict[str, MergedField], list[dict]]: + grouped: dict[str, list[dict[str, Any]]] = {} + for item in list(regex_results.get("fields") or []) + list(llm_results.get("fields") or []): + key = str(item.get("key") or "") + value = str(item.get("value") or "").strip() + if not key or not value: + continue + grouped.setdefault(key, []).append(item) + + merged: dict[str, MergedField] = {} + conflicts: list[dict] = [] + for key, candidates in grouped.items(): + selected = sorted( + candidates, + key=lambda item: ( + rank_source(str(item.get("source_role") or ""), str(item.get("source_file") or "")), + -float(item.get("confidence") or 0), + ), + )[0] + distinct = _distinct_values(candidates) + has_conflict = len(distinct) > 1 + conflict_values = [ + { + "value": item.get("value"), + "source_file": item.get("source_file", ""), + "source_role": item.get("source_role", ""), + "evidence": item.get("evidence", ""), + } + for item in candidates + if normalize_field_value(str(item.get("value") or "")) != normalize_field_value(str(selected.get("value") or "")) + ] + merged_field = MergedField( + key=key, + label=str(selected.get("label") or key), + value=str(selected.get("value") or ""), + source_file=str(selected.get("source_file") or ""), + evidence=str(selected.get("evidence") or ""), + confidence=float(selected.get("confidence") or 0), + has_conflict=has_conflict, + conflict_values=conflict_values, + ) + merged[key] = merged_field + if has_conflict: + conflicts.append( + { + "field_key": key, + "field_label": merged_field.label, + "selected_value": merged_field.value, + "selected_source": merged_field.source_file, + "conflict_values": conflict_values, + "handling": "说明书优先,模板内黄底红字高亮" if rank_source(merged_field.source_file, merged_field.source_file) == 1 else "按来源优先级采用最高优先级字段", + } + ) + return merged, conflicts + + +def _distinct_values(candidates: list[dict[str, Any]]) -> set[str]: + return {normalize_field_value(str(item.get("value") or "")) for item in candidates if item.get("value")} diff --git a/tests/test_application_form_fill_field_extract.py b/tests/test_application_form_fill_field_extract.py new file mode 100644 index 0000000..08c7b44 --- /dev/null +++ b/tests/test_application_form_fill_field_extract.py @@ -0,0 +1,121 @@ +import json + +import pytest + +from review_agent.application_form_fill.services.field_extract import ( + extract_by_llm, + extract_by_rules, + run_parallel_extract, + save_field_extract_result, +) +from review_agent.application_form_fill.services.template_config import load_template_config +from review_agent.application_form_fill.services.template_select import select_templates +from review_agent.models import ( + ApplicationFormFillArtifact, + ApplicationFormFillBatch, + Conversation, + FileSummaryBatch, +) + + +pytestmark = pytest.mark.django_db + + +def _registration_specs(): + config = load_template_config() + specs, _risk_notes = select_templates(config, ["registration_certificate"], "首次注册") + return specs + + +def test_rule_extracts_registration_certificate_fields(): + texts = { + "产品说明书.txt": "\n".join( + [ + "产品名称:甲胎蛋白检测试剂盒", + "包装规格:20人份/盒", + "预期用途:用于体外定量检测人血清中甲胎蛋白含量", + "产品储存条件及有效期:2-8℃保存,有效期12个月", + ] + ) + } + + result = extract_by_rules(texts, _registration_specs()) + + values = {field["key"]: field for field in result["fields"]} + assert values["product_name"]["value"] == "甲胎蛋白检测试剂盒" + assert values["intended_use"]["source_role"] == "说明书" + assert "2-8℃保存" in values["storage_condition_and_validity"]["value"] + assert values["package_specification"]["extractor"] == "rule" + + +def test_llm_extract_parses_structured_json(monkeypatch): + monkeypatch.setattr( + "review_agent.application_form_fill.services.field_extract.generate_completion", + lambda messages, temperature=0.0: json.dumps( + { + "fields": [ + { + "key": "product_name", + "label": "产品名称", + "value": "甲胎蛋白检测试剂盒", + "source_file": "说明书.txt", + "source_role": "说明书", + "evidence": "产品名称:甲胎蛋白检测试剂盒", + "confidence": 0.9, + } + ], + "checklist_items": [], + }, + ensure_ascii=False, + ), + ) + + result = extract_by_llm({"说明书.txt": "产品名称:甲胎蛋白检测试剂盒"}, _registration_specs()) + + assert result["fields"][0]["extractor"] == "llm" + assert result["fields"][0]["value"] == "甲胎蛋白检测试剂盒" + + +def test_llm_extract_failure_returns_empty_result(monkeypatch): + monkeypatch.setattr( + "review_agent.application_form_fill.services.field_extract.generate_completion", + lambda messages, temperature=0.0: (_ for _ in ()).throw(TimeoutError("timeout")), + ) + + result = extract_by_llm({"说明书.txt": "产品名称:甲胎蛋白检测试剂盒"}, _registration_specs()) + + assert result["fields"] == [] + assert "timeout" in result["error_message"] + + +def test_parallel_extract_preserves_rule_result_when_llm_fails(monkeypatch): + monkeypatch.setattr( + "review_agent.application_form_fill.services.field_extract.generate_completion", + lambda messages, temperature=0.0: (_ for _ in ()).throw(TimeoutError("timeout")), + ) + + payload = run_parallel_extract({"说明书.txt": "产品名称:甲胎蛋白检测试剂盒"}, _registration_specs()) + + assert payload["regex_results"]["fields"] + assert payload["llm_results"]["fields"] == [] + assert payload["selected_templates"] == ["registration_certificate"] + + +def test_save_field_extract_result_creates_json_artifact(settings, tmp_path, django_user_model): + settings.MEDIA_ROOT = tmp_path + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + summary = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-FIELD") + batch = ApplicationFormFillBatch.objects.create( + conversation=conversation, + user=user, + source_summary_batch=summary, + batch_no="AFF-FIELD", + work_dir=str(tmp_path / "aff" / "AFF-FIELD"), + ) + + artifact = save_field_extract_result(batch, {"regex_results": {"fields": []}, "llm_results": {"fields": []}}) + + assert artifact.artifact_type == ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT + assert artifact.file_format == ApplicationFormFillArtifact.FileFormat.JSON + assert artifact.content_hash diff --git a/tests/test_application_form_fill_field_merge.py b/tests/test_application_form_fill_field_merge.py new file mode 100644 index 0000000..a449ad6 --- /dev/null +++ b/tests/test_application_form_fill_field_merge.py @@ -0,0 +1,79 @@ +import pytest + +from review_agent.application_form_fill.services.field_merge import merge_fields, normalize_field_value, rank_source + + +def test_normalize_field_value_removes_whitespace(): + assert normalize_field_value(" 2-8℃ 保存 \n 有效期12个月 ") == "2-8℃保存有效期12个月" + + +def test_rank_source_prefers_instructions(): + assert rank_source("说明书") < rank_source("产品技术要求") + + +def test_merge_fields_prefers_instructions_and_marks_conflict(): + regex_results = { + "fields": [ + { + "key": "storage_condition_and_validity", + "label": "产品储存条件及有效期", + "value": "2-8℃保存,有效期12个月", + "source_file": "说明书.txt", + "source_role": "说明书", + "evidence": "产品储存条件及有效期:2-8℃保存,有效期12个月", + "confidence": 0.75, + }, + { + "key": "storage_condition_and_validity", + "label": "产品储存条件及有效期", + "value": "-20℃保存", + "source_file": "产品技术要求.txt", + "source_role": "产品技术要求", + "evidence": "产品储存条件及有效期:-20℃保存", + "confidence": 0.8, + }, + ] + } + + merged, conflicts = merge_fields(regex_results, {"fields": []}) + + field = merged["storage_condition_and_validity"] + assert field.value == "2-8℃保存,有效期12个月" + assert field.has_conflict is True + assert conflicts[0]["selected_value"] == "2-8℃保存,有效期12个月" + assert conflicts[0]["conflict_values"][0]["value"] == "-20℃保存" + + +def test_merge_fields_combines_consistent_values_without_conflict(): + regex_results = { + "fields": [ + { + "key": "product_name", + "label": "产品名称", + "value": "甲胎蛋白检测试剂盒", + "source_file": "说明书.txt", + "source_role": "说明书", + "evidence": "产品名称:甲胎蛋白检测试剂盒", + "confidence": 0.75, + } + ] + } + llm_results = { + "fields": [ + { + "key": "product_name", + "label": "产品名称", + "value": "甲胎蛋白 检测试剂盒", + "source_file": "产品技术要求.txt", + "source_role": "产品技术要求", + "evidence": "产品名称:甲胎蛋白 检测试剂盒", + "confidence": 0.9, + } + ] + } + + merged, conflicts = merge_fields(regex_results, llm_results) + + assert merged["product_name"].value == "甲胎蛋白检测试剂盒" + assert merged["product_name"].has_conflict is False + assert conflicts == []