Files

323 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import re
from pathlib import Path
from docx import Document
from docx.enum.text import WD_COLOR_INDEX
from docx.shared import RGBColor
from django.utils import timezone
from review_agent.regulatory_info_package.schemas import MergedField
PLACEHOLDER_RE = re.compile(r"\{\{([a-zA-Z0-9_]+)\}\}")
def write_docx_from_template(
source_path: str | Path,
output_path: str | Path,
merged_fields: dict[str, MergedField],
*,
template_code: str = "",
directory_page_numbers: dict[str, str] | None = None,
) -> tuple[int, int, int]:
source = Path(source_path)
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
if source.exists():
document = Document(source)
else:
document = Document()
replacements = {f"{{{{{key}}}}}": field for key, field in merged_fields.items()}
highlight_count = 0
missing_count = 0
llm_only_count = 0
highlight_count += _apply_known_template_replacements(document, merged_fields, template_code=template_code)
if template_code == "ch1_5_product_list":
_rebuild_product_list_table(document, merged_fields)
if template_code == "ch1_2_directory":
_apply_directory_page_numbers(document, directory_page_numbers or {})
paragraph_counts = _replace_placeholders(document, replacements, merged_fields)
highlight_count += paragraph_counts[0]
missing_count += paragraph_counts[1]
llm_only_count += paragraph_counts[2]
document.save(output)
return highlight_count, missing_count, llm_only_count
def _replace_paragraph_text(paragraph, text: str, field: MergedField) -> None:
for run in paragraph.runs:
run.text = ""
run = paragraph.add_run(text)
if field.highlight_reason != "none":
run.font.highlight_color = WD_COLOR_INDEX.YELLOW
if field.highlight_reason == "conflict":
run.font.color.rgb = RGBColor(255, 0, 0)
def _apply_directory_page_numbers(document, page_numbers: dict[str, str]) -> None:
for table in document.tables:
if not table.rows:
continue
header = [cell.text.strip() for cell in table.rows[0].cells]
if len(header) < 5 or header[0] != "RPS目录" or header[4] != "页码":
continue
for row in table.rows[1:]:
code = row.cells[0].text.strip()
if code in page_numbers:
row.cells[4].text = page_numbers[code]
return
def _replace_placeholders(
document,
replacements: dict[str, MergedField],
merged_fields: dict[str, MergedField],
) -> tuple[int, int, int]:
highlight_count = 0
missing_count = 0
llm_only_count = 0
for paragraph in _iter_paragraphs(document):
text = paragraph.text
if "{{" not in text or "}}" not in text:
continue
used_fields: list[MergedField] = []
def replace(match: re.Match[str]) -> str:
key = match.group(1)
placeholder = match.group(0)
field = replacements.get(placeholder) or _default_placeholder_field(key, merged_fields)
used_fields.append(field)
return field.value
new_text = PLACEHOLDER_RE.sub(replace, text)
if new_text == text:
continue
field_for_style = next((field for field in used_fields if field.highlight_reason != "none"), None) or used_fields[0]
_replace_paragraph_text(paragraph, new_text, field_for_style)
for field in used_fields:
if field.highlight_reason != "none":
highlight_count += 1
if field.highlight_reason == "missing":
missing_count += 1
if field.highlight_reason == "llm_only":
llm_only_count += 1
return highlight_count, missing_count, llm_only_count
def _iter_paragraphs(document):
yield from document.paragraphs
for table in document.tables:
for row in table.rows:
for cell in row.cells:
yield from cell.paragraphs
def _apply_known_template_replacements(document, merged_fields: dict[str, MergedField], *, template_code: str = "") -> int:
product = _field_value(merged_fields, "product_name")
applicant = _field_value(merged_fields, "applicant_name")
today = timezone.localdate().strftime("%Y年%m月%d")
replacements = {
"xxxx年xx月xx日": today,
"XXXX年XX月XX日": today,
"xxxx 年 xx 月 xx 日": today,
"XXXX 年 XX 月 XX 日": today,
"2023年09月20日": today,
"2023 年 10 月": today[:8],
}
if not template_code.startswith("ch1_11"):
replacements.update({
"呼吸道合胞病毒、肺炎支原体核酸检测试剂盒荧光PCR法": product,
"呼吸道合胞病毒、肺炎支原体核酸检测试剂盒": product,
"呼吸道合胞病毒 、肺炎支产品名称: 原体核酸检测试剂盒(荧": f"产品名称:{product}",
"光PCR法": "",
"卡尤迪生物科技宜兴有限公司": applicant,
})
changed = 0
for paragraph in document.paragraphs:
changed += _replace_text_in_paragraph(paragraph, replacements, merged_fields)
for table in document.tables:
for row in table.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
changed += _replace_text_in_paragraph(paragraph, replacements, merged_fields)
return changed
def _default_placeholder_field(key: str, merged_fields: dict[str, MergedField]) -> MergedField:
if key == "declaration_date":
return _plain_field(key, "日期", timezone.localdate().strftime("%Y年%m月%d"))
label = key
for field in merged_fields.values():
if field.key == key:
label = field.label
break
return MergedField(
key=key,
label=label,
value="/",
source="missing",
evidence="模板字段未从说明书中抽取到",
confidence=0.0,
highlight_reason="missing",
needs_review=True,
)
def _replace_text_in_paragraph(paragraph, replacements: dict[str, str], merged_fields: dict[str, MergedField]) -> int:
text = paragraph.text
new_text = text
for old, new in replacements.items():
if old in new_text:
new_text = new_text.replace(old, new)
if new_text == text:
return 0
field = merged_fields.get("product_name") or MergedField(
key="product_name",
label="产品名称",
value=new_text,
source="rule",
evidence="",
confidence=0.0,
)
_replace_paragraph_text(paragraph, new_text, field)
return 1
def _rebuild_product_list_table(document, merged_fields: dict[str, MergedField]) -> None:
product = _field_value(merged_fields, "product_name")
package_specification = _field_value(merged_fields, "package_specification")
component_table = _component_table_payload(merged_fields)
component_notes = _field_value(merged_fields, "component_notes")
for paragraph in document.paragraphs:
if "的包装规格、货号、组分及主要组成成分见下表" in paragraph.text:
_replace_paragraph_text(
paragraph,
f"{product}的包装规格、货号、组分及主要组成成分见下表:",
merged_fields.get("product_name") or _plain_field("product_name", "产品名称", product),
)
if "规格A和规格B的区别" in paragraph.text and component_notes != "/":
_replace_paragraph_text(
paragraph,
component_notes,
merged_fields.get("component_notes") or _plain_field("component_notes", "主要组成成分备注", component_notes),
)
target = None
for table in document.tables:
header = [cell.text.strip() for cell in table.rows[0].cells] if table.rows else []
if header[:6] == ["包装规格", "货号", "组成", "组分", "主要组成成分", "规格/数量"]:
target = table
break
specs = _component_specs(component_table) or [
(spec, None) for spec in [item.strip() for item in package_specification.replace("", ";").split(";") if item.strip()]
]
if target is not None:
_clear_table_body(target)
if component_table:
_fill_product_component_table(target, component_table, specs)
else:
if not specs:
specs = [("/", None)]
for spec, _index in specs[:8]:
cells = target.add_row().cells
cells[0].text = spec
cells[1].text = "/"
cells[2].text = _field_value(merged_fields, "composition")
cells[3].text = _field_value(merged_fields, "component_name")
cells[4].text = _field_value(merged_fields, "main_component")
cells[5].text = _field_value(merged_fields, "quantity")
if component_table:
_rebuild_component_comparison_table(document, component_table, specs)
def _field_value(merged_fields: dict[str, MergedField], key: str) -> str:
field = merged_fields.get(key)
if not field or not field.value:
return "/"
return field.value
def _plain_field(key: str, label: str, value: str) -> MergedField:
return MergedField(key=key, label=label, value=value, source="rule", evidence="", confidence=0.0)
def _component_table_payload(merged_fields: dict[str, MergedField]) -> dict:
field = merged_fields.get("component_table")
if not field or not field.value or field.value == "/":
return {}
try:
payload = json.loads(field.value)
except json.JSONDecodeError:
return {}
if not isinstance(payload, dict):
return {}
rows = payload.get("rows") or []
header = payload.get("header") or []
if not isinstance(header, list) or not isinstance(rows, list):
return {}
return {"header": header, "rows": rows}
def _component_specs(component_table: dict) -> list[tuple[str, int]]:
header = component_table.get("header") or []
specs: list[tuple[str, int]] = []
for index, value in enumerate(header[2:], start=2):
label = str(value or "").strip()
if not label:
continue
label = label.replace("规格(", "").replace("规格(", "").rstrip(")")
specs.append((label, index))
return specs
def _clear_table_body(table) -> None:
while len(table.rows) > 1:
table._tbl.remove(table.rows[-1]._tr)
def _fill_product_component_table(table, component_table: dict, specs: list[tuple[str, int]]) -> None:
rows = component_table.get("rows") or []
for spec_label, spec_index in specs:
for row in rows:
cells = table.add_row().cells
cells[0].text = spec_label
cells[1].text = "/"
cells[2].text = "/"
cells[3].text = _row_value(row, 0)
cells[4].text = _row_value(row, 1)
cells[5].text = _row_value(row, spec_index or 0)
def _rebuild_component_comparison_table(document, component_table: dict, specs: list[tuple[str, int]]) -> None:
target = None
for table in document.tables:
header = [cell.text.strip() for cell in table.rows[0].cells] if table.rows else []
if header and header[0] == "组分名称":
target = table
break
if target is None:
return
_clear_table_body(target)
header_cells = target.rows[0].cells
labels = ["组分名称", *[spec for spec, _index in specs[: len(header_cells) - 1]]]
while len(labels) < len(header_cells):
labels.append("备注")
for index, label in enumerate(labels[: len(header_cells)]):
header_cells[index].text = label
for row in component_table.get("rows") or []:
cells = target.add_row().cells
cells[0].text = _row_value(row, 0)
for cell_index, (_spec_label, spec_index) in enumerate(specs[: len(cells) - 1], start=1):
cells[cell_index].text = _row_value(row, spec_index)
for cell_index in range(len(specs[: len(cells) - 1]) + 1, len(cells)):
cells[cell_index].text = "/"
def _row_value(row, index: int) -> str:
if not isinstance(row, list) or index >= len(row):
return "/"
value = str(row[index] or "").strip()
return value or "/"