Files
DEMO-AGENT/review_agent/regulatory_info_package/services/docx_document.py

250 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import re
from pathlib import Path
from docx import Document
from docx.enum.text import WD_COLOR_INDEX
from docx.shared import RGBColor
from django.utils import timezone
from review_agent.regulatory_info_package.schemas import MergedField
PLACEHOLDER_RE = re.compile(r"\{\{([a-zA-Z0-9_]+)\}\}")
def write_docx_from_template(
source_path: str | Path,
output_path: str | Path,
merged_fields: dict[str, MergedField],
*,
template_code: str = "",
) -> tuple[int, int, int]:
source = Path(source_path)
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
if source.exists():
document = Document(source)
else:
document = Document()
replacements = {f"{{{{{key}}}}}": field for key, field in merged_fields.items()}
highlight_count = 0
missing_count = 0
llm_only_count = 0
highlight_count, missing_count, llm_only_count = _insert_prefill_block(document, merged_fields)
highlight_count += _apply_known_template_replacements(document, merged_fields)
if template_code == "ch1_5_product_list":
_rebuild_product_list_table(document, merged_fields)
paragraph_counts = _replace_placeholders(document, replacements, merged_fields)
highlight_count += paragraph_counts[0]
missing_count += paragraph_counts[1]
llm_only_count += paragraph_counts[2]
document.add_page_break()
heading = document.add_paragraph()
heading_run = heading.add_run("预生成字段")
heading_run.bold = True
table = document.add_table(rows=1, cols=4)
table.rows[0].cells[0].text = "字段"
table.rows[0].cells[1].text = ""
table.rows[0].cells[2].text = "来源"
table.rows[0].cells[3].text = "待确认"
for field in merged_fields.values():
cells = table.add_row().cells
cells[0].text = field.label
cells[1].text = field.value
cells[2].text = field.source
cells[3].text = "" if field.needs_review else ""
if field.highlight_reason != "none":
highlight_count += 1
if field.highlight_reason == "missing":
missing_count += 1
if field.highlight_reason == "llm_only":
llm_only_count += 1
document.save(output)
return highlight_count, missing_count, llm_only_count
def _insert_prefill_block(document, merged_fields: dict[str, MergedField]) -> tuple[int, int, int]:
first = document.paragraphs[0] if document.paragraphs else document.add_paragraph()
marker = first.insert_paragraph_before("【预生成版】以下字段由系统根据说明书预填,黄色或红色标记项请人工复核。")
marker.runs[0].bold = True
highlight_count = 0
missing_count = 0
llm_only_count = 0
for field in merged_fields.values():
paragraph = marker.insert_paragraph_before("")
run = paragraph.add_run(f"{field.label}{field.value}")
if field.highlight_reason != "none":
run.font.highlight_color = WD_COLOR_INDEX.YELLOW
highlight_count += 1
if field.highlight_reason == "conflict":
run.font.color.rgb = RGBColor(255, 0, 0)
if field.highlight_reason == "missing":
missing_count += 1
if field.highlight_reason == "llm_only":
llm_only_count += 1
return highlight_count, missing_count, llm_only_count
def _replace_paragraph_text(paragraph, text: str, field: MergedField) -> None:
for run in paragraph.runs:
run.text = ""
run = paragraph.add_run(text)
if field.highlight_reason != "none":
run.font.highlight_color = WD_COLOR_INDEX.YELLOW
if field.highlight_reason == "conflict":
run.font.color.rgb = RGBColor(255, 0, 0)
def _replace_placeholders(
document,
replacements: dict[str, MergedField],
merged_fields: dict[str, MergedField],
) -> tuple[int, int, int]:
highlight_count = 0
missing_count = 0
llm_only_count = 0
for paragraph in _iter_paragraphs(document):
text = paragraph.text
if "{{" not in text or "}}" not in text:
continue
used_fields: list[MergedField] = []
def replace(match: re.Match[str]) -> str:
key = match.group(1)
placeholder = match.group(0)
field = replacements.get(placeholder) or _default_placeholder_field(key, merged_fields)
used_fields.append(field)
return field.value
new_text = PLACEHOLDER_RE.sub(replace, text)
if new_text == text:
continue
field_for_style = next((field for field in used_fields if field.highlight_reason != "none"), None) or used_fields[0]
_replace_paragraph_text(paragraph, new_text, field_for_style)
for field in used_fields:
if field.highlight_reason != "none":
highlight_count += 1
if field.highlight_reason == "missing":
missing_count += 1
if field.highlight_reason == "llm_only":
llm_only_count += 1
return highlight_count, missing_count, llm_only_count
def _iter_paragraphs(document):
yield from document.paragraphs
for table in document.tables:
for row in table.rows:
for cell in row.cells:
yield from cell.paragraphs
def _apply_known_template_replacements(document, merged_fields: dict[str, MergedField]) -> int:
product = _field_value(merged_fields, "product_name")
applicant = _field_value(merged_fields, "applicant_name")
today = timezone.localdate().strftime("%Y年%m月%d")
replacements = {
"呼吸道合胞病毒、肺炎支原体核酸检测试剂盒荧光PCR法": product,
"呼吸道合胞病毒、肺炎支原体核酸检测试剂盒": product,
"呼吸道合胞病毒 、肺炎支产品名称: 原体核酸检测试剂盒(荧": f"产品名称:{product}",
"光PCR法": "",
"卡尤迪生物科技宜兴有限公司": applicant,
"2023年09月20日": today,
"2023 年 10 月": today[:8],
}
changed = 0
for paragraph in document.paragraphs:
changed += _replace_text_in_paragraph(paragraph, replacements, merged_fields)
for table in document.tables:
for row in table.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
changed += _replace_text_in_paragraph(paragraph, replacements, merged_fields)
return changed
def _default_placeholder_field(key: str, merged_fields: dict[str, MergedField]) -> MergedField:
if key == "declaration_date":
return _plain_field(key, "日期", timezone.localdate().strftime("%Y年%m月%d"))
label = key
for field in merged_fields.values():
if field.key == key:
label = field.label
break
return MergedField(
key=key,
label=label,
value="/",
source="missing",
evidence="模板字段未从说明书中抽取到",
confidence=0.0,
highlight_reason="missing",
needs_review=True,
)
def _replace_text_in_paragraph(paragraph, replacements: dict[str, str], merged_fields: dict[str, MergedField]) -> int:
text = paragraph.text
new_text = text
for old, new in replacements.items():
if old in new_text:
new_text = new_text.replace(old, new)
if new_text == text:
return 0
field = merged_fields.get("product_name") or MergedField(
key="product_name",
label="产品名称",
value=new_text,
source="rule",
evidence="",
confidence=0.0,
)
_replace_paragraph_text(paragraph, new_text, field)
return 1
def _rebuild_product_list_table(document, merged_fields: dict[str, MergedField]) -> None:
product = _field_value(merged_fields, "product_name")
package_specification = _field_value(merged_fields, "package_specification")
for paragraph in document.paragraphs:
if "的包装规格、货号、组分及主要组成成分见下表" in paragraph.text:
_replace_paragraph_text(
paragraph,
f"{product}的包装规格、货号、组分及主要组成成分见下表:",
merged_fields.get("product_name") or _plain_field("product_name", "产品名称", product),
)
target = None
for table in document.tables:
header = [cell.text.strip() for cell in table.rows[0].cells] if table.rows else []
if header[:6] == ["包装规格", "货号", "组成", "组分", "主要组成成分", "规格/数量"]:
target = table
break
if target is None:
return
while len(target.rows) > 1:
target._tbl.remove(target.rows[-1]._tr)
specs = [item.strip() for item in package_specification.replace("", ";").split(";") if item.strip()]
if not specs:
specs = ["/"]
for spec in specs[:8]:
cells = target.add_row().cells
cells[0].text = spec
cells[1].text = "/"
cells[2].text = _field_value(merged_fields, "composition")
cells[3].text = _field_value(merged_fields, "component_name")
cells[4].text = _field_value(merged_fields, "main_component")
cells[5].text = _field_value(merged_fields, "quantity")
def _field_value(merged_fields: dict[str, MergedField], key: str) -> str:
field = merged_fields.get(key)
if not field or not field.value:
return "/"
return field.value
def _plain_field(key: str, label: str, value: str) -> MergedField:
return MergedField(key=key, label=label, value=value, source="rule", evidence="", confidence=0.0)