Files
DEMO-AGENT/review_agent/regulatory_review/workflow.py

209 lines
7.7 KiB
Python

from __future__ import annotations
import logging
from pathlib import Path
from threading import Thread
from uuid import uuid4
from django.conf import settings
from django.db import transaction
from django.utils import timezone
from review_agent.models import (
Conversation,
FileSummaryBatch,
Message,
RegulatoryReviewBatch,
WorkflowNodeRun,
)
from review_agent.regulatory_review.services.completeness_check import run_completeness_check
from review_agent.regulatory_review.services.consistency_check import run_consistency_check
from review_agent.regulatory_review.services.export import build_assistant_summary, export_review_results
from review_agent.regulatory_review.services.risk_assess import persist_findings
from review_agent.regulatory_review.services.rule_loader import load_rule_file
from review_agent.regulatory_review.services.structure_check import run_structure_check
from review_agent.regulatory_review.services.text_extract import extract_text
from .events import record_event
NODE_DEFINITIONS = [
("prepare", "准备", "prepare"),
("rule_scope", "规则范围", "rule_scope"),
("completeness_check", "完整性核查", "completeness_check"),
("text_extract", "文本抽取", "text_extract"),
("structure_check", "章节核查", "structure_check"),
("consistency_check", "一致性核查", "consistency_check"),
("risk_assess", "风险评估", "risk_assess"),
("report_export", "报告输出", "report_export"),
("completed", "完成", "completed"),
]
logger = logging.getLogger("review_agent.regulatory_review.workflow")
def build_batch_no() -> str:
return f"RR-{timezone.localtime().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[:6]}"
def build_batch_work_dir(batch_no: str) -> Path:
return Path(settings.MEDIA_ROOT) / "regulatory_review" / "work" / batch_no
def find_latest_successful_summary_batch(conversation: Conversation) -> FileSummaryBatch | None:
return (
FileSummaryBatch.objects.filter(
conversation=conversation,
status=FileSummaryBatch.Status.SUCCESS,
)
.order_by("-finished_at", "-created_at", "-id")
.first()
)
@transaction.atomic
def create_regulatory_review_batch(
*,
conversation: Conversation,
user,
source_summary_batch: FileSummaryBatch,
trigger_message: Message | None = None,
) -> RegulatoryReviewBatch:
batch_no = build_batch_no()
work_dir = build_batch_work_dir(batch_no)
work_dir.mkdir(parents=True, exist_ok=True)
batch = RegulatoryReviewBatch.objects.create(
conversation=conversation,
user=user,
trigger_message=trigger_message,
source_summary_batch=source_summary_batch,
batch_no=batch_no,
work_dir=str(work_dir),
)
for code, name, group in NODE_DEFINITIONS:
WorkflowNodeRun.objects.create(
workflow_type="regulatory_review",
workflow_batch_id=batch.pk,
node_group=group,
node_code=code,
node_name=name,
)
record_event(batch, "workflow_created", {"batch_id": batch.pk, "batch_no": batch.batch_no})
return batch
class RegulatoryWorkflowExecutor:
def __init__(self, batch: RegulatoryReviewBatch):
self.batch = batch
self.rule_set: dict | None = None
self.findings = []
self.document_texts: dict[str, str] = {}
def run(self) -> None:
self.batch.status = RegulatoryReviewBatch.Status.RUNNING
self.batch.started_at = timezone.now()
self.batch.save(update_fields=["status", "started_at"])
record_event(self.batch, "workflow_started", {"batch_id": self.batch.pk})
try:
for node in self._nodes():
self._run_node(node)
except Exception as exc:
logger.exception("Regulatory workflow failed", extra={"batch_id": self.batch.pk})
self.batch.status = RegulatoryReviewBatch.Status.FAILED
self.batch.error_message = str(exc)
self.batch.finished_at = timezone.now()
self.batch.save(update_fields=["status", "error_message", "finished_at"])
record_event(self.batch, "workflow_failed", {"message": str(exc)})
return
self.batch.status = RegulatoryReviewBatch.Status.SUCCESS
self.batch.finished_at = timezone.now()
self.batch.save(update_fields=["status", "finished_at"])
record_event(self.batch, "workflow_completed", {"batch_id": self.batch.pk})
def _nodes(self):
return WorkflowNodeRun.objects.filter(
workflow_type="regulatory_review",
workflow_batch_id=self.batch.pk,
).order_by("id")
def _run_node(self, node: WorkflowNodeRun) -> None:
node.status = WorkflowNodeRun.Status.RUNNING
node.progress = 10
node.started_at = timezone.now()
node.message = f"{node.node_name}处理中"
node.save(update_fields=["status", "progress", "started_at", "message"])
record_event(
self.batch,
"node_progress",
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
)
self._execute_node(node.node_code)
node.status = WorkflowNodeRun.Status.SUCCESS
node.progress = 100
node.finished_at = timezone.now()
node.message = f"{node.node_name}完成"
node.save(update_fields=["status", "progress", "finished_at", "message"])
record_event(
self.batch,
"node_progress",
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
)
def _execute_node(self, node_code: str) -> None:
if node_code == "rule_scope":
self.rule_set = load_rule_file()
return
if node_code == "completeness_check":
self.findings.extend(run_completeness_check(self.batch.source_summary_batch, self._rules()))
return
if node_code == "text_extract":
self.document_texts = self._extract_source_texts()
return
if node_code == "structure_check":
self.findings.extend(run_structure_check(self.document_texts, self._rules()))
return
if node_code == "consistency_check":
self.findings.extend(run_consistency_check(self.document_texts))
return
if node_code == "risk_assess":
persist_findings(self.batch, self.findings)
return
if node_code == "report_export":
exports = export_review_results(self.batch)
Message.objects.create(
conversation=self.batch.conversation,
role=Message.Role.ASSISTANT,
content=build_assistant_summary(self.batch, exports),
)
def _rules(self) -> dict:
if self.rule_set is None:
self.rule_set = load_rule_file()
return self.rule_set
def _extract_source_texts(self) -> dict[str, str]:
texts = {}
for item in self.batch.source_summary_batch.items.order_by("file_index"):
path = Path(item.storage_path)
if not path.is_absolute():
path = Path(settings.MEDIA_ROOT) / item.storage_path
if not path.exists():
continue
result = extract_text(path)
if result.status == "success" and result.text:
texts[item.file_name] = result.text
return texts
def start_regulatory_review_workflow(batch: RegulatoryReviewBatch, *, async_run: bool = True) -> None:
executor = RegulatoryWorkflowExecutor(batch)
if not async_run:
executor.run()
return
Thread(target=executor.run, daemon=True).start()