diff --git a/review_agent/application_form_fill/services/template_repository.py b/review_agent/application_form_fill/services/template_repository.py new file mode 100644 index 0000000..0b9f691 --- /dev/null +++ b/review_agent/application_form_fill/services/template_repository.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import shutil +from pathlib import Path +from typing import Any + +from django.conf import settings + +from review_agent.application_form_fill.schemas import TemplateSpec +from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir +from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch + + +class TemplateUnavailableError(Exception): + pass + + +def resolve_source_template(spec: TemplateSpec, config: dict[str, Any]) -> Path: + source_dir = Path(settings.BASE_DIR) / str(config.get("source_dir") or "") + working_template = getattr(spec, "working_template", "") or "" + if spec.file_format == "doc" and working_template: + candidate = source_dir / working_template + else: + candidate = source_dir / spec.source_file + if not candidate.exists(): + raise TemplateUnavailableError(f"模板文件不存在:{spec.source_file}") + if spec.file_format == "doc" and candidate.suffix.lower() == ".doc": + raise TemplateUnavailableError(f"模板 {spec.code} 为 .doc,当前阶段需预转换为 .docx 后使用。") + return candidate + + +def copy_template_to_batch( + spec: TemplateSpec, + batch: ApplicationFormFillBatch, + config: dict[str, Any], +) -> ApplicationFormFillArtifact: + source = resolve_source_template(spec, config) + target_dir = ensure_batch_subdir(batch, "templates") + target = target_dir / f"{spec.code}.source{source.suffix.lower()}" + shutil.copy2(source, target) + _ensure_under(target, Path(batch.work_dir)) + return create_artifact_for_file( + batch, + path=target, + artifact_type=ApplicationFormFillArtifact.ArtifactType.TEMPLATE_COPY, + file_format=source.suffix.lower().lstrip(".") or spec.file_format, + name=spec.name, + metadata={"template_code": spec.code, "source_file": spec.source_file}, + created_by_node="template_copy", + ) + + +def _ensure_under(path: Path, root: Path) -> None: + resolved_path = path.resolve() + resolved_root = root.resolve() + if resolved_path != resolved_root and resolved_root not in resolved_path.parents: + raise ValueError(f"模板复制目标不在批次工作目录内:{path}") diff --git a/review_agent/application_form_fill/services/template_select.py b/review_agent/application_form_fill/services/template_select.py new file mode 100644 index 0000000..11c770d --- /dev/null +++ b/review_agent/application_form_fill/services/template_select.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from typing import Any + +from review_agent.application_form_fill.constants import ( + TEMPLATE_CHANGE_REGISTRATION, + TEMPLATE_ESSENTIAL_PRINCIPLES, + TEMPLATE_REGISTRATION_CERTIFICATE, +) +from review_agent.application_form_fill.schemas import TemplateSpec +from review_agent.models import ApplicationFormFillBatch + + +ALL_TEMPLATE_CODES = [ + TEMPLATE_REGISTRATION_CERTIFICATE, + TEMPLATE_CHANGE_REGISTRATION, + TEMPLATE_ESSENTIAL_PRINCIPLES, +] + + +def parse_requested_templates(message: str) -> list[str]: + normalized = (message or "").lower() + if any(keyword in normalized for keyword in ["全部模板", "所有模板", "全套模板", "全部表格", "所有表格"]): + return ALL_TEMPLATE_CODES.copy() + + requested: list[str] = [] + if "注册证" in normalized and "变更注册" not in normalized and "变更 注册" not in normalized: + requested.append(TEMPLATE_REGISTRATION_CERTIFICATE) + if any(keyword in normalized for keyword in ["变更注册", "变更 注册", "变更备案", "备案文件"]): + requested.append(TEMPLATE_CHANGE_REGISTRATION) + if any(keyword in normalized for keyword in ["安全和性能基本原则", "基本原则清单", "原则清单"]): + requested.append(TEMPLATE_ESSENTIAL_PRINCIPLES) + return _dedupe(requested) + + +def detect_registration_type( + *, + batch: ApplicationFormFillBatch | None = None, + message: str = "", + file_candidates: dict[str, Any] | None = None, +) -> tuple[str, str]: + user_value = _registration_type_from_text(message) + if user_value: + return user_value, ApplicationFormFillBatch.RegistrationTypeSource.USER_MESSAGE + + regulatory_value = _registration_type_from_regulatory_batch(batch) + if regulatory_value: + return regulatory_value, ApplicationFormFillBatch.RegistrationTypeSource.REGULATORY_BATCH + + file_value = _registration_type_from_candidates(file_candidates or {}) + if file_value: + return file_value, ApplicationFormFillBatch.RegistrationTypeSource.FILE_EXTRACT + + return "unknown", ApplicationFormFillBatch.RegistrationTypeSource.UNKNOWN + + +def select_templates( + config: dict[str, Any], + requested_templates: list[str], + registration_type: str, +) -> tuple[list[TemplateSpec], list[dict[str, str]]]: + template_map = {item.get("code"): item for item in config.get("templates") or []} + risk_notes: list[dict[str, str]] = [] + if requested_templates: + selected_codes = _dedupe(requested_templates) + elif registration_type in {"变更注册", "备案"}: + selected_codes = [TEMPLATE_CHANGE_REGISTRATION, TEMPLATE_ESSENTIAL_PRINCIPLES] + else: + selected_codes = [TEMPLATE_REGISTRATION_CERTIFICATE, TEMPLATE_ESSENTIAL_PRINCIPLES] + + specs: list[TemplateSpec] = [] + for code in selected_codes: + raw = template_map.get(code) + if not raw: + risk_notes.append({"type": "unknown_template", "message": f"模板不存在:{code}"}) + continue + spec = _to_template_spec(raw) + if requested_templates and not _template_applies(spec, registration_type): + risk_notes.append( + { + "type": "template_registration_mismatch", + "message": f"用户指定模板 {spec.name} 与注册类型 {registration_type or 'unknown'} 可能不匹配,仍按指定生成。", + } + ) + specs.append(spec) + return specs, risk_notes + + +def _to_template_spec(raw: dict[str, Any]) -> TemplateSpec: + return TemplateSpec( + code=str(raw.get("code") or ""), + name=str(raw.get("name") or ""), + source_file=str(raw.get("source_file") or ""), + output_label=str(raw.get("output_label") or raw.get("name") or ""), + applies_when=dict(raw.get("applies_when") or {}), + file_format=str(raw.get("file_format") or ""), + fields=list(raw.get("fields") or []), + checklist_items=list(raw.get("checklist_items") or []), + ) + + +def _template_applies(spec: TemplateSpec, registration_type: str) -> bool: + allowed = spec.applies_when.get("registration_type") or [] + if not allowed: + return True + return registration_type in allowed or (registration_type == "unknown" and "unknown" in allowed) + + +def _registration_type_from_text(message: str) -> str: + normalized = (message or "").lower() + if any(keyword in normalized for keyword in ["首次注册", "初次注册", "新注册"]): + return "首次注册" + if "变更注册" in normalized: + return "变更注册" + if "备案" in normalized: + return "备案" + return "" + + +def _registration_type_from_regulatory_batch(batch: ApplicationFormFillBatch | None) -> str: + if not batch or not batch.source_regulatory_batch_id: + return "" + condition_json = batch.source_regulatory_batch.condition_json or {} + confirmed = condition_json.get("confirmed_conditions") or {} + candidates = condition_json.get("candidates") or {} + for payload in [confirmed, condition_json, candidates.get("registration_type") or {}]: + if isinstance(payload, dict): + value = payload.get("registration_type") or payload.get("suggested") or payload.get("value") + normalized = _normalize_registration_type(value) + if normalized: + return normalized + return "" + + +def _registration_type_from_candidates(candidates: dict[str, Any]) -> str: + value = candidates.get("registration_type") or candidates.get("suggested") + if isinstance(value, dict): + value = value.get("value") or value.get("suggested") + return _normalize_registration_type(value) + + +def _normalize_registration_type(value: Any) -> str: + text = str(value or "") + if "首次" in text or "初次" in text: + return "首次注册" + if "变更" in text: + return "变更注册" + if "备案" in text: + return "备案" + return "" + + +def _dedupe(values: list[str]) -> list[str]: + result: list[str] = [] + for value in values: + if value and value not in result: + result.append(value) + return result diff --git a/tests/test_application_form_fill_template_repository.py b/tests/test_application_form_fill_template_repository.py new file mode 100644 index 0000000..aafa001 --- /dev/null +++ b/tests/test_application_form_fill_template_repository.py @@ -0,0 +1,60 @@ +import pytest + +from review_agent.application_form_fill.services.template_config import load_template_config +from review_agent.application_form_fill.services.template_repository import ( + TemplateUnavailableError, + copy_template_to_batch, + resolve_source_template, +) +from review_agent.application_form_fill.services.template_select import select_templates +from review_agent.models import ( + ApplicationFormFillArtifact, + ApplicationFormFillBatch, + Conversation, + FileSummaryBatch, +) + + +pytestmark = pytest.mark.django_db + + +def test_resolve_source_template_finds_registration_docx(): + config = load_template_config() + specs, _risk_notes = select_templates(config, ["registration_certificate"], "首次注册") + + path = resolve_source_template(specs[0], config) + + assert path.exists() + assert path.name == "中华人民共和国医疗器械注册证(体外诊断试剂)(格式).docx" + + +def test_copy_template_to_batch_creates_artifact(settings, tmp_path, django_user_model): + settings.MEDIA_ROOT = tmp_path + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + summary = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-REPO") + batch = ApplicationFormFillBatch.objects.create( + conversation=conversation, + user=user, + source_summary_batch=summary, + batch_no="AFF-REPO", + work_dir=str(tmp_path / "aff" / "AFF-REPO"), + ) + config = load_template_config() + specs, _risk_notes = select_templates(config, ["registration_certificate"], "首次注册") + + artifact = copy_template_to_batch(specs[0], batch, config) + + assert artifact.artifact_type == ApplicationFormFillArtifact.ArtifactType.TEMPLATE_COPY + assert artifact.file_format == "docx" + assert artifact.content_hash + assert artifact.metadata["template_code"] == "registration_certificate" + assert artifact.storage_path.startswith(batch.work_dir) + + +def test_doc_template_without_working_docx_is_unavailable(): + config = load_template_config() + specs, _risk_notes = select_templates(config, ["change_registration"], "变更注册") + + with pytest.raises(TemplateUnavailableError): + resolve_source_template(specs[0], config) diff --git a/tests/test_application_form_fill_template_select.py b/tests/test_application_form_fill_template_select.py new file mode 100644 index 0000000..dada57e --- /dev/null +++ b/tests/test_application_form_fill_template_select.py @@ -0,0 +1,114 @@ +import pytest + +from review_agent.application_form_fill.services.template_config import load_template_config +from review_agent.application_form_fill.services.template_select import ( + detect_registration_type, + parse_requested_templates, + select_templates, +) +from review_agent.models import ApplicationFormFillBatch, Conversation, FileSummaryBatch, RegulatoryReviewBatch + + +pytestmark = pytest.mark.django_db + + +@pytest.mark.parametrize( + ("message", "expected"), + [ + ("帮我填注册证", ["registration_certificate"]), + ("生成变更注册备案文件", ["change_registration"]), + ("生成安全和性能基本原则清单", ["essential_principles"]), + ("请生成全部模板", ["registration_certificate", "change_registration", "essential_principles"]), + ("普通聊天", []), + ], +) +def test_parse_requested_templates(message, expected): + assert parse_requested_templates(message) == expected + + +def test_detect_registration_type_prefers_user_message(django_user_model): + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + summary = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-SEL") + regulatory = RegulatoryReviewBatch.objects.create( + conversation=conversation, + user=user, + source_summary_batch=summary, + batch_no="RR-SEL", + condition_json={"confirmed_conditions": {"registration_type": "变更注册"}}, + ) + batch = ApplicationFormFillBatch.objects.create( + conversation=conversation, + user=user, + source_summary_batch=summary, + source_regulatory_batch=regulatory, + batch_no="AFF-SEL", + ) + + value, source = detect_registration_type(batch=batch, message="首次注册资料,请填注册证") + + assert value == "首次注册" + assert source == ApplicationFormFillBatch.RegistrationTypeSource.USER_MESSAGE + + +def test_detect_registration_type_falls_back_to_regulatory_batch_and_file_candidates(django_user_model): + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + summary = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-SEL-2") + regulatory = RegulatoryReviewBatch.objects.create( + conversation=conversation, + user=user, + source_summary_batch=summary, + batch_no="RR-SEL-2", + condition_json={"confirmed_conditions": {"registration_type": "变更注册"}}, + ) + batch = ApplicationFormFillBatch.objects.create( + conversation=conversation, + user=user, + source_summary_batch=summary, + source_regulatory_batch=regulatory, + batch_no="AFF-SEL-2", + ) + + regulatory_value, regulatory_source = detect_registration_type(batch=batch, message="") + file_value, file_source = detect_registration_type( + message="", + file_candidates={"registration_type": {"suggested": "备案"}}, + ) + + assert (regulatory_value, regulatory_source) == ( + "变更注册", + ApplicationFormFillBatch.RegistrationTypeSource.REGULATORY_BATCH, + ) + assert (file_value, file_source) == ( + "备案", + ApplicationFormFillBatch.RegistrationTypeSource.FILE_EXTRACT, + ) + + +def test_select_default_templates_for_initial_registration(): + config = load_template_config() + + specs, risk_notes = select_templates(config, [], "首次注册") + + assert [spec.code for spec in specs] == ["registration_certificate", "essential_principles"] + assert risk_notes == [] + + +def test_select_default_templates_for_change_registration(): + config = load_template_config() + + specs, risk_notes = select_templates(config, [], "变更注册") + + assert [spec.code for spec in specs] == ["change_registration", "essential_principles"] + assert risk_notes == [] + + +def test_select_user_requested_mismatch_is_allowed_with_risk_note(): + config = load_template_config() + + specs, risk_notes = select_templates(config, ["change_registration"], "首次注册") + + assert [spec.code for spec in specs] == ["change_registration"] + assert risk_notes + assert risk_notes[0]["type"] == "template_registration_mismatch"