from __future__ import annotations import json import re from pathlib import Path from docx import Document from docx.enum.text import WD_COLOR_INDEX from docx.shared import RGBColor from django.utils import timezone from review_agent.regulatory_info_package.schemas import MergedField PLACEHOLDER_RE = re.compile(r"\{\{([a-zA-Z0-9_]+)\}\}") def write_docx_from_template( source_path: str | Path, output_path: str | Path, merged_fields: dict[str, MergedField], *, template_code: str = "", directory_page_numbers: dict[str, str] | None = None, ) -> tuple[int, int, int]: source = Path(source_path) output = Path(output_path) output.parent.mkdir(parents=True, exist_ok=True) if source.exists(): document = Document(source) else: document = Document() replacements = {f"{{{{{key}}}}}": field for key, field in merged_fields.items()} highlight_count = 0 missing_count = 0 llm_only_count = 0 highlight_count += _apply_known_template_replacements(document, merged_fields, template_code=template_code) if template_code == "ch1_5_product_list": _rebuild_product_list_table(document, merged_fields) if template_code == "ch1_2_directory": _apply_directory_page_numbers(document, directory_page_numbers or {}) paragraph_counts = _replace_placeholders(document, replacements, merged_fields) highlight_count += paragraph_counts[0] missing_count += paragraph_counts[1] llm_only_count += paragraph_counts[2] document.save(output) return highlight_count, missing_count, llm_only_count def _replace_paragraph_text(paragraph, text: str, field: MergedField) -> None: for run in paragraph.runs: run.text = "" run = paragraph.add_run(text) if field.highlight_reason != "none": run.font.highlight_color = WD_COLOR_INDEX.YELLOW if field.highlight_reason == "conflict": run.font.color.rgb = RGBColor(255, 0, 0) def _apply_directory_page_numbers(document, page_numbers: dict[str, str]) -> None: for table in document.tables: if not table.rows: continue header = [cell.text.strip() for cell in table.rows[0].cells] if len(header) < 5 or header[0] != "RPS目录" or header[4] != "页码": continue for row in table.rows[1:]: code = row.cells[0].text.strip() if code in page_numbers: row.cells[4].text = page_numbers[code] return def _replace_placeholders( document, replacements: dict[str, MergedField], merged_fields: dict[str, MergedField], ) -> tuple[int, int, int]: highlight_count = 0 missing_count = 0 llm_only_count = 0 for paragraph in _iter_paragraphs(document): text = paragraph.text if "{{" not in text or "}}" not in text: continue used_fields: list[MergedField] = [] def replace(match: re.Match[str]) -> str: key = match.group(1) placeholder = match.group(0) field = replacements.get(placeholder) or _default_placeholder_field(key, merged_fields) used_fields.append(field) return field.value new_text = PLACEHOLDER_RE.sub(replace, text) if new_text == text: continue field_for_style = next((field for field in used_fields if field.highlight_reason != "none"), None) or used_fields[0] _replace_paragraph_text(paragraph, new_text, field_for_style) for field in used_fields: if field.highlight_reason != "none": highlight_count += 1 if field.highlight_reason == "missing": missing_count += 1 if field.highlight_reason == "llm_only": llm_only_count += 1 return highlight_count, missing_count, llm_only_count def _iter_paragraphs(document): yield from document.paragraphs for table in document.tables: for row in table.rows: for cell in row.cells: yield from cell.paragraphs def _apply_known_template_replacements(document, merged_fields: dict[str, MergedField], *, template_code: str = "") -> int: product = _field_value(merged_fields, "product_name") applicant = _field_value(merged_fields, "applicant_name") today = timezone.localdate().strftime("%Y年%m月%d日") replacements = { "xxxx年xx月xx日": today, "XXXX年XX月XX日": today, "xxxx 年 xx 月 xx 日": today, "XXXX 年 XX 月 XX 日": today, "2023年09月20日": today, "2023 年 10 月": today[:8], } if not template_code.startswith("ch1_11"): replacements.update({ "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒(荧光PCR法)": product, "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒": product, "呼吸道合胞病毒 、肺炎支产品名称: 原体核酸检测试剂盒(荧": f"产品名称:{product}", "光PCR法)": "", "卡尤迪生物科技宜兴有限公司": applicant, }) changed = 0 for paragraph in document.paragraphs: changed += _replace_text_in_paragraph(paragraph, replacements, merged_fields) for table in document.tables: for row in table.rows: for cell in row.cells: for paragraph in cell.paragraphs: changed += _replace_text_in_paragraph(paragraph, replacements, merged_fields) return changed def _default_placeholder_field(key: str, merged_fields: dict[str, MergedField]) -> MergedField: if key == "declaration_date": return _plain_field(key, "日期", timezone.localdate().strftime("%Y年%m月%d日")) label = key for field in merged_fields.values(): if field.key == key: label = field.label break return MergedField( key=key, label=label, value="/", source="missing", evidence="模板字段未从说明书中抽取到", confidence=0.0, highlight_reason="missing", needs_review=True, ) def _replace_text_in_paragraph(paragraph, replacements: dict[str, str], merged_fields: dict[str, MergedField]) -> int: text = paragraph.text new_text = text for old, new in replacements.items(): if old in new_text: new_text = new_text.replace(old, new) if new_text == text: return 0 field = merged_fields.get("product_name") or MergedField( key="product_name", label="产品名称", value=new_text, source="rule", evidence="", confidence=0.0, ) _replace_paragraph_text(paragraph, new_text, field) return 1 def _rebuild_product_list_table(document, merged_fields: dict[str, MergedField]) -> None: product = _field_value(merged_fields, "product_name") package_specification = _field_value(merged_fields, "package_specification") component_table = _component_table_payload(merged_fields) component_notes = _field_value(merged_fields, "component_notes") for paragraph in document.paragraphs: if "的包装规格、货号、组分及主要组成成分见下表" in paragraph.text: _replace_paragraph_text( paragraph, f"{product}的包装规格、货号、组分及主要组成成分见下表:", merged_fields.get("product_name") or _plain_field("product_name", "产品名称", product), ) if "规格A和规格B的区别" in paragraph.text and component_notes != "/": _replace_paragraph_text( paragraph, component_notes, merged_fields.get("component_notes") or _plain_field("component_notes", "主要组成成分备注", component_notes), ) target = None for table in document.tables: header = [cell.text.strip() for cell in table.rows[0].cells] if table.rows else [] if header[:6] == ["包装规格", "货号", "组成", "组分", "主要组成成分", "规格/数量"]: target = table break specs = _component_specs(component_table) or [ (spec, None) for spec in [item.strip() for item in package_specification.replace(";", ";").split(";") if item.strip()] ] if target is not None: _clear_table_body(target) if component_table: _fill_product_component_table(target, component_table, specs) else: if not specs: specs = [("/", None)] for spec, _index in specs[:8]: cells = target.add_row().cells cells[0].text = spec cells[1].text = "/" cells[2].text = _field_value(merged_fields, "composition") cells[3].text = _field_value(merged_fields, "component_name") cells[4].text = _field_value(merged_fields, "main_component") cells[5].text = _field_value(merged_fields, "quantity") if component_table: _rebuild_component_comparison_table(document, component_table, specs) def _field_value(merged_fields: dict[str, MergedField], key: str) -> str: field = merged_fields.get(key) if not field or not field.value: return "/" return field.value def _plain_field(key: str, label: str, value: str) -> MergedField: return MergedField(key=key, label=label, value=value, source="rule", evidence="", confidence=0.0) def _component_table_payload(merged_fields: dict[str, MergedField]) -> dict: field = merged_fields.get("component_table") if not field or not field.value or field.value == "/": return {} try: payload = json.loads(field.value) except json.JSONDecodeError: return {} if not isinstance(payload, dict): return {} rows = payload.get("rows") or [] header = payload.get("header") or [] if not isinstance(header, list) or not isinstance(rows, list): return {} return {"header": header, "rows": rows} def _component_specs(component_table: dict) -> list[tuple[str, int]]: header = component_table.get("header") or [] specs: list[tuple[str, int]] = [] for index, value in enumerate(header[2:], start=2): label = str(value or "").strip() if not label: continue label = label.replace("规格(", "").replace("规格(", "").rstrip("))") specs.append((label, index)) return specs def _clear_table_body(table) -> None: while len(table.rows) > 1: table._tbl.remove(table.rows[-1]._tr) def _fill_product_component_table(table, component_table: dict, specs: list[tuple[str, int]]) -> None: rows = component_table.get("rows") or [] for spec_label, spec_index in specs: for row in rows: cells = table.add_row().cells cells[0].text = spec_label cells[1].text = "/" cells[2].text = "/" cells[3].text = _row_value(row, 0) cells[4].text = _row_value(row, 1) cells[5].text = _row_value(row, spec_index or 0) def _rebuild_component_comparison_table(document, component_table: dict, specs: list[tuple[str, int]]) -> None: target = None for table in document.tables: header = [cell.text.strip() for cell in table.rows[0].cells] if table.rows else [] if header and header[0] == "组分名称": target = table break if target is None: return _clear_table_body(target) header_cells = target.rows[0].cells labels = ["组分名称", *[spec for spec, _index in specs[: len(header_cells) - 1]]] while len(labels) < len(header_cells): labels.append("备注") for index, label in enumerate(labels[: len(header_cells)]): header_cells[index].text = label for row in component_table.get("rows") or []: cells = target.add_row().cells cells[0].text = _row_value(row, 0) for cell_index, (_spec_label, spec_index) in enumerate(specs[: len(cells) - 1], start=1): cells[cell_index].text = _row_value(row, spec_index) for cell_index in range(len(specs[: len(cells) - 1]) + 1, len(cells)): cells[cell_index].text = "/" def _row_value(row, index: int) -> str: if not isinstance(row, list) or index >= len(row): return "/" value = str(row[index] or "").strip() return value or "/"