362 lines
14 KiB
Python
362 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from pathlib import Path
|
|
from threading import Thread
|
|
from uuid import uuid4
|
|
|
|
from django.conf import settings
|
|
from django.db import transaction
|
|
from django.utils import timezone
|
|
|
|
from review_agent.models import (
|
|
Conversation,
|
|
FileSummaryBatch,
|
|
Message,
|
|
RegulatoryReviewBatch,
|
|
WorkflowNodeRun,
|
|
)
|
|
from review_agent.regulatory_review.services.completeness_check import run_completeness_check
|
|
from review_agent.regulatory_review.services.consistency_check import run_consistency_check
|
|
from review_agent.regulatory_review.services.export import build_assistant_summary, export_review_results
|
|
from review_agent.regulatory_review.services.feishu_notifier import create_mock_notifications
|
|
from review_agent.regulatory_review.services.info_extract import detect_regulatory_condition_candidates
|
|
from review_agent.regulatory_review.services.risk_assess import persist_findings
|
|
from review_agent.regulatory_review.services.rule_loader import load_rule_file
|
|
from review_agent.regulatory_review.services.structure_check import run_structure_check
|
|
from review_agent.regulatory_review.services.text_extract import extract_text
|
|
|
|
from .events import record_event
|
|
from .storage import save_artifact
|
|
|
|
|
|
NODE_DEFINITIONS = [
|
|
("prepare", "准备", "prepare"),
|
|
("condition_confirm", "适用条件确认", "condition_confirm"),
|
|
("rule_scope", "规则范围", "rule_scope"),
|
|
("completeness_check", "完整性核查", "completeness_check"),
|
|
("text_extract", "文本抽取", "text_extract"),
|
|
("structure_check", "章节核查", "structure_check"),
|
|
("consistency_check", "一致性核查", "consistency_check"),
|
|
("risk_assess", "风险评估", "risk_assess"),
|
|
("report_export", "报告输出", "report_export"),
|
|
("completed", "完成", "completed"),
|
|
]
|
|
|
|
|
|
logger = logging.getLogger("review_agent.regulatory_review.workflow")
|
|
|
|
|
|
ATTACHMENT4_CHAPTER_LABELS = {
|
|
"1": "第1章 监管信息",
|
|
"2": "第2章 综述资料",
|
|
"3": "第3章 非临床资料",
|
|
"4": "第4章 临床评价资料",
|
|
"5": "第5章 产品说明书和标签样稿",
|
|
"6": "第6章 质量管理体系文件",
|
|
}
|
|
|
|
|
|
class WorkflowPausedForUser(Exception):
|
|
pass
|
|
|
|
|
|
def build_batch_no() -> str:
|
|
return f"RR-{timezone.localtime().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[:6]}"
|
|
|
|
|
|
def build_batch_work_dir(batch_no: str) -> Path:
|
|
return Path(settings.MEDIA_ROOT) / "regulatory_review" / "work" / batch_no
|
|
|
|
|
|
def find_latest_successful_summary_batch(conversation: Conversation) -> FileSummaryBatch | None:
|
|
return (
|
|
FileSummaryBatch.objects.filter(
|
|
conversation=conversation,
|
|
status=FileSummaryBatch.Status.SUCCESS,
|
|
)
|
|
.order_by("-finished_at", "-created_at", "-id")
|
|
.first()
|
|
)
|
|
|
|
|
|
@transaction.atomic
|
|
def create_regulatory_review_batch(
|
|
*,
|
|
conversation: Conversation,
|
|
user,
|
|
source_summary_batch: FileSummaryBatch,
|
|
trigger_message: Message | None = None,
|
|
) -> RegulatoryReviewBatch:
|
|
batch_no = build_batch_no()
|
|
work_dir = build_batch_work_dir(batch_no)
|
|
work_dir.mkdir(parents=True, exist_ok=True)
|
|
batch = RegulatoryReviewBatch.objects.create(
|
|
conversation=conversation,
|
|
user=user,
|
|
trigger_message=trigger_message,
|
|
source_summary_batch=source_summary_batch,
|
|
batch_no=batch_no,
|
|
work_dir=str(work_dir),
|
|
condition_json=_initial_condition_json(trigger_message),
|
|
)
|
|
for code, name, group in NODE_DEFINITIONS:
|
|
WorkflowNodeRun.objects.create(
|
|
workflow_type="regulatory_review",
|
|
workflow_batch_id=batch.pk,
|
|
node_group=group,
|
|
node_code=code,
|
|
node_name=name,
|
|
)
|
|
record_event(batch, "workflow_created", {"batch_id": batch.pk, "batch_no": batch.batch_no})
|
|
return batch
|
|
|
|
|
|
class RegulatoryWorkflowExecutor:
|
|
def __init__(self, batch: RegulatoryReviewBatch):
|
|
self.batch = batch
|
|
self.rule_set: dict | None = None
|
|
self.findings = []
|
|
self.document_texts: dict[str, str] = {}
|
|
self.text_extract_status: dict[str, dict[str, object]] = {}
|
|
|
|
def run(self) -> None:
|
|
self.batch.status = RegulatoryReviewBatch.Status.RUNNING
|
|
self.batch.started_at = timezone.now()
|
|
self.batch.save(update_fields=["status", "started_at"])
|
|
record_event(self.batch, "workflow_started", {"batch_id": self.batch.pk})
|
|
|
|
try:
|
|
for node in self._nodes():
|
|
if node.status == WorkflowNodeRun.Status.SUCCESS:
|
|
continue
|
|
self._run_node(node)
|
|
except WorkflowPausedForUser:
|
|
return
|
|
except Exception as exc:
|
|
logger.exception("Regulatory workflow failed", extra={"batch_id": self.batch.pk})
|
|
self.batch.status = RegulatoryReviewBatch.Status.FAILED
|
|
self.batch.error_message = str(exc)
|
|
self.batch.finished_at = timezone.now()
|
|
self.batch.save(update_fields=["status", "error_message", "finished_at"])
|
|
record_event(self.batch, "workflow_failed", {"message": str(exc)})
|
|
return
|
|
|
|
self.batch.status = RegulatoryReviewBatch.Status.SUCCESS
|
|
self.batch.finished_at = timezone.now()
|
|
self.batch.save(update_fields=["status", "finished_at"])
|
|
record_event(self.batch, "workflow_completed", {"batch_id": self.batch.pk})
|
|
|
|
def _nodes(self):
|
|
return WorkflowNodeRun.objects.filter(
|
|
workflow_type="regulatory_review",
|
|
workflow_batch_id=self.batch.pk,
|
|
).order_by("id")
|
|
|
|
def _run_node(self, node: WorkflowNodeRun) -> None:
|
|
node.status = WorkflowNodeRun.Status.RUNNING
|
|
node.progress = 10
|
|
node.started_at = timezone.now()
|
|
node.message = f"{node.node_name}处理中"
|
|
node.save(update_fields=["status", "progress", "started_at", "message"])
|
|
record_event(
|
|
self.batch,
|
|
"node_progress",
|
|
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
|
|
)
|
|
|
|
self._execute_node(node.node_code)
|
|
|
|
node.status = WorkflowNodeRun.Status.SUCCESS
|
|
node.progress = 100
|
|
node.finished_at = timezone.now()
|
|
node.message = f"{node.node_name}完成"
|
|
node.save(update_fields=["status", "progress", "finished_at", "message"])
|
|
record_event(
|
|
self.batch,
|
|
"node_progress",
|
|
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
|
|
)
|
|
|
|
def _execute_node(self, node_code: str) -> None:
|
|
if node_code == "condition_confirm":
|
|
self._pause_for_condition_confirmation()
|
|
return
|
|
if node_code == "rule_scope":
|
|
self.rule_set = apply_rule_scope(load_rule_file(), self.batch.condition_json.get("rule_scope") or {})
|
|
return
|
|
if node_code == "completeness_check":
|
|
self.findings.extend(run_completeness_check(self.batch.source_summary_batch, self._rules()))
|
|
return
|
|
if node_code == "text_extract":
|
|
self.document_texts = self._extract_source_texts()
|
|
save_artifact(
|
|
self.batch,
|
|
name="text_extract_status.json",
|
|
artifact_type="json",
|
|
content=json.dumps(self.text_extract_status, ensure_ascii=False, indent=2),
|
|
metadata={"artifact": "text_extract_status"},
|
|
)
|
|
return
|
|
if node_code == "structure_check":
|
|
self.findings.extend(run_structure_check(self.document_texts, self._rules()))
|
|
return
|
|
if node_code == "consistency_check":
|
|
self.findings.extend(run_consistency_check(self.document_texts))
|
|
return
|
|
if node_code == "risk_assess":
|
|
issues = persist_findings(self.batch, self.findings)
|
|
create_mock_notifications(self.batch)
|
|
save_artifact(
|
|
self.batch,
|
|
name="rag_result_json.json",
|
|
artifact_type="json",
|
|
content=json.dumps(
|
|
{
|
|
"batch_no": self.batch.batch_no,
|
|
"text_extract_status": self.text_extract_status,
|
|
"issues": [
|
|
{
|
|
"rule_code": issue.rule_code,
|
|
"title": issue.title,
|
|
"citations": issue.citations,
|
|
}
|
|
for issue in issues
|
|
],
|
|
},
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
),
|
|
metadata={"artifact": "rag_result_json"},
|
|
)
|
|
return
|
|
if node_code == "report_export":
|
|
exports = export_review_results(self.batch)
|
|
Message.objects.create(
|
|
conversation=self.batch.conversation,
|
|
role=Message.Role.ASSISTANT,
|
|
content=build_assistant_summary(self.batch, exports),
|
|
)
|
|
|
|
def _pause_for_condition_confirmation(self) -> None:
|
|
if self.batch.condition_json.get("confirmed"):
|
|
return
|
|
candidates = detect_regulatory_condition_candidates(self.batch.source_summary_batch)
|
|
self.batch.condition_json = {
|
|
**(self.batch.condition_json or {}),
|
|
"confirmed": False,
|
|
"resume_from": "rule_scope",
|
|
"candidates": candidates,
|
|
}
|
|
self.batch.status = RegulatoryReviewBatch.Status.WAITING_USER
|
|
self.batch.save(update_fields=["status", "condition_json"])
|
|
node = WorkflowNodeRun.objects.get(
|
|
workflow_type="regulatory_review",
|
|
workflow_batch_id=self.batch.pk,
|
|
node_code="condition_confirm",
|
|
)
|
|
node.status = WorkflowNodeRun.Status.WAITING_USER
|
|
node.progress = 50
|
|
node.message = "请确认产品类别、注册类型、临床评价路径等适用条件"
|
|
node.save(update_fields=["status", "progress", "message"])
|
|
record_event(
|
|
self.batch,
|
|
"waiting_user",
|
|
{"node_code": "condition_confirm", "candidates": candidates, "resume_from": "rule_scope"},
|
|
)
|
|
raise WorkflowPausedForUser()
|
|
|
|
def _rules(self) -> dict:
|
|
if self.rule_set is None:
|
|
self.rule_set = apply_rule_scope(load_rule_file(), self.batch.condition_json.get("rule_scope") or {})
|
|
return self.rule_set
|
|
|
|
def _extract_source_texts(self) -> dict[str, str]:
|
|
texts = {}
|
|
for item in self.batch.source_summary_batch.items.order_by("file_index"):
|
|
path = Path(item.storage_path)
|
|
if not path.is_absolute():
|
|
path = Path(settings.MEDIA_ROOT) / item.storage_path
|
|
if not path.exists():
|
|
self.text_extract_status[item.file_name] = {
|
|
"status": "missing",
|
|
"path": str(path),
|
|
"content_hash": "",
|
|
"section_candidates": [],
|
|
"field_candidates": {},
|
|
"front_text": "",
|
|
}
|
|
continue
|
|
result = extract_text(path)
|
|
self.text_extract_status[item.file_name] = {
|
|
"status": result.status,
|
|
"path": str(path),
|
|
"content_hash": result.content_hash,
|
|
"section_candidates": result.section_candidates,
|
|
"field_candidates": result.field_candidates,
|
|
"front_text": result.front_text,
|
|
"error_message": result.error_message,
|
|
}
|
|
if result.status == "success" and result.text:
|
|
texts[item.file_name] = result.text
|
|
return texts
|
|
|
|
|
|
def start_regulatory_review_workflow(batch: RegulatoryReviewBatch, *, async_run: bool = True) -> None:
|
|
executor = RegulatoryWorkflowExecutor(batch)
|
|
if not async_run:
|
|
executor.run()
|
|
return
|
|
Thread(target=executor.run, daemon=True).start()
|
|
|
|
|
|
def _initial_condition_json(trigger_message: Message | None) -> dict:
|
|
scope = detect_attachment4_chapter_scope(trigger_message.content if trigger_message else "")
|
|
return {"rule_scope": scope} if scope else {}
|
|
|
|
|
|
def detect_attachment4_chapter_scope(content: str) -> dict[str, str] | None:
|
|
normalized = (content or "").strip()
|
|
if not normalized:
|
|
return None
|
|
chapter = _extract_chapter_number(normalized)
|
|
if chapter not in ATTACHMENT4_CHAPTER_LABELS:
|
|
return None
|
|
return {"attachment4_chapter": chapter, "label": ATTACHMENT4_CHAPTER_LABELS[chapter]}
|
|
|
|
|
|
def apply_rule_scope(rule_set: dict, rule_scope: dict) -> dict:
|
|
chapter = str(rule_scope.get("attachment4_chapter") or "")
|
|
if chapter not in ATTACHMENT4_CHAPTER_LABELS:
|
|
return rule_set
|
|
scoped = {**rule_set}
|
|
scoped["requirements"] = [
|
|
requirement
|
|
for requirement in rule_set.get("requirements", [])
|
|
if _requirement_in_chapter(requirement, chapter)
|
|
]
|
|
scoped["active_rule_scope"] = rule_scope
|
|
return scoped
|
|
|
|
|
|
def _requirement_in_chapter(requirement: dict, chapter: str) -> bool:
|
|
attachment4_code = str(requirement.get("attachment4_code") or "")
|
|
return attachment4_code == chapter or attachment4_code.startswith(f"{chapter}.")
|
|
|
|
|
|
def _extract_chapter_number(content: str) -> str:
|
|
match = re.search(r"第\s*([一二三四五六1-6])\s*[章节张]", content)
|
|
if match:
|
|
return _normalize_chapter_number(match.group(1))
|
|
match = re.search(r"(^|[^\d])([1-6])\s*[章节张]", content)
|
|
if match:
|
|
return match.group(2)
|
|
return ""
|
|
|
|
|
|
def _normalize_chapter_number(value: str) -> str:
|
|
chinese = {"一": "1", "二": "2", "三": "3", "四": "4", "五": "5", "六": "6"}
|
|
return chinese.get(value, value)
|