feat(application-form-fill): 实现字段抽取与冲突合并
This commit is contained in:
23
review_agent/application_form_fill/prompts/field_extract.md
Normal file
23
review_agent/application_form_fill/prompts/field_extract.md
Normal file
@@ -0,0 +1,23 @@
|
||||
你是医疗器械体外诊断试剂申报资料字段抽取助手。
|
||||
|
||||
请只输出 JSON 对象,不要输出 Markdown。结构如下:
|
||||
|
||||
{
|
||||
"fields": [
|
||||
{
|
||||
"key": "product_name",
|
||||
"label": "产品名称",
|
||||
"value": "字段值",
|
||||
"source_file": "来源文件名",
|
||||
"source_role": "说明书",
|
||||
"evidence": "原文证据",
|
||||
"confidence": 0.8
|
||||
}
|
||||
],
|
||||
"checklist_items": []
|
||||
}
|
||||
|
||||
要求:
|
||||
- 只抽取输入模板字段中出现的信息。
|
||||
- 字段值必须来自资料原文,不要编造。
|
||||
- 找不到时不要输出该字段。
|
||||
187
review_agent/application_form_fill/services/field_extract.py
Normal file
187
review_agent/application_form_fill/services/field_extract.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from review_agent.application_form_fill.schemas import ExtractedField, TemplateSpec
|
||||
from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir
|
||||
from review_agent.llm import generate_completion
|
||||
from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch, FileSummaryBatch
|
||||
from review_agent.regulatory_review.services.text_extract import extract_text
|
||||
|
||||
|
||||
def collect_document_texts(summary_batch: FileSummaryBatch) -> dict[str, str]:
|
||||
texts: dict[str, str] = {}
|
||||
for item in summary_batch.items.order_by("file_index"):
|
||||
path = Path(item.storage_path)
|
||||
if not path.is_absolute():
|
||||
path = Path(settings.MEDIA_ROOT) / item.storage_path
|
||||
if not path.exists():
|
||||
continue
|
||||
result = extract_text(path)
|
||||
if result.status == "success" and result.text:
|
||||
texts[item.file_name] = result.text
|
||||
return texts
|
||||
|
||||
|
||||
def extract_by_rules(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
|
||||
fields: list[dict[str, Any]] = []
|
||||
field_defs = _field_defs(specs)
|
||||
labels = [field["label"] for field in field_defs if field.get("label")]
|
||||
for file_name, text in texts.items():
|
||||
source_role = detect_source_role(file_name, text)
|
||||
for field in field_defs:
|
||||
value, evidence = _extract_label_value(text, field["label"], labels)
|
||||
if not value:
|
||||
continue
|
||||
fields.append(
|
||||
ExtractedField(
|
||||
key=field["key"],
|
||||
label=field["label"],
|
||||
value=value,
|
||||
source_file=file_name,
|
||||
source_role=source_role,
|
||||
evidence=evidence,
|
||||
extractor="rule",
|
||||
confidence=0.75 if source_role == "说明书" else 0.65,
|
||||
).__dict__
|
||||
)
|
||||
return {"fields": fields, "checklist_items": []}
|
||||
|
||||
|
||||
def extract_by_llm(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
|
||||
try:
|
||||
raw = generate_completion(
|
||||
[
|
||||
{"role": "system", "content": _prompt_text()},
|
||||
{"role": "user", "content": _build_llm_user_prompt(texts, specs)},
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
payload = _parse_json_object(raw)
|
||||
except Exception as exc:
|
||||
return {"fields": [], "checklist_items": [], "error_message": str(exc)}
|
||||
|
||||
fields = []
|
||||
allowed_keys = {field["key"] for field in _field_defs(specs)}
|
||||
for item in payload.get("fields") or []:
|
||||
if not isinstance(item, dict) or item.get("key") not in allowed_keys or not item.get("value"):
|
||||
continue
|
||||
fields.append(
|
||||
{
|
||||
"key": str(item.get("key") or ""),
|
||||
"label": str(item.get("label") or item.get("key") or ""),
|
||||
"value": str(item.get("value") or "").strip(),
|
||||
"source_file": str(item.get("source_file") or ""),
|
||||
"source_role": str(item.get("source_role") or detect_source_role(str(item.get("source_file") or ""), "")),
|
||||
"evidence": str(item.get("evidence") or "").strip(),
|
||||
"extractor": "llm",
|
||||
"confidence": _float_confidence(item.get("confidence"), default=0.7),
|
||||
}
|
||||
)
|
||||
return {"fields": fields, "checklist_items": payload.get("checklist_items") or []}
|
||||
|
||||
|
||||
def run_parallel_extract(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
rule_future = executor.submit(extract_by_rules, texts, specs)
|
||||
llm_future = executor.submit(extract_by_llm, texts, specs)
|
||||
regex_results = rule_future.result()
|
||||
llm_results = llm_future.result()
|
||||
return {
|
||||
"regex_results": regex_results,
|
||||
"llm_results": llm_results,
|
||||
"selected_templates": [spec.code for spec in specs],
|
||||
"source_evidence": [{"source_file": name, "char_count": len(text)} for name, text in texts.items()],
|
||||
}
|
||||
|
||||
|
||||
def save_field_extract_result(batch: ApplicationFormFillBatch, payload: dict[str, Any]) -> ApplicationFormFillArtifact:
|
||||
target_dir = ensure_batch_subdir(batch, "exports")
|
||||
path = target_dir / "field_extract_result.json"
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
return create_artifact_for_file(
|
||||
batch,
|
||||
path=path,
|
||||
artifact_type=ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT,
|
||||
file_format=ApplicationFormFillArtifact.FileFormat.JSON,
|
||||
name="field_extract_result",
|
||||
metadata={"artifact": "field_extract_result"},
|
||||
created_by_node="field_extract",
|
||||
)
|
||||
|
||||
|
||||
def detect_source_role(file_name: str, text: str = "") -> str:
|
||||
target = f"{file_name}\n{text[:200]}"
|
||||
if "说明书" in target:
|
||||
return "说明书"
|
||||
if "产品技术要求" in target:
|
||||
return "产品技术要求"
|
||||
if "注册检验" in target or "检测报告" in target:
|
||||
return "注册检验报告"
|
||||
if "性能研究" in target:
|
||||
return "性能研究资料"
|
||||
if "申请表" in target:
|
||||
return "申请表"
|
||||
return "其他注册资料"
|
||||
|
||||
|
||||
def _field_defs(specs: list[TemplateSpec]) -> list[dict[str, str]]:
|
||||
fields: list[dict[str, str]] = []
|
||||
for spec in specs:
|
||||
for field in spec.fields:
|
||||
key = str(field.get("key") or "")
|
||||
label = str(field.get("label") or "")
|
||||
if key and label:
|
||||
fields.append({"key": key, "label": label})
|
||||
return fields
|
||||
|
||||
|
||||
def _extract_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]:
|
||||
escaped_labels = "|".join(re.escape(item) for item in labels if item != label)
|
||||
stop_pattern = rf"(?=\n\s*(?:{escaped_labels})\s*[::])" if escaped_labels else r"(?=\Z)"
|
||||
pattern = re.compile(rf"{re.escape(label)}\s*[::]\s*(.+?)(?:{stop_pattern}|\Z)", re.S)
|
||||
match = pattern.search(text or "")
|
||||
if not match:
|
||||
return "", ""
|
||||
raw = match.group(1).strip()
|
||||
value = re.sub(r"\n{2,}.*\Z", "", raw, flags=re.S).strip()
|
||||
value = "\n".join(line.strip() for line in value.splitlines() if line.strip())
|
||||
evidence = f"{label}:{value}"[:300]
|
||||
return value, evidence
|
||||
|
||||
|
||||
def _prompt_text() -> str:
|
||||
path = Path(__file__).resolve().parents[1] / "prompts" / "field_extract.md"
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _build_llm_user_prompt(texts: dict[str, str], specs: list[TemplateSpec]) -> str:
|
||||
fields = [{"key": field["key"], "label": field["label"]} for field in _field_defs(specs)]
|
||||
documents = [{"source_file": name, "text": text[:4000]} for name, text in texts.items()]
|
||||
return json.dumps({"fields": fields, "documents": documents}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _parse_json_object(raw: str) -> dict[str, Any]:
|
||||
text = (raw or "").strip()
|
||||
if text.startswith("```"):
|
||||
text = text.strip("`").strip()
|
||||
if text.lower().startswith("json"):
|
||||
text = text[4:].strip()
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start == -1 or end == -1 or end < start:
|
||||
raise json.JSONDecodeError("未找到 JSON 对象", text, 0)
|
||||
return json.loads(text[start : end + 1])
|
||||
|
||||
|
||||
def _float_confidence(value, *, default: float) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
88
review_agent/application_form_fill/services/field_merge.py
Normal file
88
review_agent/application_form_fill/services/field_merge.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from review_agent.application_form_fill.schemas import MergedField
|
||||
|
||||
|
||||
SOURCE_PRIORITY = {
|
||||
"说明书": 1,
|
||||
"产品技术要求": 2,
|
||||
"注册检验报告": 3,
|
||||
"检测报告": 3,
|
||||
"性能研究资料": 4,
|
||||
"其他注册资料": 5,
|
||||
}
|
||||
|
||||
|
||||
def normalize_field_value(value: str) -> str:
|
||||
return re.sub(r"\s+", "", str(value or "")).strip().lower()
|
||||
|
||||
|
||||
def rank_source(source_role: str, source_file: str = "") -> int:
|
||||
target = f"{source_role}\n{source_file}"
|
||||
for keyword, rank in SOURCE_PRIORITY.items():
|
||||
if keyword in target:
|
||||
return rank
|
||||
return 9
|
||||
|
||||
|
||||
def merge_fields(regex_results: dict[str, Any], llm_results: dict[str, Any]) -> tuple[dict[str, MergedField], list[dict]]:
|
||||
grouped: dict[str, list[dict[str, Any]]] = {}
|
||||
for item in list(regex_results.get("fields") or []) + list(llm_results.get("fields") or []):
|
||||
key = str(item.get("key") or "")
|
||||
value = str(item.get("value") or "").strip()
|
||||
if not key or not value:
|
||||
continue
|
||||
grouped.setdefault(key, []).append(item)
|
||||
|
||||
merged: dict[str, MergedField] = {}
|
||||
conflicts: list[dict] = []
|
||||
for key, candidates in grouped.items():
|
||||
selected = sorted(
|
||||
candidates,
|
||||
key=lambda item: (
|
||||
rank_source(str(item.get("source_role") or ""), str(item.get("source_file") or "")),
|
||||
-float(item.get("confidence") or 0),
|
||||
),
|
||||
)[0]
|
||||
distinct = _distinct_values(candidates)
|
||||
has_conflict = len(distinct) > 1
|
||||
conflict_values = [
|
||||
{
|
||||
"value": item.get("value"),
|
||||
"source_file": item.get("source_file", ""),
|
||||
"source_role": item.get("source_role", ""),
|
||||
"evidence": item.get("evidence", ""),
|
||||
}
|
||||
for item in candidates
|
||||
if normalize_field_value(str(item.get("value") or "")) != normalize_field_value(str(selected.get("value") or ""))
|
||||
]
|
||||
merged_field = MergedField(
|
||||
key=key,
|
||||
label=str(selected.get("label") or key),
|
||||
value=str(selected.get("value") or ""),
|
||||
source_file=str(selected.get("source_file") or ""),
|
||||
evidence=str(selected.get("evidence") or ""),
|
||||
confidence=float(selected.get("confidence") or 0),
|
||||
has_conflict=has_conflict,
|
||||
conflict_values=conflict_values,
|
||||
)
|
||||
merged[key] = merged_field
|
||||
if has_conflict:
|
||||
conflicts.append(
|
||||
{
|
||||
"field_key": key,
|
||||
"field_label": merged_field.label,
|
||||
"selected_value": merged_field.value,
|
||||
"selected_source": merged_field.source_file,
|
||||
"conflict_values": conflict_values,
|
||||
"handling": "说明书优先,模板内黄底红字高亮" if rank_source(merged_field.source_file, merged_field.source_file) == 1 else "按来源优先级采用最高优先级字段",
|
||||
}
|
||||
)
|
||||
return merged, conflicts
|
||||
|
||||
|
||||
def _distinct_values(candidates: list[dict[str, Any]]) -> set[str]:
|
||||
return {normalize_field_value(str(item.get("value") or "")) for item in candidates if item.get("value")}
|
||||
121
tests/test_application_form_fill_field_extract.py
Normal file
121
tests/test_application_form_fill_field_extract.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from review_agent.application_form_fill.services.field_extract import (
|
||||
extract_by_llm,
|
||||
extract_by_rules,
|
||||
run_parallel_extract,
|
||||
save_field_extract_result,
|
||||
)
|
||||
from review_agent.application_form_fill.services.template_config import load_template_config
|
||||
from review_agent.application_form_fill.services.template_select import select_templates
|
||||
from review_agent.models import (
|
||||
ApplicationFormFillArtifact,
|
||||
ApplicationFormFillBatch,
|
||||
Conversation,
|
||||
FileSummaryBatch,
|
||||
)
|
||||
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
def _registration_specs():
|
||||
config = load_template_config()
|
||||
specs, _risk_notes = select_templates(config, ["registration_certificate"], "首次注册")
|
||||
return specs
|
||||
|
||||
|
||||
def test_rule_extracts_registration_certificate_fields():
|
||||
texts = {
|
||||
"产品说明书.txt": "\n".join(
|
||||
[
|
||||
"产品名称:甲胎蛋白检测试剂盒",
|
||||
"包装规格:20人份/盒",
|
||||
"预期用途:用于体外定量检测人血清中甲胎蛋白含量",
|
||||
"产品储存条件及有效期:2-8℃保存,有效期12个月",
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
result = extract_by_rules(texts, _registration_specs())
|
||||
|
||||
values = {field["key"]: field for field in result["fields"]}
|
||||
assert values["product_name"]["value"] == "甲胎蛋白检测试剂盒"
|
||||
assert values["intended_use"]["source_role"] == "说明书"
|
||||
assert "2-8℃保存" in values["storage_condition_and_validity"]["value"]
|
||||
assert values["package_specification"]["extractor"] == "rule"
|
||||
|
||||
|
||||
def test_llm_extract_parses_structured_json(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"review_agent.application_form_fill.services.field_extract.generate_completion",
|
||||
lambda messages, temperature=0.0: json.dumps(
|
||||
{
|
||||
"fields": [
|
||||
{
|
||||
"key": "product_name",
|
||||
"label": "产品名称",
|
||||
"value": "甲胎蛋白检测试剂盒",
|
||||
"source_file": "说明书.txt",
|
||||
"source_role": "说明书",
|
||||
"evidence": "产品名称:甲胎蛋白检测试剂盒",
|
||||
"confidence": 0.9,
|
||||
}
|
||||
],
|
||||
"checklist_items": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
|
||||
result = extract_by_llm({"说明书.txt": "产品名称:甲胎蛋白检测试剂盒"}, _registration_specs())
|
||||
|
||||
assert result["fields"][0]["extractor"] == "llm"
|
||||
assert result["fields"][0]["value"] == "甲胎蛋白检测试剂盒"
|
||||
|
||||
|
||||
def test_llm_extract_failure_returns_empty_result(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"review_agent.application_form_fill.services.field_extract.generate_completion",
|
||||
lambda messages, temperature=0.0: (_ for _ in ()).throw(TimeoutError("timeout")),
|
||||
)
|
||||
|
||||
result = extract_by_llm({"说明书.txt": "产品名称:甲胎蛋白检测试剂盒"}, _registration_specs())
|
||||
|
||||
assert result["fields"] == []
|
||||
assert "timeout" in result["error_message"]
|
||||
|
||||
|
||||
def test_parallel_extract_preserves_rule_result_when_llm_fails(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"review_agent.application_form_fill.services.field_extract.generate_completion",
|
||||
lambda messages, temperature=0.0: (_ for _ in ()).throw(TimeoutError("timeout")),
|
||||
)
|
||||
|
||||
payload = run_parallel_extract({"说明书.txt": "产品名称:甲胎蛋白检测试剂盒"}, _registration_specs())
|
||||
|
||||
assert payload["regex_results"]["fields"]
|
||||
assert payload["llm_results"]["fields"] == []
|
||||
assert payload["selected_templates"] == ["registration_certificate"]
|
||||
|
||||
|
||||
def test_save_field_extract_result_creates_json_artifact(settings, tmp_path, django_user_model):
|
||||
settings.MEDIA_ROOT = tmp_path
|
||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||
conversation = Conversation.objects.create(user=user, title="会话")
|
||||
summary = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-FIELD")
|
||||
batch = ApplicationFormFillBatch.objects.create(
|
||||
conversation=conversation,
|
||||
user=user,
|
||||
source_summary_batch=summary,
|
||||
batch_no="AFF-FIELD",
|
||||
work_dir=str(tmp_path / "aff" / "AFF-FIELD"),
|
||||
)
|
||||
|
||||
artifact = save_field_extract_result(batch, {"regex_results": {"fields": []}, "llm_results": {"fields": []}})
|
||||
|
||||
assert artifact.artifact_type == ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT
|
||||
assert artifact.file_format == ApplicationFormFillArtifact.FileFormat.JSON
|
||||
assert artifact.content_hash
|
||||
79
tests/test_application_form_fill_field_merge.py
Normal file
79
tests/test_application_form_fill_field_merge.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import pytest
|
||||
|
||||
from review_agent.application_form_fill.services.field_merge import merge_fields, normalize_field_value, rank_source
|
||||
|
||||
|
||||
def test_normalize_field_value_removes_whitespace():
|
||||
assert normalize_field_value(" 2-8℃ 保存 \n 有效期12个月 ") == "2-8℃保存有效期12个月"
|
||||
|
||||
|
||||
def test_rank_source_prefers_instructions():
|
||||
assert rank_source("说明书") < rank_source("产品技术要求")
|
||||
|
||||
|
||||
def test_merge_fields_prefers_instructions_and_marks_conflict():
|
||||
regex_results = {
|
||||
"fields": [
|
||||
{
|
||||
"key": "storage_condition_and_validity",
|
||||
"label": "产品储存条件及有效期",
|
||||
"value": "2-8℃保存,有效期12个月",
|
||||
"source_file": "说明书.txt",
|
||||
"source_role": "说明书",
|
||||
"evidence": "产品储存条件及有效期:2-8℃保存,有效期12个月",
|
||||
"confidence": 0.75,
|
||||
},
|
||||
{
|
||||
"key": "storage_condition_and_validity",
|
||||
"label": "产品储存条件及有效期",
|
||||
"value": "-20℃保存",
|
||||
"source_file": "产品技术要求.txt",
|
||||
"source_role": "产品技术要求",
|
||||
"evidence": "产品储存条件及有效期:-20℃保存",
|
||||
"confidence": 0.8,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
merged, conflicts = merge_fields(regex_results, {"fields": []})
|
||||
|
||||
field = merged["storage_condition_and_validity"]
|
||||
assert field.value == "2-8℃保存,有效期12个月"
|
||||
assert field.has_conflict is True
|
||||
assert conflicts[0]["selected_value"] == "2-8℃保存,有效期12个月"
|
||||
assert conflicts[0]["conflict_values"][0]["value"] == "-20℃保存"
|
||||
|
||||
|
||||
def test_merge_fields_combines_consistent_values_without_conflict():
|
||||
regex_results = {
|
||||
"fields": [
|
||||
{
|
||||
"key": "product_name",
|
||||
"label": "产品名称",
|
||||
"value": "甲胎蛋白检测试剂盒",
|
||||
"source_file": "说明书.txt",
|
||||
"source_role": "说明书",
|
||||
"evidence": "产品名称:甲胎蛋白检测试剂盒",
|
||||
"confidence": 0.75,
|
||||
}
|
||||
]
|
||||
}
|
||||
llm_results = {
|
||||
"fields": [
|
||||
{
|
||||
"key": "product_name",
|
||||
"label": "产品名称",
|
||||
"value": "甲胎蛋白 检测试剂盒",
|
||||
"source_file": "产品技术要求.txt",
|
||||
"source_role": "产品技术要求",
|
||||
"evidence": "产品名称:甲胎蛋白 检测试剂盒",
|
||||
"confidence": 0.9,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
merged, conflicts = merge_fields(regex_results, llm_results)
|
||||
|
||||
assert merged["product_name"].value == "甲胎蛋白检测试剂盒"
|
||||
assert merged["product_name"].has_conflict is False
|
||||
assert conflicts == []
|
||||
Reference in New Issue
Block a user