279 lines
11 KiB
Python
279 lines
11 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from django.conf import settings
|
||
|
||
from review_agent.application_form_fill.schemas import ExtractedField, TemplateSpec
|
||
from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir
|
||
from review_agent.llm import generate_completion
|
||
from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch, FileSummaryBatch
|
||
from review_agent.regulatory_review.services.text_extract import extract_text
|
||
|
||
|
||
FIELD_ALIASES = {
|
||
"product_name": ["产品名称"],
|
||
"applicant_name": ["注册人名称", "申请人名称", "生产企业名称"],
|
||
"applicant_address": ["注册人住所", "申请人住所", "生产企业住所"],
|
||
"manufacturer_address": ["生产地址", "生产企业地址", "生产场所"],
|
||
"agent_name": ["代理人名称", "生产企业名称", "注册人名称", "申请人名称"],
|
||
"agent_address": ["代理人住所", "生产企业住所", "注册人住所", "申请人住所"],
|
||
"package_specification": ["包装规格", "规格"],
|
||
"main_components": ["主要组成成分", "主要组成", "组成成分"],
|
||
"intended_use": ["预期用途"],
|
||
"storage_condition_and_validity": ["产品储存条件及有效期", "储存条件及有效期", "储存条件", "有效期"],
|
||
}
|
||
|
||
STATIC_STOP_LABELS = [
|
||
"申请人",
|
||
"国家药品监督管理局",
|
||
"填表说明",
|
||
"注",
|
||
"保证书",
|
||
"应附资料",
|
||
"优先通道申请",
|
||
"分类编码",
|
||
"医疗器械唯一标识",
|
||
"注册产品目前是否",
|
||
"临床评价路径",
|
||
"临床试验",
|
||
"其他需要说明的问题",
|
||
"国家药监局器审中心医疗器械",
|
||
]
|
||
|
||
|
||
def collect_document_texts(summary_batch: FileSummaryBatch) -> dict[str, str]:
|
||
texts: dict[str, str] = {}
|
||
for item in summary_batch.items.order_by("file_index"):
|
||
path = Path(item.storage_path)
|
||
if not path.is_absolute():
|
||
path = Path(settings.MEDIA_ROOT) / item.storage_path
|
||
if not path.exists():
|
||
continue
|
||
result = extract_text(path)
|
||
if result.status == "success" and result.text:
|
||
texts[item.file_name] = result.text
|
||
return texts
|
||
|
||
|
||
def extract_by_rules(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
|
||
fields: list[dict[str, Any]] = []
|
||
field_defs = _field_defs(specs)
|
||
labels = _all_field_labels(field_defs)
|
||
for file_name, text in texts.items():
|
||
source_role = detect_source_role(file_name, text)
|
||
for field in field_defs:
|
||
value, evidence = _extract_field_value(text, field, labels)
|
||
if not value:
|
||
continue
|
||
fields.append(
|
||
ExtractedField(
|
||
key=field["key"],
|
||
label=field["label"],
|
||
value=value,
|
||
source_file=file_name,
|
||
source_role=source_role,
|
||
evidence=evidence,
|
||
extractor="rule",
|
||
confidence=0.75 if source_role == "说明书" else 0.65,
|
||
).__dict__
|
||
)
|
||
return {"fields": fields, "checklist_items": []}
|
||
|
||
|
||
def extract_by_llm(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
|
||
try:
|
||
raw = generate_completion(
|
||
[
|
||
{"role": "system", "content": _prompt_text()},
|
||
{"role": "user", "content": _build_llm_user_prompt(texts, specs)},
|
||
],
|
||
temperature=0.0,
|
||
)
|
||
payload = _parse_json_object(raw)
|
||
except Exception as exc:
|
||
return {"fields": [], "checklist_items": [], "error_message": str(exc)}
|
||
|
||
fields = []
|
||
allowed_keys = {field["key"] for field in _field_defs(specs)}
|
||
for item in payload.get("fields") or []:
|
||
if not isinstance(item, dict) or item.get("key") not in allowed_keys or not item.get("value"):
|
||
continue
|
||
fields.append(
|
||
{
|
||
"key": str(item.get("key") or ""),
|
||
"label": str(item.get("label") or item.get("key") or ""),
|
||
"value": str(item.get("value") or "").strip(),
|
||
"source_file": str(item.get("source_file") or ""),
|
||
"source_role": str(item.get("source_role") or detect_source_role(str(item.get("source_file") or ""), "")),
|
||
"evidence": str(item.get("evidence") or "").strip(),
|
||
"extractor": "llm",
|
||
"confidence": _float_confidence(item.get("confidence"), default=0.7),
|
||
}
|
||
)
|
||
return {"fields": fields, "checklist_items": payload.get("checklist_items") or []}
|
||
|
||
|
||
def run_parallel_extract(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
|
||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||
rule_future = executor.submit(extract_by_rules, texts, specs)
|
||
llm_future = executor.submit(extract_by_llm, texts, specs)
|
||
regex_results = rule_future.result()
|
||
llm_results = llm_future.result()
|
||
return {
|
||
"regex_results": regex_results,
|
||
"llm_results": llm_results,
|
||
"selected_templates": [spec.code for spec in specs],
|
||
"source_evidence": [{"source_file": name, "char_count": len(text)} for name, text in texts.items()],
|
||
}
|
||
|
||
|
||
def save_field_extract_result(batch: ApplicationFormFillBatch, payload: dict[str, Any]) -> ApplicationFormFillArtifact:
|
||
target_dir = ensure_batch_subdir(batch, "exports")
|
||
path = target_dir / "field_extract_result.json"
|
||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
return create_artifact_for_file(
|
||
batch,
|
||
path=path,
|
||
artifact_type=ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT,
|
||
file_format=ApplicationFormFillArtifact.FileFormat.JSON,
|
||
name="field_extract_result",
|
||
metadata={"artifact": "field_extract_result"},
|
||
created_by_node="field_extract",
|
||
)
|
||
|
||
|
||
def detect_source_role(file_name: str, text: str = "") -> str:
|
||
target = f"{file_name}\n{text[:200]}"
|
||
if "说明书" in target:
|
||
return "说明书"
|
||
if "产品技术要求" in target:
|
||
return "产品技术要求"
|
||
if "注册检验" in target or "检测报告" in target:
|
||
return "注册检验报告"
|
||
if "性能研究" in target:
|
||
return "性能研究资料"
|
||
if "申请表" in target:
|
||
return "申请表"
|
||
return "其他注册资料"
|
||
|
||
|
||
def _field_defs(specs: list[TemplateSpec]) -> list[dict[str, str]]:
|
||
fields: list[dict[str, str]] = []
|
||
for spec in specs:
|
||
for field in spec.fields:
|
||
key = str(field.get("key") or "")
|
||
label = str(field.get("label") or "")
|
||
if key and label:
|
||
fields.append({"key": key, "label": label})
|
||
return fields
|
||
|
||
|
||
def _extract_field_value(text: str, field: dict[str, str], labels: list[str]) -> tuple[str, str]:
|
||
aliases = _field_aliases(field)
|
||
for label in aliases:
|
||
value, evidence = _extract_colon_label_value(text, label, labels + aliases)
|
||
if value:
|
||
return value, evidence
|
||
value, evidence = _extract_bracket_section_value(text, label)
|
||
if value:
|
||
return value, evidence
|
||
return "", ""
|
||
|
||
|
||
def _field_aliases(field: dict[str, str]) -> list[str]:
|
||
aliases = [field["label"]]
|
||
aliases.extend(FIELD_ALIASES.get(field["key"], []))
|
||
result: list[str] = []
|
||
for alias in aliases:
|
||
normalized = str(alias or "").strip()
|
||
if normalized and normalized not in result:
|
||
result.append(normalized)
|
||
return result
|
||
|
||
|
||
def _all_field_labels(fields: list[dict[str, str]]) -> list[str]:
|
||
labels: list[str] = list(STATIC_STOP_LABELS)
|
||
for field in fields:
|
||
for label in _field_aliases(field):
|
||
if label not in labels:
|
||
labels.append(label)
|
||
return labels
|
||
|
||
|
||
def _extract_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]:
|
||
return _extract_colon_label_value(text, label, labels)
|
||
|
||
|
||
def _extract_colon_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]:
|
||
escaped_labels = "|".join(re.escape(item) for item in labels if item != label)
|
||
stop_pattern = rf"(?=\n\s*(?:{escaped_labels})(?:\s*[::]|\s*$))" if escaped_labels else r"(?=\Z)"
|
||
pattern = re.compile(rf"{re.escape(label)}\s*[::]\s*(.+?)(?:{stop_pattern}|\Z)", re.S)
|
||
match = pattern.search(text or "")
|
||
if not match:
|
||
return "", ""
|
||
raw = match.group(1).strip()
|
||
value = re.sub(r"\n{2,}.*\Z", "", raw, flags=re.S).strip()
|
||
value = "\n".join(line.strip() for line in value.splitlines() if line.strip())
|
||
evidence = f"{label}:{value}"[:300]
|
||
return value, evidence
|
||
|
||
|
||
def _extract_bracket_section_value(text: str, label: str) -> tuple[str, str]:
|
||
heading_pattern = rf"^\s*[【\[]\s*{re.escape(label)}\s*[】\]]\s*$"
|
||
lines = (text or "").splitlines()
|
||
for index, line in enumerate(lines):
|
||
if not re.match(heading_pattern, line.strip()):
|
||
continue
|
||
value_parts: list[str] = []
|
||
for next_line in lines[index + 1 :]:
|
||
normalized = next_line.strip()
|
||
if not normalized:
|
||
continue
|
||
if _looks_like_bracket_heading(normalized):
|
||
break
|
||
value_parts.append(normalized)
|
||
value = "\n".join(value_parts).strip()
|
||
if value:
|
||
return value, f"【{label}】\n{value}"[:300]
|
||
return "", ""
|
||
|
||
|
||
def _looks_like_bracket_heading(line: str) -> bool:
|
||
return bool(re.match(r"^\s*[【\[].{1,40}[】\]]\s*$", line))
|
||
|
||
|
||
def _prompt_text() -> str:
|
||
path = Path(__file__).resolve().parents[1] / "prompts" / "field_extract.md"
|
||
return path.read_text(encoding="utf-8")
|
||
|
||
|
||
def _build_llm_user_prompt(texts: dict[str, str], specs: list[TemplateSpec]) -> str:
|
||
fields = [{"key": field["key"], "label": field["label"]} for field in _field_defs(specs)]
|
||
documents = [{"source_file": name, "text": text[:4000]} for name, text in texts.items()]
|
||
return json.dumps({"fields": fields, "documents": documents}, ensure_ascii=False)
|
||
|
||
|
||
def _parse_json_object(raw: str) -> dict[str, Any]:
|
||
text = (raw or "").strip()
|
||
if text.startswith("```"):
|
||
text = text.strip("`").strip()
|
||
if text.lower().startswith("json"):
|
||
text = text[4:].strip()
|
||
start = text.find("{")
|
||
end = text.rfind("}")
|
||
if start == -1 or end == -1 or end < start:
|
||
raise json.JSONDecodeError("未找到 JSON 对象", text, 0)
|
||
return json.loads(text[start : end + 1])
|
||
|
||
|
||
def _float_confidence(value, *, default: float) -> float:
|
||
try:
|
||
return float(value)
|
||
except (TypeError, ValueError):
|
||
return default
|