Files
DEMO-AGENT/review_agent/regulatory_review/services/rule_loader.py

128 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from pathlib import Path
import yaml
from django.conf import settings
from review_agent.models import RegulatoryRuleVersion
DEFAULT_RULE_CODE = "nmpa_ivd_registration_v1"
DEFAULT_RULE_PATH = (
Path(settings.BASE_DIR)
/ "review_agent"
/ "regulatory_review"
/ "rules"
/ "nmpa_ivd_registration_v1.yaml"
)
@dataclass(frozen=True)
class RuleVersionCheck:
status: str
code: str
path: Path
current_hash: str
database_hash: str = ""
record: RegulatoryRuleVersion | None = None
def compute_file_sha256(path: str | Path) -> str:
file_path = Path(path)
digest = hashlib.sha256()
with file_path.open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def load_rule_file(path: str | Path | None = None) -> dict:
rule_path = Path(path) if path else DEFAULT_RULE_PATH
with rule_path.open("r", encoding="utf-8") as handle:
payload = yaml.safe_load(handle) or {}
if payload.get("code") != DEFAULT_RULE_CODE:
raise ValueError(f"规则 code 必须为 {DEFAULT_RULE_CODE}")
if not isinstance(payload.get("requirements"), list) or not payload["requirements"]:
raise ValueError("规则文件必须包含 requirements 列表。")
_validate_attachment4_requirements(payload)
return payload
def _validate_attachment4_requirements(payload: dict) -> None:
requirements = payload.get("requirements") or []
required_codes = {str(code) for code in payload.get("attachment4_required_codes") or []}
by_attachment4_code: dict[str, list[dict]] = {}
for requirement in requirements:
attachment4_code = requirement.get("attachment4_code")
if attachment4_code:
by_attachment4_code.setdefault(str(attachment4_code), []).append(requirement)
for field in ["code", "rule_id", "title", "severity", "file_keywords", "citation_query"]:
if attachment4_code and not requirement.get(field):
raise ValueError(f"附件4规则 {attachment4_code} 缺少 {field}")
missing = sorted(required_codes - set(by_attachment4_code), key=_attachment4_sort_key)
if missing:
raise ValueError(f"附件4目录项缺少规则{', '.join(missing)}")
def _attachment4_sort_key(value: str) -> tuple[int, ...]:
return tuple(int(part) for part in value.split(".") if part.isdigit())
def check_rule_version(
*,
path: str | Path | None = None,
update_missing: bool = True,
) -> RuleVersionCheck:
rule_path = Path(path) if path else DEFAULT_RULE_PATH
rule_set = load_rule_file(rule_path)
current_hash = compute_file_sha256(rule_path)
record = RegulatoryRuleVersion.objects.filter(code=rule_set["code"]).first()
yaml_path = str(rule_path.relative_to(settings.BASE_DIR))
if record is None:
if not update_missing:
return RuleVersionCheck(
status="missing",
code=rule_set["code"],
path=rule_path,
current_hash=current_hash,
)
record = RegulatoryRuleVersion.objects.create(
code=rule_set["code"],
name=rule_set.get("name") or rule_set["code"],
yaml_path=yaml_path,
yaml_hash=current_hash,
rag_collection=rule_set.get("rag_collection", ""),
status=RegulatoryRuleVersion.Status.ACTIVE,
)
return RuleVersionCheck(
status="created",
code=record.code,
path=rule_path,
current_hash=current_hash,
database_hash=record.yaml_hash,
record=record,
)
if record.yaml_hash != current_hash:
return RuleVersionCheck(
status="mismatch",
code=record.code,
path=rule_path,
current_hash=current_hash,
database_hash=record.yaml_hash,
record=record,
)
return RuleVersionCheck(
status="ok",
code=record.code,
path=rule_path,
current_hash=current_hash,
database_hash=record.yaml_hash,
record=record,
)