323 lines
12 KiB
Python
323 lines
12 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
from pathlib import Path
|
||
|
||
from docx import Document
|
||
from docx.enum.text import WD_COLOR_INDEX
|
||
from docx.shared import RGBColor
|
||
from django.utils import timezone
|
||
|
||
from review_agent.regulatory_info_package.schemas import MergedField
|
||
|
||
|
||
PLACEHOLDER_RE = re.compile(r"\{\{([a-zA-Z0-9_]+)\}\}")
|
||
|
||
|
||
def write_docx_from_template(
|
||
source_path: str | Path,
|
||
output_path: str | Path,
|
||
merged_fields: dict[str, MergedField],
|
||
*,
|
||
template_code: str = "",
|
||
directory_page_numbers: dict[str, str] | None = None,
|
||
) -> tuple[int, int, int]:
|
||
source = Path(source_path)
|
||
output = Path(output_path)
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
if source.exists():
|
||
document = Document(source)
|
||
else:
|
||
document = Document()
|
||
replacements = {f"{{{{{key}}}}}": field for key, field in merged_fields.items()}
|
||
highlight_count = 0
|
||
missing_count = 0
|
||
llm_only_count = 0
|
||
highlight_count += _apply_known_template_replacements(document, merged_fields, template_code=template_code)
|
||
if template_code == "ch1_5_product_list":
|
||
_rebuild_product_list_table(document, merged_fields)
|
||
if template_code == "ch1_2_directory":
|
||
_apply_directory_page_numbers(document, directory_page_numbers or {})
|
||
paragraph_counts = _replace_placeholders(document, replacements, merged_fields)
|
||
highlight_count += paragraph_counts[0]
|
||
missing_count += paragraph_counts[1]
|
||
llm_only_count += paragraph_counts[2]
|
||
document.save(output)
|
||
return highlight_count, missing_count, llm_only_count
|
||
|
||
|
||
def _replace_paragraph_text(paragraph, text: str, field: MergedField) -> None:
|
||
for run in paragraph.runs:
|
||
run.text = ""
|
||
run = paragraph.add_run(text)
|
||
if field.highlight_reason != "none":
|
||
run.font.highlight_color = WD_COLOR_INDEX.YELLOW
|
||
if field.highlight_reason == "conflict":
|
||
run.font.color.rgb = RGBColor(255, 0, 0)
|
||
|
||
|
||
def _apply_directory_page_numbers(document, page_numbers: dict[str, str]) -> None:
|
||
for table in document.tables:
|
||
if not table.rows:
|
||
continue
|
||
header = [cell.text.strip() for cell in table.rows[0].cells]
|
||
if len(header) < 5 or header[0] != "RPS目录" or header[4] != "页码":
|
||
continue
|
||
for row in table.rows[1:]:
|
||
code = row.cells[0].text.strip()
|
||
if code in page_numbers:
|
||
row.cells[4].text = page_numbers[code]
|
||
return
|
||
|
||
|
||
def _replace_placeholders(
|
||
document,
|
||
replacements: dict[str, MergedField],
|
||
merged_fields: dict[str, MergedField],
|
||
) -> tuple[int, int, int]:
|
||
highlight_count = 0
|
||
missing_count = 0
|
||
llm_only_count = 0
|
||
for paragraph in _iter_paragraphs(document):
|
||
text = paragraph.text
|
||
if "{{" not in text or "}}" not in text:
|
||
continue
|
||
used_fields: list[MergedField] = []
|
||
|
||
def replace(match: re.Match[str]) -> str:
|
||
key = match.group(1)
|
||
placeholder = match.group(0)
|
||
field = replacements.get(placeholder) or _default_placeholder_field(key, merged_fields)
|
||
used_fields.append(field)
|
||
return field.value
|
||
|
||
new_text = PLACEHOLDER_RE.sub(replace, text)
|
||
if new_text == text:
|
||
continue
|
||
field_for_style = next((field for field in used_fields if field.highlight_reason != "none"), None) or used_fields[0]
|
||
_replace_paragraph_text(paragraph, new_text, field_for_style)
|
||
for field in used_fields:
|
||
if field.highlight_reason != "none":
|
||
highlight_count += 1
|
||
if field.highlight_reason == "missing":
|
||
missing_count += 1
|
||
if field.highlight_reason == "llm_only":
|
||
llm_only_count += 1
|
||
return highlight_count, missing_count, llm_only_count
|
||
|
||
|
||
def _iter_paragraphs(document):
|
||
yield from document.paragraphs
|
||
for table in document.tables:
|
||
for row in table.rows:
|
||
for cell in row.cells:
|
||
yield from cell.paragraphs
|
||
|
||
|
||
def _apply_known_template_replacements(document, merged_fields: dict[str, MergedField], *, template_code: str = "") -> int:
|
||
product = _field_value(merged_fields, "product_name")
|
||
applicant = _field_value(merged_fields, "applicant_name")
|
||
today = timezone.localdate().strftime("%Y年%m月%d日")
|
||
replacements = {
|
||
"xxxx年xx月xx日": today,
|
||
"XXXX年XX月XX日": today,
|
||
"xxxx 年 xx 月 xx 日": today,
|
||
"XXXX 年 XX 月 XX 日": today,
|
||
"2023年09月20日": today,
|
||
"2023 年 10 月": today[:8],
|
||
}
|
||
if not template_code.startswith("ch1_11"):
|
||
replacements.update({
|
||
"呼吸道合胞病毒、肺炎支原体核酸检测试剂盒(荧光PCR法)": product,
|
||
"呼吸道合胞病毒、肺炎支原体核酸检测试剂盒": product,
|
||
"呼吸道合胞病毒 、肺炎支产品名称: 原体核酸检测试剂盒(荧": f"产品名称:{product}",
|
||
"光PCR法)": "",
|
||
"卡尤迪生物科技宜兴有限公司": applicant,
|
||
})
|
||
changed = 0
|
||
for paragraph in document.paragraphs:
|
||
changed += _replace_text_in_paragraph(paragraph, replacements, merged_fields)
|
||
for table in document.tables:
|
||
for row in table.rows:
|
||
for cell in row.cells:
|
||
for paragraph in cell.paragraphs:
|
||
changed += _replace_text_in_paragraph(paragraph, replacements, merged_fields)
|
||
return changed
|
||
|
||
|
||
def _default_placeholder_field(key: str, merged_fields: dict[str, MergedField]) -> MergedField:
|
||
if key == "declaration_date":
|
||
return _plain_field(key, "日期", timezone.localdate().strftime("%Y年%m月%d日"))
|
||
label = key
|
||
for field in merged_fields.values():
|
||
if field.key == key:
|
||
label = field.label
|
||
break
|
||
return MergedField(
|
||
key=key,
|
||
label=label,
|
||
value="/",
|
||
source="missing",
|
||
evidence="模板字段未从说明书中抽取到",
|
||
confidence=0.0,
|
||
highlight_reason="missing",
|
||
needs_review=True,
|
||
)
|
||
|
||
|
||
def _replace_text_in_paragraph(paragraph, replacements: dict[str, str], merged_fields: dict[str, MergedField]) -> int:
|
||
text = paragraph.text
|
||
new_text = text
|
||
for old, new in replacements.items():
|
||
if old in new_text:
|
||
new_text = new_text.replace(old, new)
|
||
if new_text == text:
|
||
return 0
|
||
field = merged_fields.get("product_name") or MergedField(
|
||
key="product_name",
|
||
label="产品名称",
|
||
value=new_text,
|
||
source="rule",
|
||
evidence="",
|
||
confidence=0.0,
|
||
)
|
||
_replace_paragraph_text(paragraph, new_text, field)
|
||
return 1
|
||
|
||
|
||
def _rebuild_product_list_table(document, merged_fields: dict[str, MergedField]) -> None:
|
||
product = _field_value(merged_fields, "product_name")
|
||
package_specification = _field_value(merged_fields, "package_specification")
|
||
component_table = _component_table_payload(merged_fields)
|
||
component_notes = _field_value(merged_fields, "component_notes")
|
||
for paragraph in document.paragraphs:
|
||
if "的包装规格、货号、组分及主要组成成分见下表" in paragraph.text:
|
||
_replace_paragraph_text(
|
||
paragraph,
|
||
f"{product}的包装规格、货号、组分及主要组成成分见下表:",
|
||
merged_fields.get("product_name") or _plain_field("product_name", "产品名称", product),
|
||
)
|
||
if "规格A和规格B的区别" in paragraph.text and component_notes != "/":
|
||
_replace_paragraph_text(
|
||
paragraph,
|
||
component_notes,
|
||
merged_fields.get("component_notes") or _plain_field("component_notes", "主要组成成分备注", component_notes),
|
||
)
|
||
target = None
|
||
for table in document.tables:
|
||
header = [cell.text.strip() for cell in table.rows[0].cells] if table.rows else []
|
||
if header[:6] == ["包装规格", "货号", "组成", "组分", "主要组成成分", "规格/数量"]:
|
||
target = table
|
||
break
|
||
specs = _component_specs(component_table) or [
|
||
(spec, None) for spec in [item.strip() for item in package_specification.replace(";", ";").split(";") if item.strip()]
|
||
]
|
||
if target is not None:
|
||
_clear_table_body(target)
|
||
if component_table:
|
||
_fill_product_component_table(target, component_table, specs)
|
||
else:
|
||
if not specs:
|
||
specs = [("/", None)]
|
||
for spec, _index in specs[:8]:
|
||
cells = target.add_row().cells
|
||
cells[0].text = spec
|
||
cells[1].text = "/"
|
||
cells[2].text = _field_value(merged_fields, "composition")
|
||
cells[3].text = _field_value(merged_fields, "component_name")
|
||
cells[4].text = _field_value(merged_fields, "main_component")
|
||
cells[5].text = _field_value(merged_fields, "quantity")
|
||
if component_table:
|
||
_rebuild_component_comparison_table(document, component_table, specs)
|
||
|
||
|
||
def _field_value(merged_fields: dict[str, MergedField], key: str) -> str:
|
||
field = merged_fields.get(key)
|
||
if not field or not field.value:
|
||
return "/"
|
||
return field.value
|
||
|
||
|
||
def _plain_field(key: str, label: str, value: str) -> MergedField:
|
||
return MergedField(key=key, label=label, value=value, source="rule", evidence="", confidence=0.0)
|
||
|
||
|
||
def _component_table_payload(merged_fields: dict[str, MergedField]) -> dict:
|
||
field = merged_fields.get("component_table")
|
||
if not field or not field.value or field.value == "/":
|
||
return {}
|
||
try:
|
||
payload = json.loads(field.value)
|
||
except json.JSONDecodeError:
|
||
return {}
|
||
if not isinstance(payload, dict):
|
||
return {}
|
||
rows = payload.get("rows") or []
|
||
header = payload.get("header") or []
|
||
if not isinstance(header, list) or not isinstance(rows, list):
|
||
return {}
|
||
return {"header": header, "rows": rows}
|
||
|
||
|
||
def _component_specs(component_table: dict) -> list[tuple[str, int]]:
|
||
header = component_table.get("header") or []
|
||
specs: list[tuple[str, int]] = []
|
||
for index, value in enumerate(header[2:], start=2):
|
||
label = str(value or "").strip()
|
||
if not label:
|
||
continue
|
||
label = label.replace("规格(", "").replace("规格(", "").rstrip("))")
|
||
specs.append((label, index))
|
||
return specs
|
||
|
||
|
||
def _clear_table_body(table) -> None:
|
||
while len(table.rows) > 1:
|
||
table._tbl.remove(table.rows[-1]._tr)
|
||
|
||
|
||
def _fill_product_component_table(table, component_table: dict, specs: list[tuple[str, int]]) -> None:
|
||
rows = component_table.get("rows") or []
|
||
for spec_label, spec_index in specs:
|
||
for row in rows:
|
||
cells = table.add_row().cells
|
||
cells[0].text = spec_label
|
||
cells[1].text = "/"
|
||
cells[2].text = "/"
|
||
cells[3].text = _row_value(row, 0)
|
||
cells[4].text = _row_value(row, 1)
|
||
cells[5].text = _row_value(row, spec_index or 0)
|
||
|
||
|
||
def _rebuild_component_comparison_table(document, component_table: dict, specs: list[tuple[str, int]]) -> None:
|
||
target = None
|
||
for table in document.tables:
|
||
header = [cell.text.strip() for cell in table.rows[0].cells] if table.rows else []
|
||
if header and header[0] == "组分名称":
|
||
target = table
|
||
break
|
||
if target is None:
|
||
return
|
||
_clear_table_body(target)
|
||
header_cells = target.rows[0].cells
|
||
labels = ["组分名称", *[spec for spec, _index in specs[: len(header_cells) - 1]]]
|
||
while len(labels) < len(header_cells):
|
||
labels.append("备注")
|
||
for index, label in enumerate(labels[: len(header_cells)]):
|
||
header_cells[index].text = label
|
||
for row in component_table.get("rows") or []:
|
||
cells = target.add_row().cells
|
||
cells[0].text = _row_value(row, 0)
|
||
for cell_index, (_spec_label, spec_index) in enumerate(specs[: len(cells) - 1], start=1):
|
||
cells[cell_index].text = _row_value(row, spec_index)
|
||
for cell_index in range(len(specs[: len(cells) - 1]) + 1, len(cells)):
|
||
cells[cell_index].text = "/"
|
||
|
||
|
||
def _row_value(row, index: int) -> str:
|
||
if not isinstance(row, list) or index >= len(row):
|
||
return "/"
|
||
value = str(row[index] or "").strip()
|
||
return value or "/"
|