Files

376 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import logging
from threading import Thread
from uuid import uuid4
from django.conf import settings
from django.db import transaction
from django.utils import timezone
from review_agent.file_summary.paths import resolve_storage_path
from review_agent.models import (
Conversation,
ExportedSummaryFile,
Message,
RegulatoryInfoPackageArtifact,
RegulatoryInfoPackageBatch,
RegulatoryInfoPackageNotificationRecord,
WorkflowNodeRun,
)
from review_agent.regulatory_info_package.constants import (
DEFAULT_ZIP_NAME,
REGULATORY_INFO_PACKAGE_NODE_DEFINITIONS,
WORKFLOW_TYPE,
)
from review_agent.regulatory_info_package.events import record_event
from review_agent.regulatory_info_package.services.template_config import (
compute_config_hash,
load_template_config,
validate_template_config,
)
from review_agent.regulatory_info_package.services.field_extract import run_parallel_extract, save_field_extract_result
from review_agent.regulatory_info_package.services.field_merge import merge_fields, save_merged_fields
from review_agent.regulatory_info_package.services.instruction_extract import parse_instruction_docx, save_instruction_extract_json
from review_agent.regulatory_info_package.services.package_generate import generate_package_documents
from review_agent.regulatory_info_package.services.summary import build_assistant_summary
from review_agent.regulatory_info_package.services.traceability_export import save_traceability_exports
from review_agent.regulatory_info_package.services.zip_export import create_zip_package
from review_agent.regulatory_info_package.schemas import GeneratedFileResult, InstructionExtractResult, MergedField
from review_agent.regulatory_info_package.storage import build_batch_work_dir
from review_agent.regulatory_info_package.storage import create_artifact_for_file, ensure_batch_subdir
logger = logging.getLogger("review_agent.regulatory_info_package.workflow")
def build_batch_no() -> str:
return f"RIP-{timezone.localtime().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[:6]}"
@transaction.atomic
def create_regulatory_info_package_batch(
*,
conversation: Conversation,
user,
trigger_message: Message | None = None,
source_attachment=None,
source_summary_batch=None,
source_summary_item_id: int | None = None,
source_file_name: str = "",
source_storage_path: str = "",
existing_batch: RegulatoryInfoPackageBatch | None = None,
) -> RegulatoryInfoPackageBatch:
batch = existing_batch
if batch is None:
batch_no = build_batch_no()
work_dir = build_batch_work_dir(batch_no=batch_no)
work_dir.mkdir(parents=True, exist_ok=True)
batch = RegulatoryInfoPackageBatch.objects.create(
conversation=conversation,
user=user,
trigger_message=trigger_message,
source_attachment=source_attachment,
source_summary_batch=source_summary_batch,
source_summary_item_id=source_summary_item_id,
source_file_name=source_file_name or getattr(source_attachment, "original_name", ""),
source_storage_path=source_storage_path or getattr(source_attachment, "storage_path", ""),
batch_no=batch_no,
output_zip_name=DEFAULT_ZIP_NAME,
work_dir=str(work_dir),
)
for code, name, group in REGULATORY_INFO_PACKAGE_NODE_DEFINITIONS:
WorkflowNodeRun.objects.get_or_create(
workflow_type=WORKFLOW_TYPE,
workflow_batch_id=batch.pk,
node_code=code,
defaults={
"node_group": group,
"node_name": name,
},
)
record_event(batch, "workflow_created", {"batch_id": batch.pk, "batch_no": batch.batch_no})
return batch
class RegulatoryInfoPackageWorkflowExecutor:
"""Runs the Chapter 1 regulatory information package workflow."""
def __init__(self, batch: RegulatoryInfoPackageBatch):
self.batch = batch
self.template_config: dict = {}
self.instruction: InstructionExtractResult | None = None
self.extract_payload: dict = {}
self.merged_fields: dict[str, MergedField] = {}
self.merge_summary: dict[str, list[dict]] = {}
self.generation_results: list[GeneratedFileResult] = []
self.exports: list[ExportedSummaryFile] = []
def run(self) -> None:
logger.info("监管信息材料包工作流开始 batch_no=%s batch_id=%s", self.batch.batch_no, self.batch.pk)
self.batch.status = RegulatoryInfoPackageBatch.Status.RUNNING
self.batch.started_at = timezone.now()
self.batch.save(update_fields=["status", "started_at"])
record_event(self.batch, "workflow_started", {"batch_id": self.batch.pk})
try:
for node in self._nodes():
if node.status in {WorkflowNodeRun.Status.SUCCESS, WorkflowNodeRun.Status.SKIPPED}:
continue
self._run_node(node)
except Exception as exc:
logger.exception("Regulatory info package workflow failed", extra={"batch_id": self.batch.pk})
self.batch.status = RegulatoryInfoPackageBatch.Status.FAILED
self.batch.error_message = str(exc)
self.batch.finished_at = timezone.now()
self.batch.save(update_fields=["status", "error_message", "finished_at"])
record_event(self.batch, "workflow_failed", {"message": str(exc)})
return
self.batch.status = RegulatoryInfoPackageBatch.Status.SUCCESS
self.batch.finished_at = timezone.now()
self.batch.save(update_fields=["status", "finished_at"])
self._append_completion_message()
record_event(self.batch, "workflow_completed", {"batch_id": self.batch.pk})
def _nodes(self):
return WorkflowNodeRun.objects.filter(
workflow_type=WORKFLOW_TYPE,
workflow_batch_id=self.batch.pk,
).order_by("id")
def _run_node(self, node: WorkflowNodeRun) -> None:
node.status = WorkflowNodeRun.Status.RUNNING
node.progress = 10
node.started_at = timezone.now()
node.message = f"{node.node_name}处理中"
node.save(update_fields=["status", "progress", "started_at", "message"])
record_event(self.batch, "node_progress", {"node_code": node.node_code, "status": node.status})
self._execute_node(node)
node.status = WorkflowNodeRun.Status.SUCCESS
node.progress = 100
node.finished_at = timezone.now()
node.message = f"{node.node_name}完成"
node.save(update_fields=["status", "progress", "finished_at", "message"])
record_event(self.batch, "node_progress", {"node_code": node.node_code, "status": node.status})
def _execute_node(self, node: WorkflowNodeRun) -> None:
if node.node_code == "prepare":
self.template_config = load_template_config()
errors = validate_template_config(self.template_config)
if errors:
raise ValueError("".join(errors))
self.batch.template_config_version = str(self.template_config.get("version") or "")
self.batch.template_config_hash = compute_config_hash()
self.batch.save(update_fields=["template_config_version", "template_config_hash"])
return
if node.node_code == "template_copy":
return
if node.node_code == "text_extract":
if not self.batch.source_storage_path:
self.instruction = None
return
path = resolve_storage_path(self.batch.source_storage_path)
self.instruction = parse_instruction_docx(path)
json_path = ensure_batch_subdir(self.batch, "logs") / "instruction_extract.json"
save_instruction_extract_json(json_path, self.instruction)
create_artifact_for_file(
self.batch,
path=json_path,
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.INSTRUCTION_EXTRACT,
file_format=RegulatoryInfoPackageArtifact.FileFormat.JSON,
created_by_node=node.node_code,
)
return
if node.node_code == "field_extract":
if not self.instruction:
self.extract_payload = {"regex_results": {}, "llm_results": {}, "llm_error": ""}
return
self.extract_payload = run_parallel_extract(self.instruction, llm_extract_func=lambda _instruction: {})
json_path = ensure_batch_subdir(self.batch, "logs") / "field_extract_result.json"
save_field_extract_result(json_path, self.extract_payload)
create_artifact_for_file(
self.batch,
path=json_path,
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.FIELD_EXTRACT_RESULT,
file_format=RegulatoryInfoPackageArtifact.FileFormat.JSON,
created_by_node=node.node_code,
)
return
if node.node_code == "field_merge":
self.merged_fields, self.merge_summary = merge_fields(
self.extract_payload.get("regex_results") or {},
self.extract_payload.get("llm_results") or {},
)
product = self.merged_fields.get("product_name")
if product and product.value and product.value != "/":
self.batch.product_name = product.value
self.batch.missing_fields = self.merge_summary.get("missing_fields", [])
self.batch.llm_only_fields = self.merge_summary.get("llm_only_fields", [])
self.batch.conflict_fields = self.merge_summary.get("conflict_fields", [])
self.batch.save(update_fields=["product_name", "missing_fields", "llm_only_fields", "conflict_fields"])
json_path = ensure_batch_subdir(self.batch, "logs") / "merged_fields.json"
save_merged_fields(json_path, self.merged_fields, self.merge_summary)
create_artifact_for_file(
self.batch,
path=json_path,
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.MERGED_FIELDS,
file_format=RegulatoryInfoPackageArtifact.FileFormat.JSON,
created_by_node=node.node_code,
)
return
if node.node_code == "generate_docs":
self.generation_results = generate_package_documents(self.batch, self.template_config, self.merged_fields)
generated_files = []
for result in self.generation_results:
if result.path:
artifact = create_artifact_for_file(
self.batch,
path=result.path,
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.GENERATED_DOCUMENT,
file_format=result.actual_format,
name=result.template_code,
metadata=result.__dict__,
created_by_node=node.node_code,
)
result.artifact_id = artifact.pk
if result.status in {"success", "fallback_success"}:
export = self._create_export(
path=result.path,
export_type=ExportedSummaryFile.ExportType.WORD,
export_category="generated_document",
)
result.export_id = export.pk
self.exports.append(export)
generated_files.append(result.__dict__)
self.batch.generated_files = generated_files
self.batch.save(update_fields=["generated_files"])
return
if node.node_code == "highlight_review_items":
return
if node.node_code == "trace_export":
excel_path, json_path = save_traceability_exports(self.batch.work_dir, self.merged_fields)
create_artifact_for_file(
self.batch,
path=json_path,
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.TRACEABILITY,
file_format=RegulatoryInfoPackageArtifact.FileFormat.JSON,
created_by_node=node.node_code,
)
artifact = create_artifact_for_file(
self.batch,
path=excel_path,
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.TRACEABILITY,
file_format=RegulatoryInfoPackageArtifact.FileFormat.EXCEL,
created_by_node=node.node_code,
)
export = self._create_export(
path=str(excel_path),
export_type=ExportedSummaryFile.ExportType.EXCEL,
export_category="traceability",
)
self.exports.append(export)
artifact.metadata = {"export_id": export.pk}
artifact.save(update_fields=["metadata"])
return
if node.node_code == "zip_export":
zip_path = create_zip_package(self.batch.work_dir, self.generation_results, self.batch.output_zip_name)
artifact = create_artifact_for_file(
self.batch,
path=zip_path,
artifact_type=RegulatoryInfoPackageArtifact.ArtifactType.ZIP_PACKAGE,
file_format=RegulatoryInfoPackageArtifact.FileFormat.ZIP,
created_by_node=node.node_code,
)
export = self._create_export(
path=str(zip_path),
export_type=ExportedSummaryFile.ExportType.ZIP,
export_category="regulatory_info_package",
)
self.exports.insert(0, export)
artifact.metadata = {"export_id": export.pk}
artifact.save(update_fields=["metadata"])
return
if node.node_code == "notify":
RegulatoryInfoPackageNotificationRecord.objects.create(
batch=self.batch,
recipient=self.batch.user,
export_ids=[export.pk for export in self.exports],
message_summary=build_assistant_summary(
batch_no=self.batch.batch_no,
exports=[
{
"file_name": export.file_name,
"download_url": f"/api/review-agent/file-summary/exports/{export.pk}/download/",
"export_type": export.export_type,
}
for export in self.exports
],
failed_files=[item for item in self.batch.generated_files if item.get("status") == "failed"],
),
send_status=RegulatoryInfoPackageNotificationRecord.SendStatus.SUCCESS,
)
return
def _append_completion_message(self) -> None:
if (
Message.objects.filter(
conversation=self.batch.conversation,
role=Message.Role.ASSISTANT,
content__contains=self.batch.batch_no,
)
.filter(content__contains=self.batch.output_zip_name)
.exists()
):
return
exports = list(
ExportedSummaryFile.objects.filter(
workflow_type=WORKFLOW_TYPE,
workflow_batch_id=self.batch.pk,
)
)
exports = sorted(exports, key=lambda export: 0 if export.export_type == ExportedSummaryFile.ExportType.ZIP else 1)
content = build_assistant_summary(
batch_no=self.batch.batch_no,
exports=[
{
"file_name": export.file_name,
"download_url": f"/api/review-agent/file-summary/exports/{export.pk}/download/",
"export_type": export.export_type,
}
for export in exports
],
failed_files=[item for item in self.batch.generated_files if item.get("status") == "failed"],
)
Message.objects.create(
conversation=self.batch.conversation,
role=Message.Role.ASSISTANT,
content=content,
)
def _create_export(self, *, path: str, export_type: str, export_category: str) -> ExportedSummaryFile:
from pathlib import Path
resolved = Path(path)
return ExportedSummaryFile.objects.create(
batch=None,
workflow_type=WORKFLOW_TYPE,
workflow_batch_id=self.batch.pk,
export_category=export_category,
export_type=export_type,
file_name=resolved.name,
storage_path=str(resolved),
)
def start_regulatory_info_package_workflow(
batch: RegulatoryInfoPackageBatch,
*,
async_run: bool | None = None,
) -> None:
if async_run is None:
async_run = getattr(settings, "REGULATORY_INFO_PACKAGE_ASYNC", True)
executor = RegulatoryInfoPackageWorkflowExecutor(batch)
if async_run:
Thread(target=executor.run, daemon=True).start()
else:
executor.run()