Files
DEMO-AGENT/review_agent/application_form_fill/services/field_extract.py

279 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import re
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any
from django.conf import settings
from review_agent.application_form_fill.schemas import ExtractedField, TemplateSpec
from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir
from review_agent.llm import generate_completion
from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch, FileSummaryBatch
from review_agent.regulatory_review.services.text_extract import extract_text
FIELD_ALIASES = {
"product_name": ["产品名称"],
"applicant_name": ["注册人名称", "申请人名称", "生产企业名称"],
"applicant_address": ["注册人住所", "申请人住所", "生产企业住所"],
"manufacturer_address": ["生产地址", "生产企业地址", "生产场所"],
"agent_name": ["代理人名称", "生产企业名称", "注册人名称", "申请人名称"],
"agent_address": ["代理人住所", "生产企业住所", "注册人住所", "申请人住所"],
"package_specification": ["包装规格", "规格"],
"main_components": ["主要组成成分", "主要组成", "组成成分"],
"intended_use": ["预期用途"],
"storage_condition_and_validity": ["产品储存条件及有效期", "储存条件及有效期", "储存条件", "有效期"],
}
STATIC_STOP_LABELS = [
"申请人",
"国家药品监督管理局",
"填表说明",
"",
"保证书",
"应附资料",
"优先通道申请",
"分类编码",
"医疗器械唯一标识",
"注册产品目前是否",
"临床评价路径",
"临床试验",
"其他需要说明的问题",
"国家药监局器审中心医疗器械",
]
def collect_document_texts(summary_batch: FileSummaryBatch) -> dict[str, str]:
texts: dict[str, str] = {}
for item in summary_batch.items.order_by("file_index"):
path = Path(item.storage_path)
if not path.is_absolute():
path = Path(settings.MEDIA_ROOT) / item.storage_path
if not path.exists():
continue
result = extract_text(path)
if result.status == "success" and result.text:
texts[item.file_name] = result.text
return texts
def extract_by_rules(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
fields: list[dict[str, Any]] = []
field_defs = _field_defs(specs)
labels = _all_field_labels(field_defs)
for file_name, text in texts.items():
source_role = detect_source_role(file_name, text)
for field in field_defs:
value, evidence = _extract_field_value(text, field, labels)
if not value:
continue
fields.append(
ExtractedField(
key=field["key"],
label=field["label"],
value=value,
source_file=file_name,
source_role=source_role,
evidence=evidence,
extractor="rule",
confidence=0.75 if source_role == "说明书" else 0.65,
).__dict__
)
return {"fields": fields, "checklist_items": []}
def extract_by_llm(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
try:
raw = generate_completion(
[
{"role": "system", "content": _prompt_text()},
{"role": "user", "content": _build_llm_user_prompt(texts, specs)},
],
temperature=0.0,
)
payload = _parse_json_object(raw)
except Exception as exc:
return {"fields": [], "checklist_items": [], "error_message": str(exc)}
fields = []
allowed_keys = {field["key"] for field in _field_defs(specs)}
for item in payload.get("fields") or []:
if not isinstance(item, dict) or item.get("key") not in allowed_keys or not item.get("value"):
continue
fields.append(
{
"key": str(item.get("key") or ""),
"label": str(item.get("label") or item.get("key") or ""),
"value": str(item.get("value") or "").strip(),
"source_file": str(item.get("source_file") or ""),
"source_role": str(item.get("source_role") or detect_source_role(str(item.get("source_file") or ""), "")),
"evidence": str(item.get("evidence") or "").strip(),
"extractor": "llm",
"confidence": _float_confidence(item.get("confidence"), default=0.7),
}
)
return {"fields": fields, "checklist_items": payload.get("checklist_items") or []}
def run_parallel_extract(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
with ThreadPoolExecutor(max_workers=2) as executor:
rule_future = executor.submit(extract_by_rules, texts, specs)
llm_future = executor.submit(extract_by_llm, texts, specs)
regex_results = rule_future.result()
llm_results = llm_future.result()
return {
"regex_results": regex_results,
"llm_results": llm_results,
"selected_templates": [spec.code for spec in specs],
"source_evidence": [{"source_file": name, "char_count": len(text)} for name, text in texts.items()],
}
def save_field_extract_result(batch: ApplicationFormFillBatch, payload: dict[str, Any]) -> ApplicationFormFillArtifact:
target_dir = ensure_batch_subdir(batch, "exports")
path = target_dir / "field_extract_result.json"
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
return create_artifact_for_file(
batch,
path=path,
artifact_type=ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT,
file_format=ApplicationFormFillArtifact.FileFormat.JSON,
name="field_extract_result",
metadata={"artifact": "field_extract_result"},
created_by_node="field_extract",
)
def detect_source_role(file_name: str, text: str = "") -> str:
target = f"{file_name}\n{text[:200]}"
if "说明书" in target:
return "说明书"
if "产品技术要求" in target:
return "产品技术要求"
if "注册检验" in target or "检测报告" in target:
return "注册检验报告"
if "性能研究" in target:
return "性能研究资料"
if "申请表" in target:
return "申请表"
return "其他注册资料"
def _field_defs(specs: list[TemplateSpec]) -> list[dict[str, str]]:
fields: list[dict[str, str]] = []
for spec in specs:
for field in spec.fields:
key = str(field.get("key") or "")
label = str(field.get("label") or "")
if key and label:
fields.append({"key": key, "label": label})
return fields
def _extract_field_value(text: str, field: dict[str, str], labels: list[str]) -> tuple[str, str]:
aliases = _field_aliases(field)
for label in aliases:
value, evidence = _extract_colon_label_value(text, label, labels + aliases)
if value:
return value, evidence
value, evidence = _extract_bracket_section_value(text, label)
if value:
return value, evidence
return "", ""
def _field_aliases(field: dict[str, str]) -> list[str]:
aliases = [field["label"]]
aliases.extend(FIELD_ALIASES.get(field["key"], []))
result: list[str] = []
for alias in aliases:
normalized = str(alias or "").strip()
if normalized and normalized not in result:
result.append(normalized)
return result
def _all_field_labels(fields: list[dict[str, str]]) -> list[str]:
labels: list[str] = list(STATIC_STOP_LABELS)
for field in fields:
for label in _field_aliases(field):
if label not in labels:
labels.append(label)
return labels
def _extract_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]:
return _extract_colon_label_value(text, label, labels)
def _extract_colon_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]:
escaped_labels = "|".join(re.escape(item) for item in labels if item != label)
stop_pattern = rf"(?=\n\s*(?:{escaped_labels})(?:\s*[:]|\s*$))" if escaped_labels else r"(?=\Z)"
pattern = re.compile(rf"{re.escape(label)}\s*[:]\s*(.+?)(?:{stop_pattern}|\Z)", re.S)
match = pattern.search(text or "")
if not match:
return "", ""
raw = match.group(1).strip()
value = re.sub(r"\n{2,}.*\Z", "", raw, flags=re.S).strip()
value = "\n".join(line.strip() for line in value.splitlines() if line.strip())
evidence = f"{label}{value}"[:300]
return value, evidence
def _extract_bracket_section_value(text: str, label: str) -> tuple[str, str]:
heading_pattern = rf"^\s*[【\[]\s*{re.escape(label)}\s*[】\]]\s*$"
lines = (text or "").splitlines()
for index, line in enumerate(lines):
if not re.match(heading_pattern, line.strip()):
continue
value_parts: list[str] = []
for next_line in lines[index + 1 :]:
normalized = next_line.strip()
if not normalized:
continue
if _looks_like_bracket_heading(normalized):
break
value_parts.append(normalized)
value = "\n".join(value_parts).strip()
if value:
return value, f"{label}\n{value}"[:300]
return "", ""
def _looks_like_bracket_heading(line: str) -> bool:
return bool(re.match(r"^\s*[【\[].{1,40}[】\]]\s*$", line))
def _prompt_text() -> str:
path = Path(__file__).resolve().parents[1] / "prompts" / "field_extract.md"
return path.read_text(encoding="utf-8")
def _build_llm_user_prompt(texts: dict[str, str], specs: list[TemplateSpec]) -> str:
fields = [{"key": field["key"], "label": field["label"]} for field in _field_defs(specs)]
documents = [{"source_file": name, "text": text[:4000]} for name, text in texts.items()]
return json.dumps({"fields": fields, "documents": documents}, ensure_ascii=False)
def _parse_json_object(raw: str) -> dict[str, Any]:
text = (raw or "").strip()
if text.startswith("```"):
text = text.strip("`").strip()
if text.lower().startswith("json"):
text = text[4:].strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end < start:
raise json.JSONDecodeError("未找到 JSON 对象", text, 0)
return json.loads(text[start : end + 1])
def _float_confidence(value, *, default: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default