159 lines
6.0 KiB
Python
159 lines
6.0 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from review_agent.application_form_fill.constants import (
|
|
TEMPLATE_CHANGE_REGISTRATION,
|
|
TEMPLATE_ESSENTIAL_PRINCIPLES,
|
|
TEMPLATE_REGISTRATION_CERTIFICATE,
|
|
)
|
|
from review_agent.application_form_fill.schemas import TemplateSpec
|
|
from review_agent.models import ApplicationFormFillBatch
|
|
|
|
|
|
ALL_TEMPLATE_CODES = [
|
|
TEMPLATE_REGISTRATION_CERTIFICATE,
|
|
TEMPLATE_CHANGE_REGISTRATION,
|
|
TEMPLATE_ESSENTIAL_PRINCIPLES,
|
|
]
|
|
|
|
|
|
def parse_requested_templates(message: str) -> list[str]:
|
|
normalized = (message or "").lower()
|
|
if any(keyword in normalized for keyword in ["全部模板", "所有模板", "全套模板", "全部表格", "所有表格"]):
|
|
return ALL_TEMPLATE_CODES.copy()
|
|
|
|
requested: list[str] = []
|
|
if "注册证" in normalized and "变更注册" not in normalized and "变更 注册" not in normalized:
|
|
requested.append(TEMPLATE_REGISTRATION_CERTIFICATE)
|
|
if any(keyword in normalized for keyword in ["变更注册", "变更 注册", "变更备案", "备案文件"]):
|
|
requested.append(TEMPLATE_CHANGE_REGISTRATION)
|
|
if any(keyword in normalized for keyword in ["安全和性能基本原则", "基本原则清单", "原则清单"]):
|
|
requested.append(TEMPLATE_ESSENTIAL_PRINCIPLES)
|
|
return _dedupe(requested)
|
|
|
|
|
|
def detect_registration_type(
|
|
*,
|
|
batch: ApplicationFormFillBatch | None = None,
|
|
message: str = "",
|
|
file_candidates: dict[str, Any] | None = None,
|
|
) -> tuple[str, str]:
|
|
user_value = _registration_type_from_text(message)
|
|
if user_value:
|
|
return user_value, ApplicationFormFillBatch.RegistrationTypeSource.USER_MESSAGE
|
|
|
|
regulatory_value = _registration_type_from_regulatory_batch(batch)
|
|
if regulatory_value:
|
|
return regulatory_value, ApplicationFormFillBatch.RegistrationTypeSource.REGULATORY_BATCH
|
|
|
|
file_value = _registration_type_from_candidates(file_candidates or {})
|
|
if file_value:
|
|
return file_value, ApplicationFormFillBatch.RegistrationTypeSource.FILE_EXTRACT
|
|
|
|
return "unknown", ApplicationFormFillBatch.RegistrationTypeSource.UNKNOWN
|
|
|
|
|
|
def select_templates(
|
|
config: dict[str, Any],
|
|
requested_templates: list[str],
|
|
registration_type: str,
|
|
) -> tuple[list[TemplateSpec], list[dict[str, str]]]:
|
|
template_map = {item.get("code"): item for item in config.get("templates") or []}
|
|
risk_notes: list[dict[str, str]] = []
|
|
if requested_templates:
|
|
selected_codes = _dedupe(requested_templates)
|
|
elif registration_type in {"变更注册", "备案"}:
|
|
selected_codes = [TEMPLATE_CHANGE_REGISTRATION, TEMPLATE_ESSENTIAL_PRINCIPLES]
|
|
else:
|
|
selected_codes = [TEMPLATE_REGISTRATION_CERTIFICATE, TEMPLATE_ESSENTIAL_PRINCIPLES]
|
|
|
|
specs: list[TemplateSpec] = []
|
|
for code in selected_codes:
|
|
raw = template_map.get(code)
|
|
if not raw:
|
|
risk_notes.append({"type": "unknown_template", "message": f"模板不存在:{code}"})
|
|
continue
|
|
spec = _to_template_spec(raw)
|
|
if requested_templates and not _template_applies(spec, registration_type):
|
|
risk_notes.append(
|
|
{
|
|
"type": "template_registration_mismatch",
|
|
"message": f"用户指定模板 {spec.name} 与注册类型 {registration_type or 'unknown'} 可能不匹配,仍按指定生成。",
|
|
}
|
|
)
|
|
specs.append(spec)
|
|
return specs, risk_notes
|
|
|
|
|
|
def _to_template_spec(raw: dict[str, Any]) -> TemplateSpec:
|
|
return TemplateSpec(
|
|
code=str(raw.get("code") or ""),
|
|
name=str(raw.get("name") or ""),
|
|
source_file=str(raw.get("source_file") or ""),
|
|
output_label=str(raw.get("output_label") or raw.get("name") or ""),
|
|
applies_when=dict(raw.get("applies_when") or {}),
|
|
file_format=str(raw.get("file_format") or ""),
|
|
fields=list(raw.get("fields") or []),
|
|
checklist_items=list(raw.get("checklist_items") or []),
|
|
)
|
|
|
|
|
|
def _template_applies(spec: TemplateSpec, registration_type: str) -> bool:
|
|
allowed = spec.applies_when.get("registration_type") or []
|
|
if not allowed:
|
|
return True
|
|
return registration_type in allowed or (registration_type == "unknown" and "unknown" in allowed)
|
|
|
|
|
|
def _registration_type_from_text(message: str) -> str:
|
|
normalized = (message or "").lower()
|
|
if any(keyword in normalized for keyword in ["首次注册", "初次注册", "新注册"]):
|
|
return "首次注册"
|
|
if "变更注册" in normalized:
|
|
return "变更注册"
|
|
if "备案" in normalized:
|
|
return "备案"
|
|
return ""
|
|
|
|
|
|
def _registration_type_from_regulatory_batch(batch: ApplicationFormFillBatch | None) -> str:
|
|
if not batch or not batch.source_regulatory_batch_id:
|
|
return ""
|
|
condition_json = batch.source_regulatory_batch.condition_json or {}
|
|
confirmed = condition_json.get("confirmed_conditions") or {}
|
|
candidates = condition_json.get("candidates") or {}
|
|
for payload in [confirmed, condition_json, candidates.get("registration_type") or {}]:
|
|
if isinstance(payload, dict):
|
|
value = payload.get("registration_type") or payload.get("suggested") or payload.get("value")
|
|
normalized = _normalize_registration_type(value)
|
|
if normalized:
|
|
return normalized
|
|
return ""
|
|
|
|
|
|
def _registration_type_from_candidates(candidates: dict[str, Any]) -> str:
|
|
value = candidates.get("registration_type") or candidates.get("suggested")
|
|
if isinstance(value, dict):
|
|
value = value.get("value") or value.get("suggested")
|
|
return _normalize_registration_type(value)
|
|
|
|
|
|
def _normalize_registration_type(value: Any) -> str:
|
|
text = str(value or "")
|
|
if "首次" in text or "初次" in text:
|
|
return "首次注册"
|
|
if "变更" in text:
|
|
return "变更注册"
|
|
if "备案" in text:
|
|
return "备案"
|
|
return ""
|
|
|
|
|
|
def _dedupe(values: list[str]) -> list[str]:
|
|
result: list[str] = []
|
|
for value in values:
|
|
if value and value not in result:
|
|
result.append(value)
|
|
return result
|