feat(application-form-fill): 实现自动填表模板选择
This commit is contained in:
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from review_agent.application_form_fill.schemas import TemplateSpec
|
||||
from review_agent.application_form_fill.storage import create_artifact_for_file, ensure_batch_subdir
|
||||
from review_agent.models import ApplicationFormFillArtifact, ApplicationFormFillBatch
|
||||
|
||||
|
||||
class TemplateUnavailableError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def resolve_source_template(spec: TemplateSpec, config: dict[str, Any]) -> Path:
|
||||
source_dir = Path(settings.BASE_DIR) / str(config.get("source_dir") or "")
|
||||
working_template = getattr(spec, "working_template", "") or ""
|
||||
if spec.file_format == "doc" and working_template:
|
||||
candidate = source_dir / working_template
|
||||
else:
|
||||
candidate = source_dir / spec.source_file
|
||||
if not candidate.exists():
|
||||
raise TemplateUnavailableError(f"模板文件不存在:{spec.source_file}")
|
||||
if spec.file_format == "doc" and candidate.suffix.lower() == ".doc":
|
||||
raise TemplateUnavailableError(f"模板 {spec.code} 为 .doc,当前阶段需预转换为 .docx 后使用。")
|
||||
return candidate
|
||||
|
||||
|
||||
def copy_template_to_batch(
|
||||
spec: TemplateSpec,
|
||||
batch: ApplicationFormFillBatch,
|
||||
config: dict[str, Any],
|
||||
) -> ApplicationFormFillArtifact:
|
||||
source = resolve_source_template(spec, config)
|
||||
target_dir = ensure_batch_subdir(batch, "templates")
|
||||
target = target_dir / f"{spec.code}.source{source.suffix.lower()}"
|
||||
shutil.copy2(source, target)
|
||||
_ensure_under(target, Path(batch.work_dir))
|
||||
return create_artifact_for_file(
|
||||
batch,
|
||||
path=target,
|
||||
artifact_type=ApplicationFormFillArtifact.ArtifactType.TEMPLATE_COPY,
|
||||
file_format=source.suffix.lower().lstrip(".") or spec.file_format,
|
||||
name=spec.name,
|
||||
metadata={"template_code": spec.code, "source_file": spec.source_file},
|
||||
created_by_node="template_copy",
|
||||
)
|
||||
|
||||
|
||||
def _ensure_under(path: Path, root: Path) -> None:
|
||||
resolved_path = path.resolve()
|
||||
resolved_root = root.resolve()
|
||||
if resolved_path != resolved_root and resolved_root not in resolved_path.parents:
|
||||
raise ValueError(f"模板复制目标不在批次工作目录内:{path}")
|
||||
158
review_agent/application_form_fill/services/template_select.py
Normal file
158
review_agent/application_form_fill/services/template_select.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from review_agent.application_form_fill.constants import (
|
||||
TEMPLATE_CHANGE_REGISTRATION,
|
||||
TEMPLATE_ESSENTIAL_PRINCIPLES,
|
||||
TEMPLATE_REGISTRATION_CERTIFICATE,
|
||||
)
|
||||
from review_agent.application_form_fill.schemas import TemplateSpec
|
||||
from review_agent.models import ApplicationFormFillBatch
|
||||
|
||||
|
||||
ALL_TEMPLATE_CODES = [
|
||||
TEMPLATE_REGISTRATION_CERTIFICATE,
|
||||
TEMPLATE_CHANGE_REGISTRATION,
|
||||
TEMPLATE_ESSENTIAL_PRINCIPLES,
|
||||
]
|
||||
|
||||
|
||||
def parse_requested_templates(message: str) -> list[str]:
|
||||
normalized = (message or "").lower()
|
||||
if any(keyword in normalized for keyword in ["全部模板", "所有模板", "全套模板", "全部表格", "所有表格"]):
|
||||
return ALL_TEMPLATE_CODES.copy()
|
||||
|
||||
requested: list[str] = []
|
||||
if "注册证" in normalized and "变更注册" not in normalized and "变更 注册" not in normalized:
|
||||
requested.append(TEMPLATE_REGISTRATION_CERTIFICATE)
|
||||
if any(keyword in normalized for keyword in ["变更注册", "变更 注册", "变更备案", "备案文件"]):
|
||||
requested.append(TEMPLATE_CHANGE_REGISTRATION)
|
||||
if any(keyword in normalized for keyword in ["安全和性能基本原则", "基本原则清单", "原则清单"]):
|
||||
requested.append(TEMPLATE_ESSENTIAL_PRINCIPLES)
|
||||
return _dedupe(requested)
|
||||
|
||||
|
||||
def detect_registration_type(
|
||||
*,
|
||||
batch: ApplicationFormFillBatch | None = None,
|
||||
message: str = "",
|
||||
file_candidates: dict[str, Any] | None = None,
|
||||
) -> tuple[str, str]:
|
||||
user_value = _registration_type_from_text(message)
|
||||
if user_value:
|
||||
return user_value, ApplicationFormFillBatch.RegistrationTypeSource.USER_MESSAGE
|
||||
|
||||
regulatory_value = _registration_type_from_regulatory_batch(batch)
|
||||
if regulatory_value:
|
||||
return regulatory_value, ApplicationFormFillBatch.RegistrationTypeSource.REGULATORY_BATCH
|
||||
|
||||
file_value = _registration_type_from_candidates(file_candidates or {})
|
||||
if file_value:
|
||||
return file_value, ApplicationFormFillBatch.RegistrationTypeSource.FILE_EXTRACT
|
||||
|
||||
return "unknown", ApplicationFormFillBatch.RegistrationTypeSource.UNKNOWN
|
||||
|
||||
|
||||
def select_templates(
|
||||
config: dict[str, Any],
|
||||
requested_templates: list[str],
|
||||
registration_type: str,
|
||||
) -> tuple[list[TemplateSpec], list[dict[str, str]]]:
|
||||
template_map = {item.get("code"): item for item in config.get("templates") or []}
|
||||
risk_notes: list[dict[str, str]] = []
|
||||
if requested_templates:
|
||||
selected_codes = _dedupe(requested_templates)
|
||||
elif registration_type in {"变更注册", "备案"}:
|
||||
selected_codes = [TEMPLATE_CHANGE_REGISTRATION, TEMPLATE_ESSENTIAL_PRINCIPLES]
|
||||
else:
|
||||
selected_codes = [TEMPLATE_REGISTRATION_CERTIFICATE, TEMPLATE_ESSENTIAL_PRINCIPLES]
|
||||
|
||||
specs: list[TemplateSpec] = []
|
||||
for code in selected_codes:
|
||||
raw = template_map.get(code)
|
||||
if not raw:
|
||||
risk_notes.append({"type": "unknown_template", "message": f"模板不存在:{code}"})
|
||||
continue
|
||||
spec = _to_template_spec(raw)
|
||||
if requested_templates and not _template_applies(spec, registration_type):
|
||||
risk_notes.append(
|
||||
{
|
||||
"type": "template_registration_mismatch",
|
||||
"message": f"用户指定模板 {spec.name} 与注册类型 {registration_type or 'unknown'} 可能不匹配,仍按指定生成。",
|
||||
}
|
||||
)
|
||||
specs.append(spec)
|
||||
return specs, risk_notes
|
||||
|
||||
|
||||
def _to_template_spec(raw: dict[str, Any]) -> TemplateSpec:
|
||||
return TemplateSpec(
|
||||
code=str(raw.get("code") or ""),
|
||||
name=str(raw.get("name") or ""),
|
||||
source_file=str(raw.get("source_file") or ""),
|
||||
output_label=str(raw.get("output_label") or raw.get("name") or ""),
|
||||
applies_when=dict(raw.get("applies_when") or {}),
|
||||
file_format=str(raw.get("file_format") or ""),
|
||||
fields=list(raw.get("fields") or []),
|
||||
checklist_items=list(raw.get("checklist_items") or []),
|
||||
)
|
||||
|
||||
|
||||
def _template_applies(spec: TemplateSpec, registration_type: str) -> bool:
|
||||
allowed = spec.applies_when.get("registration_type") or []
|
||||
if not allowed:
|
||||
return True
|
||||
return registration_type in allowed or (registration_type == "unknown" and "unknown" in allowed)
|
||||
|
||||
|
||||
def _registration_type_from_text(message: str) -> str:
|
||||
normalized = (message or "").lower()
|
||||
if any(keyword in normalized for keyword in ["首次注册", "初次注册", "新注册"]):
|
||||
return "首次注册"
|
||||
if "变更注册" in normalized:
|
||||
return "变更注册"
|
||||
if "备案" in normalized:
|
||||
return "备案"
|
||||
return ""
|
||||
|
||||
|
||||
def _registration_type_from_regulatory_batch(batch: ApplicationFormFillBatch | None) -> str:
|
||||
if not batch or not batch.source_regulatory_batch_id:
|
||||
return ""
|
||||
condition_json = batch.source_regulatory_batch.condition_json or {}
|
||||
confirmed = condition_json.get("confirmed_conditions") or {}
|
||||
candidates = condition_json.get("candidates") or {}
|
||||
for payload in [confirmed, condition_json, candidates.get("registration_type") or {}]:
|
||||
if isinstance(payload, dict):
|
||||
value = payload.get("registration_type") or payload.get("suggested") or payload.get("value")
|
||||
normalized = _normalize_registration_type(value)
|
||||
if normalized:
|
||||
return normalized
|
||||
return ""
|
||||
|
||||
|
||||
def _registration_type_from_candidates(candidates: dict[str, Any]) -> str:
|
||||
value = candidates.get("registration_type") or candidates.get("suggested")
|
||||
if isinstance(value, dict):
|
||||
value = value.get("value") or value.get("suggested")
|
||||
return _normalize_registration_type(value)
|
||||
|
||||
|
||||
def _normalize_registration_type(value: Any) -> str:
|
||||
text = str(value or "")
|
||||
if "首次" in text or "初次" in text:
|
||||
return "首次注册"
|
||||
if "变更" in text:
|
||||
return "变更注册"
|
||||
if "备案" in text:
|
||||
return "备案"
|
||||
return ""
|
||||
|
||||
|
||||
def _dedupe(values: list[str]) -> list[str]:
|
||||
result: list[str] = []
|
||||
for value in values:
|
||||
if value and value not in result:
|
||||
result.append(value)
|
||||
return result
|
||||
60
tests/test_application_form_fill_template_repository.py
Normal file
60
tests/test_application_form_fill_template_repository.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
|
||||
from review_agent.application_form_fill.services.template_config import load_template_config
|
||||
from review_agent.application_form_fill.services.template_repository import (
|
||||
TemplateUnavailableError,
|
||||
copy_template_to_batch,
|
||||
resolve_source_template,
|
||||
)
|
||||
from review_agent.application_form_fill.services.template_select import select_templates
|
||||
from review_agent.models import (
|
||||
ApplicationFormFillArtifact,
|
||||
ApplicationFormFillBatch,
|
||||
Conversation,
|
||||
FileSummaryBatch,
|
||||
)
|
||||
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
def test_resolve_source_template_finds_registration_docx():
|
||||
config = load_template_config()
|
||||
specs, _risk_notes = select_templates(config, ["registration_certificate"], "首次注册")
|
||||
|
||||
path = resolve_source_template(specs[0], config)
|
||||
|
||||
assert path.exists()
|
||||
assert path.name == "中华人民共和国医疗器械注册证(体外诊断试剂)(格式).docx"
|
||||
|
||||
|
||||
def test_copy_template_to_batch_creates_artifact(settings, tmp_path, django_user_model):
|
||||
settings.MEDIA_ROOT = tmp_path
|
||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||
conversation = Conversation.objects.create(user=user, title="会话")
|
||||
summary = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-REPO")
|
||||
batch = ApplicationFormFillBatch.objects.create(
|
||||
conversation=conversation,
|
||||
user=user,
|
||||
source_summary_batch=summary,
|
||||
batch_no="AFF-REPO",
|
||||
work_dir=str(tmp_path / "aff" / "AFF-REPO"),
|
||||
)
|
||||
config = load_template_config()
|
||||
specs, _risk_notes = select_templates(config, ["registration_certificate"], "首次注册")
|
||||
|
||||
artifact = copy_template_to_batch(specs[0], batch, config)
|
||||
|
||||
assert artifact.artifact_type == ApplicationFormFillArtifact.ArtifactType.TEMPLATE_COPY
|
||||
assert artifact.file_format == "docx"
|
||||
assert artifact.content_hash
|
||||
assert artifact.metadata["template_code"] == "registration_certificate"
|
||||
assert artifact.storage_path.startswith(batch.work_dir)
|
||||
|
||||
|
||||
def test_doc_template_without_working_docx_is_unavailable():
|
||||
config = load_template_config()
|
||||
specs, _risk_notes = select_templates(config, ["change_registration"], "变更注册")
|
||||
|
||||
with pytest.raises(TemplateUnavailableError):
|
||||
resolve_source_template(specs[0], config)
|
||||
114
tests/test_application_form_fill_template_select.py
Normal file
114
tests/test_application_form_fill_template_select.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import pytest
|
||||
|
||||
from review_agent.application_form_fill.services.template_config import load_template_config
|
||||
from review_agent.application_form_fill.services.template_select import (
|
||||
detect_registration_type,
|
||||
parse_requested_templates,
|
||||
select_templates,
|
||||
)
|
||||
from review_agent.models import ApplicationFormFillBatch, Conversation, FileSummaryBatch, RegulatoryReviewBatch
|
||||
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("message", "expected"),
|
||||
[
|
||||
("帮我填注册证", ["registration_certificate"]),
|
||||
("生成变更注册备案文件", ["change_registration"]),
|
||||
("生成安全和性能基本原则清单", ["essential_principles"]),
|
||||
("请生成全部模板", ["registration_certificate", "change_registration", "essential_principles"]),
|
||||
("普通聊天", []),
|
||||
],
|
||||
)
|
||||
def test_parse_requested_templates(message, expected):
|
||||
assert parse_requested_templates(message) == expected
|
||||
|
||||
|
||||
def test_detect_registration_type_prefers_user_message(django_user_model):
|
||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||
conversation = Conversation.objects.create(user=user, title="会话")
|
||||
summary = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-SEL")
|
||||
regulatory = RegulatoryReviewBatch.objects.create(
|
||||
conversation=conversation,
|
||||
user=user,
|
||||
source_summary_batch=summary,
|
||||
batch_no="RR-SEL",
|
||||
condition_json={"confirmed_conditions": {"registration_type": "变更注册"}},
|
||||
)
|
||||
batch = ApplicationFormFillBatch.objects.create(
|
||||
conversation=conversation,
|
||||
user=user,
|
||||
source_summary_batch=summary,
|
||||
source_regulatory_batch=regulatory,
|
||||
batch_no="AFF-SEL",
|
||||
)
|
||||
|
||||
value, source = detect_registration_type(batch=batch, message="首次注册资料,请填注册证")
|
||||
|
||||
assert value == "首次注册"
|
||||
assert source == ApplicationFormFillBatch.RegistrationTypeSource.USER_MESSAGE
|
||||
|
||||
|
||||
def test_detect_registration_type_falls_back_to_regulatory_batch_and_file_candidates(django_user_model):
|
||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||
conversation = Conversation.objects.create(user=user, title="会话")
|
||||
summary = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-SEL-2")
|
||||
regulatory = RegulatoryReviewBatch.objects.create(
|
||||
conversation=conversation,
|
||||
user=user,
|
||||
source_summary_batch=summary,
|
||||
batch_no="RR-SEL-2",
|
||||
condition_json={"confirmed_conditions": {"registration_type": "变更注册"}},
|
||||
)
|
||||
batch = ApplicationFormFillBatch.objects.create(
|
||||
conversation=conversation,
|
||||
user=user,
|
||||
source_summary_batch=summary,
|
||||
source_regulatory_batch=regulatory,
|
||||
batch_no="AFF-SEL-2",
|
||||
)
|
||||
|
||||
regulatory_value, regulatory_source = detect_registration_type(batch=batch, message="")
|
||||
file_value, file_source = detect_registration_type(
|
||||
message="",
|
||||
file_candidates={"registration_type": {"suggested": "备案"}},
|
||||
)
|
||||
|
||||
assert (regulatory_value, regulatory_source) == (
|
||||
"变更注册",
|
||||
ApplicationFormFillBatch.RegistrationTypeSource.REGULATORY_BATCH,
|
||||
)
|
||||
assert (file_value, file_source) == (
|
||||
"备案",
|
||||
ApplicationFormFillBatch.RegistrationTypeSource.FILE_EXTRACT,
|
||||
)
|
||||
|
||||
|
||||
def test_select_default_templates_for_initial_registration():
|
||||
config = load_template_config()
|
||||
|
||||
specs, risk_notes = select_templates(config, [], "首次注册")
|
||||
|
||||
assert [spec.code for spec in specs] == ["registration_certificate", "essential_principles"]
|
||||
assert risk_notes == []
|
||||
|
||||
|
||||
def test_select_default_templates_for_change_registration():
|
||||
config = load_template_config()
|
||||
|
||||
specs, risk_notes = select_templates(config, [], "变更注册")
|
||||
|
||||
assert [spec.code for spec in specs] == ["change_registration", "essential_principles"]
|
||||
assert risk_notes == []
|
||||
|
||||
|
||||
def test_select_user_requested_mismatch_is_allowed_with_risk_note():
|
||||
config = load_template_config()
|
||||
|
||||
specs, risk_notes = select_templates(config, ["change_registration"], "首次注册")
|
||||
|
||||
assert [spec.code for spec in specs] == ["change_registration"]
|
||||
assert risk_notes
|
||||
assert risk_notes[0]["type"] == "template_registration_mismatch"
|
||||
Reference in New Issue
Block a user