feat(application-form-fill): 实现字段抽取与冲突合并

This commit is contained in:
2026-06-07 18:31:34 +08:00
parent 72890783b3
commit a48f778e09
5 changed files with 498 additions and 0 deletions

View File

@@ -0,0 +1,187 @@
from __future__ import annotations
import json
import re
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any
from django.conf import settings
from review_agent.application_form_fill.schemas import ExtractedField, TemplateSpec
from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir
from review_agent.llm import generate_completion
from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch, FileSummaryBatch
from review_agent.regulatory_review.services.text_extract import extract_text
def collect_document_texts(summary_batch: FileSummaryBatch) -> dict[str, str]:
texts: dict[str, str] = {}
for item in summary_batch.items.order_by("file_index"):
path = Path(item.storage_path)
if not path.is_absolute():
path = Path(settings.MEDIA_ROOT) / item.storage_path
if not path.exists():
continue
result = extract_text(path)
if result.status == "success" and result.text:
texts[item.file_name] = result.text
return texts
def extract_by_rules(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
fields: list[dict[str, Any]] = []
field_defs = _field_defs(specs)
labels = [field["label"] for field in field_defs if field.get("label")]
for file_name, text in texts.items():
source_role = detect_source_role(file_name, text)
for field in field_defs:
value, evidence = _extract_label_value(text, field["label"], labels)
if not value:
continue
fields.append(
ExtractedField(
key=field["key"],
label=field["label"],
value=value,
source_file=file_name,
source_role=source_role,
evidence=evidence,
extractor="rule",
confidence=0.75 if source_role == "说明书" else 0.65,
).__dict__
)
return {"fields": fields, "checklist_items": []}
def extract_by_llm(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
try:
raw = generate_completion(
[
{"role": "system", "content": _prompt_text()},
{"role": "user", "content": _build_llm_user_prompt(texts, specs)},
],
temperature=0.0,
)
payload = _parse_json_object(raw)
except Exception as exc:
return {"fields": [], "checklist_items": [], "error_message": str(exc)}
fields = []
allowed_keys = {field["key"] for field in _field_defs(specs)}
for item in payload.get("fields") or []:
if not isinstance(item, dict) or item.get("key") not in allowed_keys or not item.get("value"):
continue
fields.append(
{
"key": str(item.get("key") or ""),
"label": str(item.get("label") or item.get("key") or ""),
"value": str(item.get("value") or "").strip(),
"source_file": str(item.get("source_file") or ""),
"source_role": str(item.get("source_role") or detect_source_role(str(item.get("source_file") or ""), "")),
"evidence": str(item.get("evidence") or "").strip(),
"extractor": "llm",
"confidence": _float_confidence(item.get("confidence"), default=0.7),
}
)
return {"fields": fields, "checklist_items": payload.get("checklist_items") or []}
def run_parallel_extract(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
with ThreadPoolExecutor(max_workers=2) as executor:
rule_future = executor.submit(extract_by_rules, texts, specs)
llm_future = executor.submit(extract_by_llm, texts, specs)
regex_results = rule_future.result()
llm_results = llm_future.result()
return {
"regex_results": regex_results,
"llm_results": llm_results,
"selected_templates": [spec.code for spec in specs],
"source_evidence": [{"source_file": name, "char_count": len(text)} for name, text in texts.items()],
}
def save_field_extract_result(batch: ApplicationFormFillBatch, payload: dict[str, Any]) -> ApplicationFormFillArtifact:
target_dir = ensure_batch_subdir(batch, "exports")
path = target_dir / "field_extract_result.json"
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
return create_artifact_for_file(
batch,
path=path,
artifact_type=ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT,
file_format=ApplicationFormFillArtifact.FileFormat.JSON,
name="field_extract_result",
metadata={"artifact": "field_extract_result"},
created_by_node="field_extract",
)
def detect_source_role(file_name: str, text: str = "") -> str:
target = f"{file_name}\n{text[:200]}"
if "说明书" in target:
return "说明书"
if "产品技术要求" in target:
return "产品技术要求"
if "注册检验" in target or "检测报告" in target:
return "注册检验报告"
if "性能研究" in target:
return "性能研究资料"
if "申请表" in target:
return "申请表"
return "其他注册资料"
def _field_defs(specs: list[TemplateSpec]) -> list[dict[str, str]]:
fields: list[dict[str, str]] = []
for spec in specs:
for field in spec.fields:
key = str(field.get("key") or "")
label = str(field.get("label") or "")
if key and label:
fields.append({"key": key, "label": label})
return fields
def _extract_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]:
escaped_labels = "|".join(re.escape(item) for item in labels if item != label)
stop_pattern = rf"(?=\n\s*(?:{escaped_labels})\s*[:])" if escaped_labels else r"(?=\Z)"
pattern = re.compile(rf"{re.escape(label)}\s*[:]\s*(.+?)(?:{stop_pattern}|\Z)", re.S)
match = pattern.search(text or "")
if not match:
return "", ""
raw = match.group(1).strip()
value = re.sub(r"\n{2,}.*\Z", "", raw, flags=re.S).strip()
value = "\n".join(line.strip() for line in value.splitlines() if line.strip())
evidence = f"{label}{value}"[:300]
return value, evidence
def _prompt_text() -> str:
path = Path(__file__).resolve().parents[1] / "prompts" / "field_extract.md"
return path.read_text(encoding="utf-8")
def _build_llm_user_prompt(texts: dict[str, str], specs: list[TemplateSpec]) -> str:
fields = [{"key": field["key"], "label": field["label"]} for field in _field_defs(specs)]
documents = [{"source_file": name, "text": text[:4000]} for name, text in texts.items()]
return json.dumps({"fields": fields, "documents": documents}, ensure_ascii=False)
def _parse_json_object(raw: str) -> dict[str, Any]:
text = (raw or "").strip()
if text.startswith("```"):
text = text.strip("`").strip()
if text.lower().startswith("json"):
text = text[4:].strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end < start:
raise json.JSONDecodeError("未找到 JSON 对象", text, 0)
return json.loads(text[start : end + 1])
def _float_confidence(value, *, default: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default