feat(application-form-fill): 实现字段抽取与冲突合并
This commit is contained in:
187
review_agent/application_form_fill/services/field_extract.py
Normal file
187
review_agent/application_form_fill/services/field_extract.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from review_agent.application_form_fill.schemas import ExtractedField, TemplateSpec
|
||||
from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir
|
||||
from review_agent.llm import generate_completion
|
||||
from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch, FileSummaryBatch
|
||||
from review_agent.regulatory_review.services.text_extract import extract_text
|
||||
|
||||
|
||||
def collect_document_texts(summary_batch: FileSummaryBatch) -> dict[str, str]:
|
||||
texts: dict[str, str] = {}
|
||||
for item in summary_batch.items.order_by("file_index"):
|
||||
path = Path(item.storage_path)
|
||||
if not path.is_absolute():
|
||||
path = Path(settings.MEDIA_ROOT) / item.storage_path
|
||||
if not path.exists():
|
||||
continue
|
||||
result = extract_text(path)
|
||||
if result.status == "success" and result.text:
|
||||
texts[item.file_name] = result.text
|
||||
return texts
|
||||
|
||||
|
||||
def extract_by_rules(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
|
||||
fields: list[dict[str, Any]] = []
|
||||
field_defs = _field_defs(specs)
|
||||
labels = [field["label"] for field in field_defs if field.get("label")]
|
||||
for file_name, text in texts.items():
|
||||
source_role = detect_source_role(file_name, text)
|
||||
for field in field_defs:
|
||||
value, evidence = _extract_label_value(text, field["label"], labels)
|
||||
if not value:
|
||||
continue
|
||||
fields.append(
|
||||
ExtractedField(
|
||||
key=field["key"],
|
||||
label=field["label"],
|
||||
value=value,
|
||||
source_file=file_name,
|
||||
source_role=source_role,
|
||||
evidence=evidence,
|
||||
extractor="rule",
|
||||
confidence=0.75 if source_role == "说明书" else 0.65,
|
||||
).__dict__
|
||||
)
|
||||
return {"fields": fields, "checklist_items": []}
|
||||
|
||||
|
||||
def extract_by_llm(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
|
||||
try:
|
||||
raw = generate_completion(
|
||||
[
|
||||
{"role": "system", "content": _prompt_text()},
|
||||
{"role": "user", "content": _build_llm_user_prompt(texts, specs)},
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
payload = _parse_json_object(raw)
|
||||
except Exception as exc:
|
||||
return {"fields": [], "checklist_items": [], "error_message": str(exc)}
|
||||
|
||||
fields = []
|
||||
allowed_keys = {field["key"] for field in _field_defs(specs)}
|
||||
for item in payload.get("fields") or []:
|
||||
if not isinstance(item, dict) or item.get("key") not in allowed_keys or not item.get("value"):
|
||||
continue
|
||||
fields.append(
|
||||
{
|
||||
"key": str(item.get("key") or ""),
|
||||
"label": str(item.get("label") or item.get("key") or ""),
|
||||
"value": str(item.get("value") or "").strip(),
|
||||
"source_file": str(item.get("source_file") or ""),
|
||||
"source_role": str(item.get("source_role") or detect_source_role(str(item.get("source_file") or ""), "")),
|
||||
"evidence": str(item.get("evidence") or "").strip(),
|
||||
"extractor": "llm",
|
||||
"confidence": _float_confidence(item.get("confidence"), default=0.7),
|
||||
}
|
||||
)
|
||||
return {"fields": fields, "checklist_items": payload.get("checklist_items") or []}
|
||||
|
||||
|
||||
def run_parallel_extract(texts: dict[str, str], specs: list[TemplateSpec]) -> dict[str, Any]:
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
rule_future = executor.submit(extract_by_rules, texts, specs)
|
||||
llm_future = executor.submit(extract_by_llm, texts, specs)
|
||||
regex_results = rule_future.result()
|
||||
llm_results = llm_future.result()
|
||||
return {
|
||||
"regex_results": regex_results,
|
||||
"llm_results": llm_results,
|
||||
"selected_templates": [spec.code for spec in specs],
|
||||
"source_evidence": [{"source_file": name, "char_count": len(text)} for name, text in texts.items()],
|
||||
}
|
||||
|
||||
|
||||
def save_field_extract_result(batch: ApplicationFormFillBatch, payload: dict[str, Any]) -> ApplicationFormFillArtifact:
|
||||
target_dir = ensure_batch_subdir(batch, "exports")
|
||||
path = target_dir / "field_extract_result.json"
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
return create_artifact_for_file(
|
||||
batch,
|
||||
path=path,
|
||||
artifact_type=ApplicationFormFillArtifact.ArtifactType.FIELD_EXTRACT_RESULT,
|
||||
file_format=ApplicationFormFillArtifact.FileFormat.JSON,
|
||||
name="field_extract_result",
|
||||
metadata={"artifact": "field_extract_result"},
|
||||
created_by_node="field_extract",
|
||||
)
|
||||
|
||||
|
||||
def detect_source_role(file_name: str, text: str = "") -> str:
|
||||
target = f"{file_name}\n{text[:200]}"
|
||||
if "说明书" in target:
|
||||
return "说明书"
|
||||
if "产品技术要求" in target:
|
||||
return "产品技术要求"
|
||||
if "注册检验" in target or "检测报告" in target:
|
||||
return "注册检验报告"
|
||||
if "性能研究" in target:
|
||||
return "性能研究资料"
|
||||
if "申请表" in target:
|
||||
return "申请表"
|
||||
return "其他注册资料"
|
||||
|
||||
|
||||
def _field_defs(specs: list[TemplateSpec]) -> list[dict[str, str]]:
|
||||
fields: list[dict[str, str]] = []
|
||||
for spec in specs:
|
||||
for field in spec.fields:
|
||||
key = str(field.get("key") or "")
|
||||
label = str(field.get("label") or "")
|
||||
if key and label:
|
||||
fields.append({"key": key, "label": label})
|
||||
return fields
|
||||
|
||||
|
||||
def _extract_label_value(text: str, label: str, labels: list[str]) -> tuple[str, str]:
|
||||
escaped_labels = "|".join(re.escape(item) for item in labels if item != label)
|
||||
stop_pattern = rf"(?=\n\s*(?:{escaped_labels})\s*[::])" if escaped_labels else r"(?=\Z)"
|
||||
pattern = re.compile(rf"{re.escape(label)}\s*[::]\s*(.+?)(?:{stop_pattern}|\Z)", re.S)
|
||||
match = pattern.search(text or "")
|
||||
if not match:
|
||||
return "", ""
|
||||
raw = match.group(1).strip()
|
||||
value = re.sub(r"\n{2,}.*\Z", "", raw, flags=re.S).strip()
|
||||
value = "\n".join(line.strip() for line in value.splitlines() if line.strip())
|
||||
evidence = f"{label}:{value}"[:300]
|
||||
return value, evidence
|
||||
|
||||
|
||||
def _prompt_text() -> str:
|
||||
path = Path(__file__).resolve().parents[1] / "prompts" / "field_extract.md"
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _build_llm_user_prompt(texts: dict[str, str], specs: list[TemplateSpec]) -> str:
|
||||
fields = [{"key": field["key"], "label": field["label"]} for field in _field_defs(specs)]
|
||||
documents = [{"source_file": name, "text": text[:4000]} for name, text in texts.items()]
|
||||
return json.dumps({"fields": fields, "documents": documents}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _parse_json_object(raw: str) -> dict[str, Any]:
|
||||
text = (raw or "").strip()
|
||||
if text.startswith("```"):
|
||||
text = text.strip("`").strip()
|
||||
if text.lower().startswith("json"):
|
||||
text = text[4:].strip()
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start == -1 or end == -1 or end < start:
|
||||
raise json.JSONDecodeError("未找到 JSON 对象", text, 0)
|
||||
return json.loads(text[start : end + 1])
|
||||
|
||||
|
||||
def _float_confidence(value, *, default: float) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
Reference in New Issue
Block a user