from __future__ import annotations from typing import Any from review_agent.application_form_fill.constants import ( TEMPLATE_CHANGE_REGISTRATION, TEMPLATE_ESSENTIAL_PRINCIPLES, TEMPLATE_REGISTRATION_CERTIFICATE, ) from review_agent.application_form_fill.schemas import TemplateSpec from review_agent.models import ApplicationFormFillBatch ALL_TEMPLATE_CODES = [ TEMPLATE_REGISTRATION_CERTIFICATE, TEMPLATE_CHANGE_REGISTRATION, TEMPLATE_ESSENTIAL_PRINCIPLES, ] def parse_requested_templates(message: str) -> list[str]: normalized = (message or "").lower() if any(keyword in normalized for keyword in ["全部模板", "所有模板", "全套模板", "全部表格", "所有表格"]): return ALL_TEMPLATE_CODES.copy() requested: list[str] = [] if "注册证" in normalized and "变更注册" not in normalized and "变更 注册" not in normalized: requested.append(TEMPLATE_REGISTRATION_CERTIFICATE) if any(keyword in normalized for keyword in ["变更注册", "变更 注册", "变更备案", "备案文件"]): requested.append(TEMPLATE_CHANGE_REGISTRATION) if any(keyword in normalized for keyword in ["安全和性能基本原则", "基本原则清单", "原则清单"]): requested.append(TEMPLATE_ESSENTIAL_PRINCIPLES) return _dedupe(requested) def detect_registration_type( *, batch: ApplicationFormFillBatch | None = None, message: str = "", file_candidates: dict[str, Any] | None = None, ) -> tuple[str, str]: user_value = _registration_type_from_text(message) if user_value: return user_value, ApplicationFormFillBatch.RegistrationTypeSource.USER_MESSAGE regulatory_value = _registration_type_from_regulatory_batch(batch) if regulatory_value: return regulatory_value, ApplicationFormFillBatch.RegistrationTypeSource.REGULATORY_BATCH file_value = _registration_type_from_candidates(file_candidates or {}) if file_value: return file_value, ApplicationFormFillBatch.RegistrationTypeSource.FILE_EXTRACT return "unknown", ApplicationFormFillBatch.RegistrationTypeSource.UNKNOWN def select_templates( config: dict[str, Any], requested_templates: list[str], registration_type: str, ) -> tuple[list[TemplateSpec], list[dict[str, str]]]: template_map = {item.get("code"): item for item in config.get("templates") or []} risk_notes: list[dict[str, str]] = [] if requested_templates: selected_codes = _dedupe(requested_templates) elif registration_type in {"变更注册", "备案"}: selected_codes = [TEMPLATE_CHANGE_REGISTRATION, TEMPLATE_ESSENTIAL_PRINCIPLES] else: selected_codes = [TEMPLATE_REGISTRATION_CERTIFICATE, TEMPLATE_ESSENTIAL_PRINCIPLES] specs: list[TemplateSpec] = [] for code in selected_codes: raw = template_map.get(code) if not raw: risk_notes.append({"type": "unknown_template", "message": f"模板不存在:{code}"}) continue spec = _to_template_spec(raw) if requested_templates and not _template_applies(spec, registration_type): risk_notes.append( { "type": "template_registration_mismatch", "message": f"用户指定模板 {spec.name} 与注册类型 {registration_type or 'unknown'} 可能不匹配,仍按指定生成。", } ) specs.append(spec) return specs, risk_notes def _to_template_spec(raw: dict[str, Any]) -> TemplateSpec: return TemplateSpec( code=str(raw.get("code") or ""), name=str(raw.get("name") or ""), source_file=str(raw.get("source_file") or ""), output_label=str(raw.get("output_label") or raw.get("name") or ""), applies_when=dict(raw.get("applies_when") or {}), file_format=str(raw.get("file_format") or ""), fields=list(raw.get("fields") or []), checklist_items=list(raw.get("checklist_items") or []), ) def _template_applies(spec: TemplateSpec, registration_type: str) -> bool: allowed = spec.applies_when.get("registration_type") or [] if not allowed: return True return registration_type in allowed or (registration_type == "unknown" and "unknown" in allowed) def _registration_type_from_text(message: str) -> str: normalized = (message or "").lower() if any(keyword in normalized for keyword in ["首次注册", "初次注册", "新注册"]): return "首次注册" if "变更注册" in normalized: return "变更注册" if "备案" in normalized: return "备案" return "" def _registration_type_from_regulatory_batch(batch: ApplicationFormFillBatch | None) -> str: if not batch or not batch.source_regulatory_batch_id: return "" condition_json = batch.source_regulatory_batch.condition_json or {} confirmed = condition_json.get("confirmed_conditions") or {} candidates = condition_json.get("candidates") or {} for payload in [confirmed, condition_json, candidates.get("registration_type") or {}]: if isinstance(payload, dict): value = payload.get("registration_type") or payload.get("suggested") or payload.get("value") normalized = _normalize_registration_type(value) if normalized: return normalized return "" def _registration_type_from_candidates(candidates: dict[str, Any]) -> str: value = candidates.get("registration_type") or candidates.get("suggested") if isinstance(value, dict): value = value.get("value") or value.get("suggested") return _normalize_registration_type(value) def _normalize_registration_type(value: Any) -> str: text = str(value or "") if "首次" in text or "初次" in text: return "首次注册" if "变更" in text: return "变更注册" if "备案" in text: return "备案" return "" def _dedupe(values: list[str]) -> list[str]: result: list[str] = [] for value in values: if value and value not in result: result.append(value) return result