Files
DEMO-AGENT/review_agent/regulatory_review/workflow.py

576 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from threading import Thread
from uuid import uuid4
from django.conf import settings
from django.db import transaction
from django.utils import timezone
from review_agent.models import (
Conversation,
FileSummaryBatch,
Message,
RegulatoryReviewBatch,
WorkflowNodeRun,
)
from review_agent.notifications.dispatcher import dispatch_workflow_notification
from review_agent.notifications.workflow_adapters import build_regulatory_review_context
from review_agent.regulatory_review.services.completeness_check import run_completeness_check
from review_agent.regulatory_review.services.consistency_check import run_consistency_check
from review_agent.regulatory_review.services.export import build_assistant_summary, export_review_results
from review_agent.regulatory_review.services.feishu_notifier import create_mock_notifications
from review_agent.regulatory_review.services.info_extract import detect_regulatory_condition_candidates
from review_agent.regulatory_review.services.llm_review import review_condition_fields, review_workflow_payload
from review_agent.regulatory_review.services.risk_assess import persist_findings
from review_agent.regulatory_review.services.rule_loader import load_rule_file
from review_agent.regulatory_review.services.structure_check import run_structure_check
from review_agent.regulatory_review.services.text_extract import extract_text
from .events import record_event
from .storage import save_artifact
NODE_DEFINITIONS = [
("prepare", "准备", "prepare"),
("condition_confirm", "适用条件确认", "condition_confirm"),
("rule_scope", "规则范围", "rule_scope"),
("completeness_check", "完整性核查", "completeness_check"),
("text_extract", "文本抽取", "text_extract"),
("structure_check", "章节核查", "structure_check"),
("consistency_check", "一致性核查", "consistency_check"),
("risk_assess", "风险评估", "risk_assess"),
("report_export", "报告输出", "report_export"),
("completed", "完成", "completed"),
]
logger = logging.getLogger("review_agent.regulatory_review.workflow")
ATTACHMENT4_CHAPTER_LABELS = {
"1": "第1章 监管信息",
"2": "第2章 综述资料",
"3": "第3章 非临床资料",
"4": "第4章 临床评价资料",
"5": "第5章 产品说明书和标签样稿",
"6": "第6章 质量管理体系文件",
}
class WorkflowPausedForUser(Exception):
pass
def build_batch_no() -> str:
return f"RR-{timezone.localtime().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[:6]}"
def build_batch_work_dir(batch_no: str) -> Path:
return Path(settings.MEDIA_ROOT) / "regulatory_review" / "work" / batch_no
def find_latest_successful_summary_batch(conversation: Conversation) -> FileSummaryBatch | None:
return (
FileSummaryBatch.objects.filter(
conversation=conversation,
status=FileSummaryBatch.Status.SUCCESS,
)
.order_by("-finished_at", "-created_at", "-id")
.first()
)
@transaction.atomic
def create_regulatory_review_batch(
*,
conversation: Conversation,
user,
source_summary_batch: FileSummaryBatch,
trigger_message: Message | None = None,
) -> RegulatoryReviewBatch:
batch_no = build_batch_no()
work_dir = build_batch_work_dir(batch_no)
work_dir.mkdir(parents=True, exist_ok=True)
batch = RegulatoryReviewBatch.objects.create(
conversation=conversation,
user=user,
trigger_message=trigger_message,
source_summary_batch=source_summary_batch,
batch_no=batch_no,
work_dir=str(work_dir),
condition_json=_initial_condition_json(trigger_message),
)
for code, name, group in NODE_DEFINITIONS:
WorkflowNodeRun.objects.create(
workflow_type="regulatory_review",
workflow_batch_id=batch.pk,
node_group=group,
node_code=code,
node_name=name,
)
record_event(batch, "workflow_created", {"batch_id": batch.pk, "batch_no": batch.batch_no})
return batch
class RegulatoryWorkflowExecutor:
def __init__(self, batch: RegulatoryReviewBatch):
self.batch = batch
self.rule_set: dict | None = None
self.findings = []
self.document_texts: dict[str, str] = {}
self.text_extract_status: dict[str, dict[str, object]] = {}
self.llm_reviews: dict[str, dict[str, object]] = {}
def run(self) -> None:
logger.info("法规核查工作流开始 batch_no=%s batch_id=%s", self.batch.batch_no, self.batch.pk)
self.batch.status = RegulatoryReviewBatch.Status.RUNNING
self.batch.started_at = timezone.now()
self.batch.save(update_fields=["status", "started_at"])
record_event(self.batch, "workflow_started", {"batch_id": self.batch.pk})
try:
for node in self._nodes():
if node.status == WorkflowNodeRun.Status.SUCCESS:
continue
self._run_node(node)
except WorkflowPausedForUser:
logger.info("法规核查工作流等待用户 batch_no=%s node=condition_confirm", self.batch.batch_no)
return
except Exception as exc:
logger.exception("Regulatory workflow failed", extra={"batch_id": self.batch.pk})
self.batch.status = RegulatoryReviewBatch.Status.FAILED
self.batch.error_message = str(exc)
self.batch.finished_at = timezone.now()
self.batch.save(update_fields=["status", "error_message", "finished_at"])
record_event(self.batch, "workflow_failed", {"message": str(exc)})
self._dispatch_completion_notification()
return
self.batch.status = RegulatoryReviewBatch.Status.SUCCESS
self.batch.finished_at = timezone.now()
self.batch.save(update_fields=["status", "finished_at"])
record_event(self.batch, "workflow_completed", {"batch_id": self.batch.pk})
self._dispatch_completion_notification()
logger.info("法规核查工作流完成 batch_no=%s findings=%s", self.batch.batch_no, len(self.findings))
def _dispatch_completion_notification(self) -> None:
try:
dispatch_workflow_notification(build_regulatory_review_context(self.batch))
except Exception as exc:
logger.warning(
"Regulatory review notification failed without blocking workflow",
extra={"batch_id": self.batch.pk, "error": str(exc)},
)
def _nodes(self):
return WorkflowNodeRun.objects.filter(
workflow_type="regulatory_review",
workflow_batch_id=self.batch.pk,
).order_by("id")
def _run_node(self, node: WorkflowNodeRun) -> None:
logger.info(
"节点开始 batch_no=%s node=%s name=%s",
self.batch.batch_no,
node.node_code,
node.node_name,
)
node.status = WorkflowNodeRun.Status.RUNNING
node.progress = 10
node.started_at = timezone.now()
node.message = f"{node.node_name}处理中"
node.save(update_fields=["status", "progress", "started_at", "message"])
record_event(
self.batch,
"node_progress",
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
)
self._execute_node(node)
node.status = WorkflowNodeRun.Status.SUCCESS
node.progress = 100
node.finished_at = timezone.now()
node.message = f"{node.node_name}完成"
node.save(update_fields=["status", "progress", "finished_at", "message"])
record_event(
self.batch,
"node_progress",
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
)
logger.info(
"节点完成 batch_no=%s node=%s name=%s progress=%s",
self.batch.batch_no,
node.node_code,
node.node_name,
node.progress,
)
def _update_node_progress(
self,
node: WorkflowNodeRun,
*,
processed: int,
total: int,
message: str,
) -> None:
if total <= 0:
return
progress = min(95, 10 + int((max(processed, 0) / total) * 85))
node.progress = progress
node.message = message
node.save(update_fields=["progress", "message"])
record_event(
self.batch,
"node_progress",
{
"node_code": node.node_code,
"status": node.status,
"progress": node.progress,
"message": node.message,
"processed": processed,
"total": total,
},
)
logger.info(
"节点进度 batch_no=%s node=%s progress=%s processed=%s total=%s message=%s",
self.batch.batch_no,
node.node_code,
progress,
processed,
total,
message,
)
def _execute_node(self, node: WorkflowNodeRun) -> None:
node_code = node.node_code
if node_code == "condition_confirm":
self._pause_for_condition_confirmation()
return
if node_code == "rule_scope":
self.rule_set = apply_rule_scope(load_rule_file(), self.batch.condition_json.get("rule_scope") or {})
logger.info(
"方法执行 batch_no=%s method=apply_rule_scope requirements=%s scope=%s",
self.batch.batch_no,
len(self.rule_set.get("requirements", [])),
self.batch.condition_json.get("rule_scope") or {},
)
return
if node_code == "completeness_check":
findings = run_completeness_check(
self.batch.source_summary_batch,
self._rules(),
progress_callback=lambda update: self._update_node_progress(
node,
processed=int(update.get("processed") or 0),
total=int(update.get("total") or 0),
message=(
f"完整性核查 {update.get('processed')}/{update.get('total')}"
f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题"
),
),
)
self.findings.extend(findings)
logger.info(
"方法执行 batch_no=%s method=run_completeness_check findings=%s source_summary=%s",
self.batch.batch_no,
len(findings),
self.batch.source_summary_batch.batch_no,
)
self._save_llm_review(
"completeness_check",
{
"findings": [finding.to_dict() for finding in findings],
"rules_count": len(self._rules().get("requirements", [])),
},
)
return
if node_code == "text_extract":
self.document_texts = self._extract_source_texts(node)
logger.info(
"方法执行 batch_no=%s method=_extract_source_texts success_docs=%s total_files=%s",
self.batch.batch_no,
len(self.document_texts),
len(self.text_extract_status),
)
self._save_llm_review("text_extract", {"files": self.text_extract_status})
save_artifact(
self.batch,
name="text_extract_status.json",
artifact_type="json",
content=json.dumps(self.text_extract_status, ensure_ascii=False, indent=2),
metadata={"artifact": "text_extract_status"},
)
return
if node_code == "structure_check":
findings = run_structure_check(
self.document_texts,
self._rules(),
progress_callback=lambda update: self._update_node_progress(
node,
processed=int(update.get("processed") or 0),
total=int(update.get("total") or 0),
message=(
f"章节核查 {update.get('processed')}/{update.get('total')}"
f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题"
),
),
)
self.findings.extend(findings)
logger.info(
"方法执行 batch_no=%s method=run_structure_check findings=%s docs=%s",
self.batch.batch_no,
len(findings),
len(self.document_texts),
)
self._save_llm_review("structure_check", {"findings": [finding.to_dict() for finding in findings]})
return
if node_code == "consistency_check":
findings = run_consistency_check(
self.document_texts,
progress_callback=lambda update: self._update_node_progress(
node,
processed=int(update.get("processed") or 0),
total=int(update.get("total") or 0),
message=(
f"一致性核查 {update.get('processed')}/{update.get('total')}"
f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题"
),
),
)
self.findings.extend(findings)
logger.info(
"方法执行 batch_no=%s method=run_consistency_check findings=%s docs=%s",
self.batch.batch_no,
len(findings),
len(self.document_texts),
)
self._save_llm_review("consistency_check", {"findings": [finding.to_dict() for finding in findings]})
return
if node_code == "risk_assess":
self._save_llm_review("risk_assess", {"findings": [finding.to_dict() for finding in self.findings]})
issues = persist_findings(self.batch, self.findings)
create_mock_notifications(self.batch)
logger.info(
"方法执行 batch_no=%s method=persist_findings issues=%s findings=%s",
self.batch.batch_no,
len(issues),
len(self.findings),
)
save_artifact(
self.batch,
name="rag_result_json.json",
artifact_type="json",
content=json.dumps(
{
"batch_no": self.batch.batch_no,
"text_extract_status": self.text_extract_status,
"issues": [
{
"rule_code": issue.rule_code,
"title": issue.title,
"citations": issue.citations,
}
for issue in issues
],
"llm_reviews": self.llm_reviews,
},
ensure_ascii=False,
indent=2,
),
metadata={"artifact": "rag_result_json"},
)
return
if node_code == "report_export":
exports = export_review_results(self.batch)
logger.info(
"方法执行 batch_no=%s method=export_review_results exports=%s",
self.batch.batch_no,
len(exports),
)
Message.objects.create(
conversation=self.batch.conversation,
role=Message.Role.ASSISTANT,
content=build_assistant_summary(self.batch, exports),
)
def _pause_for_condition_confirmation(self) -> None:
if self.batch.condition_json.get("confirmed"):
return
candidates = detect_regulatory_condition_candidates(self.batch.source_summary_batch)
logger.info(
"方法执行 batch_no=%s method=detect_regulatory_condition_candidates product_category=%s product_name=%s",
self.batch.batch_no,
(candidates.get("product_category") or {}).get("suggested"),
(candidates.get("product_name") or {}).get("suggested"),
)
self.batch.condition_json = {
**(self.batch.condition_json or {}),
"confirmed": False,
"resume_from": "rule_scope",
"candidates": candidates,
}
self.batch.status = RegulatoryReviewBatch.Status.WAITING_USER
self.batch.save(update_fields=["status", "condition_json"])
node = WorkflowNodeRun.objects.get(
workflow_type="regulatory_review",
workflow_batch_id=self.batch.pk,
node_code="condition_confirm",
)
node.status = WorkflowNodeRun.Status.WAITING_USER
node.progress = 50
node.message = "请确认产品类别、注册类型、临床评价路径等适用条件"
node.save(update_fields=["status", "progress", "message"])
record_event(
self.batch,
"waiting_user",
{"node_code": "condition_confirm", "candidates": candidates, "resume_from": "rule_scope"},
)
raise WorkflowPausedForUser()
def _rules(self) -> dict:
if self.rule_set is None:
self.rule_set = apply_rule_scope(load_rule_file(), self.batch.condition_json.get("rule_scope") or {})
return self.rule_set
def _extract_source_texts(self, node: WorkflowNodeRun | None = None) -> dict[str, str]:
texts = {}
items = list(self.batch.source_summary_batch.items.order_by("file_index"))
total = len(items)
for index, item in enumerate(items, start=1):
path = Path(item.storage_path)
if not path.is_absolute():
path = Path(settings.MEDIA_ROOT) / item.storage_path
if not path.exists():
logger.info("文本抽取跳过 batch_no=%s file=%s reason=missing", self.batch.batch_no, item.file_name)
self.text_extract_status[item.file_name] = {
"status": "missing",
"path": str(path),
"content_hash": "",
"section_candidates": [],
"field_candidates": {},
"front_text": "",
}
if node:
self._update_node_progress(
node,
processed=index,
total=total,
message=f"文本抽取 {index}/{total}{item.file_name}(文件不存在)",
)
continue
result = extract_text(path)
field_review = review_condition_fields(
text=result.front_text or result.text,
rule_fields=result.field_candidates or {},
file_context=f"{item.directory_level}\n{item.file_name}\n{item.relative_path}",
)
self.text_extract_status[item.file_name] = {
"status": result.status,
"path": str(path),
"content_hash": result.content_hash,
"section_candidates": result.section_candidates,
"field_candidates": field_review.get("selected_fields", result.field_candidates),
"field_review": field_review,
"front_text": result.front_text,
"error_message": result.error_message,
}
if result.status == "success" and result.text:
texts[item.file_name] = result.text
logger.info(
"文本抽取文件 batch_no=%s file=%s status=%s fields=%s chars=%s",
self.batch.batch_no,
item.file_name,
result.status,
len((field_review.get("selected_fields") or {})),
len(result.text or ""),
)
if node:
self._update_node_progress(
node,
processed=index,
total=total,
message=f"文本抽取 {index}/{total}{item.file_name}{result.status}",
)
return texts
def _save_llm_review(self, stage: str, payload: dict[str, object]) -> dict[str, object]:
review = review_workflow_payload(stage=stage, payload=payload)
self.llm_reviews[stage] = review
logger.info(
"方法执行 batch_no=%s method=review_workflow_payload stage=%s status=%s",
self.batch.batch_no,
stage,
review.get("status"),
)
save_artifact(
self.batch,
name=f"llm_review_{stage}.json",
artifact_type="json",
content=json.dumps(review, ensure_ascii=False, indent=2),
metadata={"artifact": "llm_review", "stage": stage},
)
return review
def start_regulatory_review_workflow(batch: RegulatoryReviewBatch, *, async_run: bool = True) -> None:
executor = RegulatoryWorkflowExecutor(batch)
if not async_run:
executor.run()
return
Thread(target=executor.run, daemon=True).start()
def _initial_condition_json(trigger_message: Message | None) -> dict:
scope = detect_attachment4_chapter_scope(trigger_message.content if trigger_message else "")
return {"rule_scope": scope} if scope else {}
def detect_attachment4_chapter_scope(content: str) -> dict[str, str] | None:
normalized = (content or "").strip()
if not normalized:
return None
chapter = _extract_chapter_number(normalized)
if chapter not in ATTACHMENT4_CHAPTER_LABELS:
return None
return {"attachment4_chapter": chapter, "label": ATTACHMENT4_CHAPTER_LABELS[chapter]}
def apply_rule_scope(rule_set: dict, rule_scope: dict) -> dict:
chapter = str(rule_scope.get("attachment4_chapter") or "")
if chapter not in ATTACHMENT4_CHAPTER_LABELS:
return rule_set
scoped = {**rule_set}
scoped["requirements"] = [
requirement
for requirement in rule_set.get("requirements", [])
if _requirement_in_chapter(requirement, chapter)
]
scoped["active_rule_scope"] = rule_scope
return scoped
def _requirement_in_chapter(requirement: dict, chapter: str) -> bool:
attachment4_code = str(requirement.get("attachment4_code") or "")
return attachment4_code == chapter or attachment4_code.startswith(f"{chapter}.")
def _extract_chapter_number(content: str) -> str:
match = re.search(r"\s*([一二三四五六1-6])\s*[章节张]", content)
if match:
return _normalize_chapter_number(match.group(1))
match = re.search(r"(^|[^\d])([1-6])\s*[章节张]", content)
if match:
return match.group(2)
return ""
def _normalize_chapter_number(value: str) -> str:
chinese = {"": "1", "": "2", "": "3", "": "4", "": "5", "": "6"}
return chinese.get(value, value)