563 lines
22 KiB
Python
563 lines
22 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from pathlib import Path
|
||
from threading import Thread
|
||
from uuid import uuid4
|
||
|
||
from django.conf import settings
|
||
from django.db import transaction
|
||
from django.utils import timezone
|
||
|
||
from review_agent.models import (
|
||
Conversation,
|
||
FileSummaryBatch,
|
||
Message,
|
||
RegulatoryReviewBatch,
|
||
WorkflowNodeRun,
|
||
)
|
||
from review_agent.regulatory_review.services.completeness_check import run_completeness_check
|
||
from review_agent.regulatory_review.services.consistency_check import run_consistency_check
|
||
from review_agent.regulatory_review.services.export import build_assistant_summary, export_review_results
|
||
from review_agent.regulatory_review.services.feishu_notifier import create_mock_notifications
|
||
from review_agent.regulatory_review.services.info_extract import detect_regulatory_condition_candidates
|
||
from review_agent.regulatory_review.services.llm_review import review_condition_fields, review_workflow_payload
|
||
from review_agent.regulatory_review.services.risk_assess import persist_findings
|
||
from review_agent.regulatory_review.services.rule_loader import load_rule_file
|
||
from review_agent.regulatory_review.services.structure_check import run_structure_check
|
||
from review_agent.regulatory_review.services.text_extract import extract_text
|
||
|
||
from .events import record_event
|
||
from .storage import save_artifact
|
||
|
||
|
||
NODE_DEFINITIONS = [
|
||
("prepare", "准备", "prepare"),
|
||
("condition_confirm", "适用条件确认", "condition_confirm"),
|
||
("rule_scope", "规则范围", "rule_scope"),
|
||
("completeness_check", "完整性核查", "completeness_check"),
|
||
("text_extract", "文本抽取", "text_extract"),
|
||
("structure_check", "章节核查", "structure_check"),
|
||
("consistency_check", "一致性核查", "consistency_check"),
|
||
("risk_assess", "风险评估", "risk_assess"),
|
||
("report_export", "报告输出", "report_export"),
|
||
("completed", "完成", "completed"),
|
||
]
|
||
|
||
|
||
logger = logging.getLogger("review_agent.regulatory_review.workflow")
|
||
|
||
|
||
ATTACHMENT4_CHAPTER_LABELS = {
|
||
"1": "第1章 监管信息",
|
||
"2": "第2章 综述资料",
|
||
"3": "第3章 非临床资料",
|
||
"4": "第4章 临床评价资料",
|
||
"5": "第5章 产品说明书和标签样稿",
|
||
"6": "第6章 质量管理体系文件",
|
||
}
|
||
|
||
|
||
class WorkflowPausedForUser(Exception):
|
||
pass
|
||
|
||
|
||
def build_batch_no() -> str:
|
||
return f"RR-{timezone.localtime().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[:6]}"
|
||
|
||
|
||
def build_batch_work_dir(batch_no: str) -> Path:
|
||
return Path(settings.MEDIA_ROOT) / "regulatory_review" / "work" / batch_no
|
||
|
||
|
||
def find_latest_successful_summary_batch(conversation: Conversation) -> FileSummaryBatch | None:
|
||
return (
|
||
FileSummaryBatch.objects.filter(
|
||
conversation=conversation,
|
||
status=FileSummaryBatch.Status.SUCCESS,
|
||
)
|
||
.order_by("-finished_at", "-created_at", "-id")
|
||
.first()
|
||
)
|
||
|
||
|
||
@transaction.atomic
|
||
def create_regulatory_review_batch(
|
||
*,
|
||
conversation: Conversation,
|
||
user,
|
||
source_summary_batch: FileSummaryBatch,
|
||
trigger_message: Message | None = None,
|
||
) -> RegulatoryReviewBatch:
|
||
batch_no = build_batch_no()
|
||
work_dir = build_batch_work_dir(batch_no)
|
||
work_dir.mkdir(parents=True, exist_ok=True)
|
||
batch = RegulatoryReviewBatch.objects.create(
|
||
conversation=conversation,
|
||
user=user,
|
||
trigger_message=trigger_message,
|
||
source_summary_batch=source_summary_batch,
|
||
batch_no=batch_no,
|
||
work_dir=str(work_dir),
|
||
condition_json=_initial_condition_json(trigger_message),
|
||
)
|
||
for code, name, group in NODE_DEFINITIONS:
|
||
WorkflowNodeRun.objects.create(
|
||
workflow_type="regulatory_review",
|
||
workflow_batch_id=batch.pk,
|
||
node_group=group,
|
||
node_code=code,
|
||
node_name=name,
|
||
)
|
||
record_event(batch, "workflow_created", {"batch_id": batch.pk, "batch_no": batch.batch_no})
|
||
return batch
|
||
|
||
|
||
class RegulatoryWorkflowExecutor:
|
||
def __init__(self, batch: RegulatoryReviewBatch):
|
||
self.batch = batch
|
||
self.rule_set: dict | None = None
|
||
self.findings = []
|
||
self.document_texts: dict[str, str] = {}
|
||
self.text_extract_status: dict[str, dict[str, object]] = {}
|
||
self.llm_reviews: dict[str, dict[str, object]] = {}
|
||
|
||
def run(self) -> None:
|
||
logger.info("法规核查工作流开始 batch_no=%s batch_id=%s", self.batch.batch_no, self.batch.pk)
|
||
self.batch.status = RegulatoryReviewBatch.Status.RUNNING
|
||
self.batch.started_at = timezone.now()
|
||
self.batch.save(update_fields=["status", "started_at"])
|
||
record_event(self.batch, "workflow_started", {"batch_id": self.batch.pk})
|
||
|
||
try:
|
||
for node in self._nodes():
|
||
if node.status == WorkflowNodeRun.Status.SUCCESS:
|
||
continue
|
||
self._run_node(node)
|
||
except WorkflowPausedForUser:
|
||
logger.info("法规核查工作流等待用户 batch_no=%s node=condition_confirm", self.batch.batch_no)
|
||
return
|
||
except Exception as exc:
|
||
logger.exception("Regulatory workflow failed", extra={"batch_id": self.batch.pk})
|
||
self.batch.status = RegulatoryReviewBatch.Status.FAILED
|
||
self.batch.error_message = str(exc)
|
||
self.batch.finished_at = timezone.now()
|
||
self.batch.save(update_fields=["status", "error_message", "finished_at"])
|
||
record_event(self.batch, "workflow_failed", {"message": str(exc)})
|
||
return
|
||
|
||
self.batch.status = RegulatoryReviewBatch.Status.SUCCESS
|
||
self.batch.finished_at = timezone.now()
|
||
self.batch.save(update_fields=["status", "finished_at"])
|
||
record_event(self.batch, "workflow_completed", {"batch_id": self.batch.pk})
|
||
logger.info("法规核查工作流完成 batch_no=%s findings=%s", self.batch.batch_no, len(self.findings))
|
||
|
||
def _nodes(self):
|
||
return WorkflowNodeRun.objects.filter(
|
||
workflow_type="regulatory_review",
|
||
workflow_batch_id=self.batch.pk,
|
||
).order_by("id")
|
||
|
||
def _run_node(self, node: WorkflowNodeRun) -> None:
|
||
logger.info(
|
||
"节点开始 batch_no=%s node=%s name=%s",
|
||
self.batch.batch_no,
|
||
node.node_code,
|
||
node.node_name,
|
||
)
|
||
node.status = WorkflowNodeRun.Status.RUNNING
|
||
node.progress = 10
|
||
node.started_at = timezone.now()
|
||
node.message = f"{node.node_name}处理中"
|
||
node.save(update_fields=["status", "progress", "started_at", "message"])
|
||
record_event(
|
||
self.batch,
|
||
"node_progress",
|
||
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
|
||
)
|
||
|
||
self._execute_node(node)
|
||
|
||
node.status = WorkflowNodeRun.Status.SUCCESS
|
||
node.progress = 100
|
||
node.finished_at = timezone.now()
|
||
node.message = f"{node.node_name}完成"
|
||
node.save(update_fields=["status", "progress", "finished_at", "message"])
|
||
record_event(
|
||
self.batch,
|
||
"node_progress",
|
||
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
|
||
)
|
||
logger.info(
|
||
"节点完成 batch_no=%s node=%s name=%s progress=%s",
|
||
self.batch.batch_no,
|
||
node.node_code,
|
||
node.node_name,
|
||
node.progress,
|
||
)
|
||
|
||
def _update_node_progress(
|
||
self,
|
||
node: WorkflowNodeRun,
|
||
*,
|
||
processed: int,
|
||
total: int,
|
||
message: str,
|
||
) -> None:
|
||
if total <= 0:
|
||
return
|
||
progress = min(95, 10 + int((max(processed, 0) / total) * 85))
|
||
node.progress = progress
|
||
node.message = message
|
||
node.save(update_fields=["progress", "message"])
|
||
record_event(
|
||
self.batch,
|
||
"node_progress",
|
||
{
|
||
"node_code": node.node_code,
|
||
"status": node.status,
|
||
"progress": node.progress,
|
||
"message": node.message,
|
||
"processed": processed,
|
||
"total": total,
|
||
},
|
||
)
|
||
logger.info(
|
||
"节点进度 batch_no=%s node=%s progress=%s processed=%s total=%s message=%s",
|
||
self.batch.batch_no,
|
||
node.node_code,
|
||
progress,
|
||
processed,
|
||
total,
|
||
message,
|
||
)
|
||
|
||
def _execute_node(self, node: WorkflowNodeRun) -> None:
|
||
node_code = node.node_code
|
||
if node_code == "condition_confirm":
|
||
self._pause_for_condition_confirmation()
|
||
return
|
||
if node_code == "rule_scope":
|
||
self.rule_set = apply_rule_scope(load_rule_file(), self.batch.condition_json.get("rule_scope") or {})
|
||
logger.info(
|
||
"方法执行 batch_no=%s method=apply_rule_scope requirements=%s scope=%s",
|
||
self.batch.batch_no,
|
||
len(self.rule_set.get("requirements", [])),
|
||
self.batch.condition_json.get("rule_scope") or {},
|
||
)
|
||
return
|
||
if node_code == "completeness_check":
|
||
findings = run_completeness_check(
|
||
self.batch.source_summary_batch,
|
||
self._rules(),
|
||
progress_callback=lambda update: self._update_node_progress(
|
||
node,
|
||
processed=int(update.get("processed") or 0),
|
||
total=int(update.get("total") or 0),
|
||
message=(
|
||
f"完整性核查 {update.get('processed')}/{update.get('total')}:"
|
||
f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题"
|
||
),
|
||
),
|
||
)
|
||
self.findings.extend(findings)
|
||
logger.info(
|
||
"方法执行 batch_no=%s method=run_completeness_check findings=%s source_summary=%s",
|
||
self.batch.batch_no,
|
||
len(findings),
|
||
self.batch.source_summary_batch.batch_no,
|
||
)
|
||
self._save_llm_review(
|
||
"completeness_check",
|
||
{
|
||
"findings": [finding.to_dict() for finding in findings],
|
||
"rules_count": len(self._rules().get("requirements", [])),
|
||
},
|
||
)
|
||
return
|
||
if node_code == "text_extract":
|
||
self.document_texts = self._extract_source_texts(node)
|
||
logger.info(
|
||
"方法执行 batch_no=%s method=_extract_source_texts success_docs=%s total_files=%s",
|
||
self.batch.batch_no,
|
||
len(self.document_texts),
|
||
len(self.text_extract_status),
|
||
)
|
||
self._save_llm_review("text_extract", {"files": self.text_extract_status})
|
||
save_artifact(
|
||
self.batch,
|
||
name="text_extract_status.json",
|
||
artifact_type="json",
|
||
content=json.dumps(self.text_extract_status, ensure_ascii=False, indent=2),
|
||
metadata={"artifact": "text_extract_status"},
|
||
)
|
||
return
|
||
if node_code == "structure_check":
|
||
findings = run_structure_check(
|
||
self.document_texts,
|
||
self._rules(),
|
||
progress_callback=lambda update: self._update_node_progress(
|
||
node,
|
||
processed=int(update.get("processed") or 0),
|
||
total=int(update.get("total") or 0),
|
||
message=(
|
||
f"章节核查 {update.get('processed')}/{update.get('total')}:"
|
||
f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题"
|
||
),
|
||
),
|
||
)
|
||
self.findings.extend(findings)
|
||
logger.info(
|
||
"方法执行 batch_no=%s method=run_structure_check findings=%s docs=%s",
|
||
self.batch.batch_no,
|
||
len(findings),
|
||
len(self.document_texts),
|
||
)
|
||
self._save_llm_review("structure_check", {"findings": [finding.to_dict() for finding in findings]})
|
||
return
|
||
if node_code == "consistency_check":
|
||
findings = run_consistency_check(
|
||
self.document_texts,
|
||
progress_callback=lambda update: self._update_node_progress(
|
||
node,
|
||
processed=int(update.get("processed") or 0),
|
||
total=int(update.get("total") or 0),
|
||
message=(
|
||
f"一致性核查 {update.get('processed')}/{update.get('total')}:"
|
||
f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题"
|
||
),
|
||
),
|
||
)
|
||
self.findings.extend(findings)
|
||
logger.info(
|
||
"方法执行 batch_no=%s method=run_consistency_check findings=%s docs=%s",
|
||
self.batch.batch_no,
|
||
len(findings),
|
||
len(self.document_texts),
|
||
)
|
||
self._save_llm_review("consistency_check", {"findings": [finding.to_dict() for finding in findings]})
|
||
return
|
||
if node_code == "risk_assess":
|
||
self._save_llm_review("risk_assess", {"findings": [finding.to_dict() for finding in self.findings]})
|
||
issues = persist_findings(self.batch, self.findings)
|
||
create_mock_notifications(self.batch)
|
||
logger.info(
|
||
"方法执行 batch_no=%s method=persist_findings issues=%s findings=%s",
|
||
self.batch.batch_no,
|
||
len(issues),
|
||
len(self.findings),
|
||
)
|
||
save_artifact(
|
||
self.batch,
|
||
name="rag_result_json.json",
|
||
artifact_type="json",
|
||
content=json.dumps(
|
||
{
|
||
"batch_no": self.batch.batch_no,
|
||
"text_extract_status": self.text_extract_status,
|
||
"issues": [
|
||
{
|
||
"rule_code": issue.rule_code,
|
||
"title": issue.title,
|
||
"citations": issue.citations,
|
||
}
|
||
for issue in issues
|
||
],
|
||
"llm_reviews": self.llm_reviews,
|
||
},
|
||
ensure_ascii=False,
|
||
indent=2,
|
||
),
|
||
metadata={"artifact": "rag_result_json"},
|
||
)
|
||
return
|
||
if node_code == "report_export":
|
||
exports = export_review_results(self.batch)
|
||
logger.info(
|
||
"方法执行 batch_no=%s method=export_review_results exports=%s",
|
||
self.batch.batch_no,
|
||
len(exports),
|
||
)
|
||
Message.objects.create(
|
||
conversation=self.batch.conversation,
|
||
role=Message.Role.ASSISTANT,
|
||
content=build_assistant_summary(self.batch, exports),
|
||
)
|
||
|
||
def _pause_for_condition_confirmation(self) -> None:
|
||
if self.batch.condition_json.get("confirmed"):
|
||
return
|
||
candidates = detect_regulatory_condition_candidates(self.batch.source_summary_batch)
|
||
logger.info(
|
||
"方法执行 batch_no=%s method=detect_regulatory_condition_candidates product_category=%s product_name=%s",
|
||
self.batch.batch_no,
|
||
(candidates.get("product_category") or {}).get("suggested"),
|
||
(candidates.get("product_name") or {}).get("suggested"),
|
||
)
|
||
self.batch.condition_json = {
|
||
**(self.batch.condition_json or {}),
|
||
"confirmed": False,
|
||
"resume_from": "rule_scope",
|
||
"candidates": candidates,
|
||
}
|
||
self.batch.status = RegulatoryReviewBatch.Status.WAITING_USER
|
||
self.batch.save(update_fields=["status", "condition_json"])
|
||
node = WorkflowNodeRun.objects.get(
|
||
workflow_type="regulatory_review",
|
||
workflow_batch_id=self.batch.pk,
|
||
node_code="condition_confirm",
|
||
)
|
||
node.status = WorkflowNodeRun.Status.WAITING_USER
|
||
node.progress = 50
|
||
node.message = "请确认产品类别、注册类型、临床评价路径等适用条件"
|
||
node.save(update_fields=["status", "progress", "message"])
|
||
record_event(
|
||
self.batch,
|
||
"waiting_user",
|
||
{"node_code": "condition_confirm", "candidates": candidates, "resume_from": "rule_scope"},
|
||
)
|
||
raise WorkflowPausedForUser()
|
||
|
||
def _rules(self) -> dict:
|
||
if self.rule_set is None:
|
||
self.rule_set = apply_rule_scope(load_rule_file(), self.batch.condition_json.get("rule_scope") or {})
|
||
return self.rule_set
|
||
|
||
def _extract_source_texts(self, node: WorkflowNodeRun | None = None) -> dict[str, str]:
|
||
texts = {}
|
||
items = list(self.batch.source_summary_batch.items.order_by("file_index"))
|
||
total = len(items)
|
||
for index, item in enumerate(items, start=1):
|
||
path = Path(item.storage_path)
|
||
if not path.is_absolute():
|
||
path = Path(settings.MEDIA_ROOT) / item.storage_path
|
||
if not path.exists():
|
||
logger.info("文本抽取跳过 batch_no=%s file=%s reason=missing", self.batch.batch_no, item.file_name)
|
||
self.text_extract_status[item.file_name] = {
|
||
"status": "missing",
|
||
"path": str(path),
|
||
"content_hash": "",
|
||
"section_candidates": [],
|
||
"field_candidates": {},
|
||
"front_text": "",
|
||
}
|
||
if node:
|
||
self._update_node_progress(
|
||
node,
|
||
processed=index,
|
||
total=total,
|
||
message=f"文本抽取 {index}/{total}:{item.file_name}(文件不存在)",
|
||
)
|
||
continue
|
||
result = extract_text(path)
|
||
field_review = review_condition_fields(
|
||
text=result.front_text or result.text,
|
||
rule_fields=result.field_candidates or {},
|
||
file_context=f"{item.directory_level}\n{item.file_name}\n{item.relative_path}",
|
||
)
|
||
self.text_extract_status[item.file_name] = {
|
||
"status": result.status,
|
||
"path": str(path),
|
||
"content_hash": result.content_hash,
|
||
"section_candidates": result.section_candidates,
|
||
"field_candidates": field_review.get("selected_fields", result.field_candidates),
|
||
"field_review": field_review,
|
||
"front_text": result.front_text,
|
||
"error_message": result.error_message,
|
||
}
|
||
if result.status == "success" and result.text:
|
||
texts[item.file_name] = result.text
|
||
logger.info(
|
||
"文本抽取文件 batch_no=%s file=%s status=%s fields=%s chars=%s",
|
||
self.batch.batch_no,
|
||
item.file_name,
|
||
result.status,
|
||
len((field_review.get("selected_fields") or {})),
|
||
len(result.text or ""),
|
||
)
|
||
if node:
|
||
self._update_node_progress(
|
||
node,
|
||
processed=index,
|
||
total=total,
|
||
message=f"文本抽取 {index}/{total}:{item.file_name}({result.status})",
|
||
)
|
||
return texts
|
||
|
||
def _save_llm_review(self, stage: str, payload: dict[str, object]) -> dict[str, object]:
|
||
review = review_workflow_payload(stage=stage, payload=payload)
|
||
self.llm_reviews[stage] = review
|
||
logger.info(
|
||
"方法执行 batch_no=%s method=review_workflow_payload stage=%s status=%s",
|
||
self.batch.batch_no,
|
||
stage,
|
||
review.get("status"),
|
||
)
|
||
save_artifact(
|
||
self.batch,
|
||
name=f"llm_review_{stage}.json",
|
||
artifact_type="json",
|
||
content=json.dumps(review, ensure_ascii=False, indent=2),
|
||
metadata={"artifact": "llm_review", "stage": stage},
|
||
)
|
||
return review
|
||
|
||
|
||
def start_regulatory_review_workflow(batch: RegulatoryReviewBatch, *, async_run: bool = True) -> None:
|
||
executor = RegulatoryWorkflowExecutor(batch)
|
||
if not async_run:
|
||
executor.run()
|
||
return
|
||
Thread(target=executor.run, daemon=True).start()
|
||
|
||
|
||
def _initial_condition_json(trigger_message: Message | None) -> dict:
|
||
scope = detect_attachment4_chapter_scope(trigger_message.content if trigger_message else "")
|
||
return {"rule_scope": scope} if scope else {}
|
||
|
||
|
||
def detect_attachment4_chapter_scope(content: str) -> dict[str, str] | None:
|
||
normalized = (content or "").strip()
|
||
if not normalized:
|
||
return None
|
||
chapter = _extract_chapter_number(normalized)
|
||
if chapter not in ATTACHMENT4_CHAPTER_LABELS:
|
||
return None
|
||
return {"attachment4_chapter": chapter, "label": ATTACHMENT4_CHAPTER_LABELS[chapter]}
|
||
|
||
|
||
def apply_rule_scope(rule_set: dict, rule_scope: dict) -> dict:
|
||
chapter = str(rule_scope.get("attachment4_chapter") or "")
|
||
if chapter not in ATTACHMENT4_CHAPTER_LABELS:
|
||
return rule_set
|
||
scoped = {**rule_set}
|
||
scoped["requirements"] = [
|
||
requirement
|
||
for requirement in rule_set.get("requirements", [])
|
||
if _requirement_in_chapter(requirement, chapter)
|
||
]
|
||
scoped["active_rule_scope"] = rule_scope
|
||
return scoped
|
||
|
||
|
||
def _requirement_in_chapter(requirement: dict, chapter: str) -> bool:
|
||
attachment4_code = str(requirement.get("attachment4_code") or "")
|
||
return attachment4_code == chapter or attachment4_code.startswith(f"{chapter}.")
|
||
|
||
|
||
def _extract_chapter_number(content: str) -> str:
|
||
match = re.search(r"第\s*([一二三四五六1-6])\s*[章节张]", content)
|
||
if match:
|
||
return _normalize_chapter_number(match.group(1))
|
||
match = re.search(r"(^|[^\d])([1-6])\s*[章节张]", content)
|
||
if match:
|
||
return match.group(2)
|
||
return ""
|
||
|
||
|
||
def _normalize_chapter_number(value: str) -> str:
|
||
chinese = {"一": "1", "二": "2", "三": "3", "四": "4", "五": "5", "六": "6"}
|
||
return chinese.get(value, value)
|