339 lines
16 KiB
Python
339 lines
16 KiB
Python
from __future__ import annotations
|
||
|
||
import logging
|
||
from threading import Thread
|
||
from uuid import uuid4
|
||
|
||
from django.conf import settings
|
||
from django.db import transaction
|
||
from django.utils import timezone
|
||
|
||
from review_agent.file_summary.paths import resolve_storage_path
|
||
from review_agent.models import (
|
||
Conversation,
|
||
ExportedSummaryFile,
|
||
Message,
|
||
RegulatoryInfoPackageArtifact,
|
||
RegulatoryInfoPackageBatch,
|
||
RegulatoryInfoPackageNotificationRecord,
|
||
WorkflowNodeRun,
|
||
)
|
||
from review_agent.regulatory_info_package.constants import (
|
||
DEFAULT_ZIP_NAME,
|
||
REGULATORY_INFO_PACKAGE_NODE_DEFINITIONS,
|
||
WORKFLOW_TYPE,
|
||
)
|
||
from review_agent.regulatory_info_package.events import record_event
|
||
from review_agent.regulatory_info_package.services.template_config import (
|
||
compute_config_hash,
|
||
load_template_config,
|
||
validate_template_config,
|
||
)
|
||
from review_agent.regulatory_info_package.services.field_extract import run_parallel_extract, save_field_extract_result
|
||
from review_agent.regulatory_info_package.services.field_merge import merge_fields, save_merged_fields
|
||
from review_agent.regulatory_info_package.services.instruction_extract import parse_instruction_docx, save_instruction_extract_json
|
||
from review_agent.regulatory_info_package.services.package_generate import generate_package_documents
|
||
from review_agent.regulatory_info_package.services.summary import build_assistant_summary
|
||
from review_agent.regulatory_info_package.services.traceability_export import save_traceability_exports
|
||
from review_agent.regulatory_info_package.services.zip_export import create_zip_package
|
||
from review_agent.regulatory_info_package.schemas import GeneratedFileResult, InstructionExtractResult, MergedField
|
||
from review_agent.regulatory_info_package.storage import build_batch_work_dir
|
||
from review_agent.regulatory_info_package.storage import create_artifact_for_file, ensure_batch_subdir
|
||
|
||
|
||
logger = logging.getLogger("review_agent.regulatory_info_package.workflow")
|
||
|
||
|
||
def build_batch_no() -> str:
|
||
return f"RIP-{timezone.localtime().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[:6]}"
|
||
|
||
|
||
@transaction.atomic
|
||
def create_regulatory_info_package_batch(
|
||
*,
|
||
conversation: Conversation,
|
||
user,
|
||
trigger_message: Message | None = None,
|
||
source_attachment=None,
|
||
source_summary_batch=None,
|
||
source_summary_item_id: int | None = None,
|
||
source_file_name: str = "",
|
||
source_storage_path: str = "",
|
||
existing_batch: RegulatoryInfoPackageBatch | None = None,
|
||
) -> RegulatoryInfoPackageBatch:
|
||
batch = existing_batch
|
||
if batch is None:
|
||
batch_no = build_batch_no()
|
||
work_dir = build_batch_work_dir(batch_no=batch_no)
|
||
work_dir.mkdir(parents=True, exist_ok=True)
|
||
batch = RegulatoryInfoPackageBatch.objects.create(
|
||
conversation=conversation,
|
||
user=user,
|
||
trigger_message=trigger_message,
|
||
source_attachment=source_attachment,
|
||
source_summary_batch=source_summary_batch,
|
||
source_summary_item_id=source_summary_item_id,
|
||
source_file_name=source_file_name or getattr(source_attachment, "original_name", ""),
|
||
source_storage_path=source_storage_path or getattr(source_attachment, "storage_path", ""),
|
||
batch_no=batch_no,
|
||
output_zip_name=DEFAULT_ZIP_NAME,
|
||
work_dir=str(work_dir),
|
||
)
|
||
for code, name, group in REGULATORY_INFO_PACKAGE_NODE_DEFINITIONS:
|
||
WorkflowNodeRun.objects.get_or_create(
|
||
workflow_type=WORKFLOW_TYPE,
|
||
workflow_batch_id=batch.pk,
|
||
node_code=code,
|
||
defaults={
|
||
"node_group": group,
|
||
"node_name": name,
|
||
},
|
||
)
|
||
record_event(batch, "workflow_created", {"batch_id": batch.pk, "batch_no": batch.batch_no})
|
||
return batch
|
||
|
||
|
||
class RegulatoryInfoPackageWorkflowExecutor:
|
||
"""Runs the Chapter 1 regulatory information package workflow."""
|
||
|
||
def __init__(self, batch: RegulatoryInfoPackageBatch):
|
||
self.batch = batch
|
||
self.template_config: dict = {}
|
||
self.instruction: InstructionExtractResult | None = None
|
||
self.extract_payload: dict = {}
|
||
self.merged_fields: dict[str, MergedField] = {}
|
||
self.merge_summary: dict[str, list[dict]] = {}
|
||
self.generation_results: list[GeneratedFileResult] = []
|
||
self.exports: list[ExportedSummaryFile] = []
|
||
|
||
def run(self) -> None:
|
||
logger.info("监管信息材料包工作流开始 batch_no=%s batch_id=%s", self.batch.batch_no, self.batch.pk)
|
||
self.batch.status = RegulatoryInfoPackageBatch.Status.RUNNING
|
||
self.batch.started_at = timezone.now()
|
||
self.batch.save(update_fields=["status", "started_at"])
|
||
record_event(self.batch, "workflow_started", {"batch_id": self.batch.pk})
|
||
try:
|
||
for node in self._nodes():
|
||
if node.status in {WorkflowNodeRun.Status.SUCCESS, WorkflowNodeRun.Status.SKIPPED}:
|
||
continue
|
||
self._run_node(node)
|
||
except Exception as exc:
|
||
logger.exception("Regulatory info package workflow failed", extra={"batch_id": self.batch.pk})
|
||
self.batch.status = RegulatoryInfoPackageBatch.Status.FAILED
|
||
self.batch.error_message = str(exc)
|
||
self.batch.finished_at = timezone.now()
|
||
self.batch.save(update_fields=["status", "error_message", "finished_at"])
|
||
record_event(self.batch, "workflow_failed", {"message": str(exc)})
|
||
return
|
||
self.batch.status = RegulatoryInfoPackageBatch.Status.SUCCESS
|
||
self.batch.finished_at = timezone.now()
|
||
self.batch.save(update_fields=["status", "finished_at"])
|
||
record_event(self.batch, "workflow_completed", {"batch_id": self.batch.pk})
|
||
|
||
def _nodes(self):
|
||
return WorkflowNodeRun.objects.filter(
|
||
workflow_type=WORKFLOW_TYPE,
|
||
workflow_batch_id=self.batch.pk,
|
||
).order_by("id")
|
||
|
||
def _run_node(self, node: WorkflowNodeRun) -> None:
|
||
node.status = WorkflowNodeRun.Status.RUNNING
|
||
node.progress = 10
|
||
node.started_at = timezone.now()
|
||
node.message = f"{node.node_name}处理中"
|
||
node.save(update_fields=["status", "progress", "started_at", "message"])
|
||
record_event(self.batch, "node_progress", {"node_code": node.node_code, "status": node.status})
|
||
self._execute_node(node)
|
||
node.status = WorkflowNodeRun.Status.SUCCESS
|
||
node.progress = 100
|
||
node.finished_at = timezone.now()
|
||
node.message = f"{node.node_name}完成"
|
||
node.save(update_fields=["status", "progress", "finished_at", "message"])
|
||
record_event(self.batch, "node_progress", {"node_code": node.node_code, "status": node.status})
|
||
|
||
def _execute_node(self, node: WorkflowNodeRun) -> None:
|
||
if node.node_code == "prepare":
|
||
self.template_config = load_template_config()
|
||
errors = validate_template_config(self.template_config)
|
||
if errors:
|
||
raise ValueError(";".join(errors))
|
||
self.batch.template_config_version = str(self.template_config.get("version") or "")
|
||
self.batch.template_config_hash = compute_config_hash()
|
||
self.batch.save(update_fields=["template_config_version", "template_config_hash"])
|
||
return
|
||
if node.node_code == "template_copy":
|
||
return
|
||
if node.node_code == "text_extract":
|
||
if not self.batch.source_storage_path:
|
||
self.instruction = None
|
||
return
|
||
path = resolve_storage_path(self.batch.source_storage_path)
|
||
self.instruction = parse_instruction_docx(path)
|
||
json_path = ensure_batch_subdir(self.batch, "logs") / "instruction_extract.json"
|
||
save_instruction_extract_json(json_path, self.instruction)
|
||
create_artifact_for_file(
|
||
self.batch,
|
||
path=json_path,
|
||
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.INSTRUCTION_EXTRACT,
|
||
file_format=RegulatoryInfoPackageArtifact.FileFormat.JSON,
|
||
created_by_node=node.node_code,
|
||
)
|
||
return
|
||
if node.node_code == "field_extract":
|
||
if not self.instruction:
|
||
self.extract_payload = {"regex_results": {}, "llm_results": {}, "llm_error": ""}
|
||
return
|
||
self.extract_payload = run_parallel_extract(self.instruction, llm_extract_func=lambda _instruction: {})
|
||
json_path = ensure_batch_subdir(self.batch, "logs") / "field_extract_result.json"
|
||
save_field_extract_result(json_path, self.extract_payload)
|
||
create_artifact_for_file(
|
||
self.batch,
|
||
path=json_path,
|
||
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.FIELD_EXTRACT_RESULT,
|
||
file_format=RegulatoryInfoPackageArtifact.FileFormat.JSON,
|
||
created_by_node=node.node_code,
|
||
)
|
||
return
|
||
if node.node_code == "field_merge":
|
||
self.merged_fields, self.merge_summary = merge_fields(
|
||
self.extract_payload.get("regex_results") or {},
|
||
self.extract_payload.get("llm_results") or {},
|
||
)
|
||
product = self.merged_fields.get("product_name")
|
||
if product and product.value and product.value != "/":
|
||
self.batch.product_name = product.value
|
||
self.batch.missing_fields = self.merge_summary.get("missing_fields", [])
|
||
self.batch.llm_only_fields = self.merge_summary.get("llm_only_fields", [])
|
||
self.batch.conflict_fields = self.merge_summary.get("conflict_fields", [])
|
||
self.batch.save(update_fields=["product_name", "missing_fields", "llm_only_fields", "conflict_fields"])
|
||
json_path = ensure_batch_subdir(self.batch, "logs") / "merged_fields.json"
|
||
save_merged_fields(json_path, self.merged_fields, self.merge_summary)
|
||
create_artifact_for_file(
|
||
self.batch,
|
||
path=json_path,
|
||
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.MERGED_FIELDS,
|
||
file_format=RegulatoryInfoPackageArtifact.FileFormat.JSON,
|
||
created_by_node=node.node_code,
|
||
)
|
||
return
|
||
if node.node_code == "generate_docs":
|
||
self.generation_results = generate_package_documents(self.batch, self.template_config, self.merged_fields)
|
||
generated_files = []
|
||
for result in self.generation_results:
|
||
if result.path:
|
||
artifact = create_artifact_for_file(
|
||
self.batch,
|
||
path=result.path,
|
||
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.GENERATED_DOCUMENT,
|
||
file_format=result.actual_format,
|
||
name=result.template_code,
|
||
metadata=result.__dict__,
|
||
created_by_node=node.node_code,
|
||
)
|
||
result.artifact_id = artifact.pk
|
||
if result.status in {"success", "fallback_success"}:
|
||
export = self._create_export(
|
||
path=result.path,
|
||
export_type=ExportedSummaryFile.ExportType.WORD,
|
||
export_category="generated_document",
|
||
)
|
||
result.export_id = export.pk
|
||
self.exports.append(export)
|
||
generated_files.append(result.__dict__)
|
||
self.batch.generated_files = generated_files
|
||
self.batch.save(update_fields=["generated_files"])
|
||
return
|
||
if node.node_code == "highlight_review_items":
|
||
return
|
||
if node.node_code == "trace_export":
|
||
excel_path, json_path = save_traceability_exports(self.batch.work_dir, self.merged_fields)
|
||
create_artifact_for_file(
|
||
self.batch,
|
||
path=json_path,
|
||
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.TRACEABILITY,
|
||
file_format=RegulatoryInfoPackageArtifact.FileFormat.JSON,
|
||
created_by_node=node.node_code,
|
||
)
|
||
artifact = create_artifact_for_file(
|
||
self.batch,
|
||
path=excel_path,
|
||
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.TRACEABILITY,
|
||
file_format=RegulatoryInfoPackageArtifact.FileFormat.EXCEL,
|
||
created_by_node=node.node_code,
|
||
)
|
||
export = self._create_export(
|
||
path=str(excel_path),
|
||
export_type=ExportedSummaryFile.ExportType.EXCEL,
|
||
export_category="traceability",
|
||
)
|
||
self.exports.append(export)
|
||
artifact.metadata = {"export_id": export.pk}
|
||
artifact.save(update_fields=["metadata"])
|
||
return
|
||
if node.node_code == "zip_export":
|
||
zip_path = create_zip_package(self.batch.work_dir, self.generation_results, self.batch.output_zip_name)
|
||
artifact = create_artifact_for_file(
|
||
self.batch,
|
||
path=zip_path,
|
||
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.ZIP_PACKAGE,
|
||
file_format=RegulatoryInfoPackageArtifact.FileFormat.ZIP,
|
||
created_by_node=node.node_code,
|
||
)
|
||
export = self._create_export(
|
||
path=str(zip_path),
|
||
export_type=ExportedSummaryFile.ExportType.ZIP,
|
||
export_category="regulatory_info_package",
|
||
)
|
||
self.exports.insert(0, export)
|
||
artifact.metadata = {"export_id": export.pk}
|
||
artifact.save(update_fields=["metadata"])
|
||
return
|
||
if node.node_code == "notify":
|
||
RegulatoryInfoPackageNotificationRecord.objects.create(
|
||
batch=self.batch,
|
||
recipient=self.batch.user,
|
||
export_ids=[export.pk for export in self.exports],
|
||
message_summary=build_assistant_summary(
|
||
batch_no=self.batch.batch_no,
|
||
exports=[
|
||
{
|
||
"file_name": export.file_name,
|
||
"download_url": f"/api/review-agent/file-summary/exports/{export.pk}/download/",
|
||
"export_type": export.export_type,
|
||
}
|
||
for export in self.exports
|
||
],
|
||
failed_files=[item for item in self.batch.generated_files if item.get("status") == "failed"],
|
||
),
|
||
send_status=RegulatoryInfoPackageNotificationRecord.SendStatus.SUCCESS,
|
||
)
|
||
return
|
||
|
||
def _create_export(self, *, path: str, export_type: str, export_category: str) -> ExportedSummaryFile:
|
||
from pathlib import Path
|
||
|
||
resolved = Path(path)
|
||
return ExportedSummaryFile.objects.create(
|
||
batch=None,
|
||
workflow_type=WORKFLOW_TYPE,
|
||
workflow_batch_id=self.batch.pk,
|
||
export_category=export_category,
|
||
export_type=export_type,
|
||
file_name=resolved.name,
|
||
storage_path=str(resolved),
|
||
)
|
||
|
||
|
||
def start_regulatory_info_package_workflow(
|
||
batch: RegulatoryInfoPackageBatch,
|
||
*,
|
||
async_run: bool | None = None,
|
||
) -> None:
|
||
if async_run is None:
|
||
async_run = getattr(settings, "REGULATORY_INFO_PACKAGE_ASYNC", True)
|
||
executor = RegulatoryInfoPackageWorkflowExecutor(batch)
|
||
if async_run:
|
||
Thread(target=executor.run, daemon=True).start()
|
||
else:
|
||
executor.run()
|