329 lines
14 KiB
Python
329 lines
14 KiB
Python
from __future__ import annotations
|
||
|
||
import logging
|
||
from threading import Thread
|
||
from uuid import uuid4
|
||
|
||
from django.conf import settings
|
||
from django.db import transaction
|
||
from django.utils import timezone
|
||
|
||
from review_agent.application_form_fill.constants import DEFAULT_OUTPUT_TYPES, FORM_FILL_NODE_DEFINITIONS, WORKFLOW_TYPE
|
||
from review_agent.application_form_fill.events import record_event
|
||
from review_agent.application_form_fill.services.field_extract import (
|
||
collect_document_texts,
|
||
run_parallel_extract,
|
||
save_field_extract_result,
|
||
)
|
||
from review_agent.application_form_fill.services.field_merge import merge_fields
|
||
from review_agent.application_form_fill.services.notifier import notify_completion
|
||
from review_agent.application_form_fill.services.summary import build_assistant_summary
|
||
from review_agent.application_form_fill.services.template_config import (
|
||
compute_config_hash,
|
||
load_template_config,
|
||
validate_template_config,
|
||
)
|
||
from review_agent.application_form_fill.services.template_repository import (
|
||
TemplateUnavailableError,
|
||
copy_template_to_batch,
|
||
)
|
||
from review_agent.application_form_fill.services.template_select import (
|
||
detect_registration_type,
|
||
parse_requested_templates,
|
||
select_templates,
|
||
)
|
||
from review_agent.application_form_fill.services.traceability_export import save_traceability_exports
|
||
from review_agent.application_form_fill.services.word_fill import create_word_export
|
||
from review_agent.application_form_fill.schemas import MergedField, TemplateSpec
|
||
from review_agent.application_form_fill.storage import build_batch_work_dir
|
||
from review_agent.models import ApplicationFormFillBatch, Conversation, FileSummaryBatch, Message, WorkflowNodeRun
|
||
|
||
|
||
logger = logging.getLogger("review_agent.application_form_fill.workflow")
|
||
|
||
|
||
def build_batch_no() -> str:
|
||
return f"AFF-{timezone.localtime().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[:6]}"
|
||
|
||
|
||
def find_latest_successful_summary_batch(conversation: Conversation) -> FileSummaryBatch | None:
|
||
return (
|
||
FileSummaryBatch.objects.filter(
|
||
conversation=conversation,
|
||
status=FileSummaryBatch.Status.SUCCESS,
|
||
)
|
||
.order_by("-finished_at", "-created_at", "-id")
|
||
.first()
|
||
)
|
||
|
||
|
||
@transaction.atomic
|
||
def create_application_form_fill_batch(
|
||
*,
|
||
conversation: Conversation,
|
||
user,
|
||
source_summary_batch: FileSummaryBatch,
|
||
trigger_message: Message | None = None,
|
||
requested_templates: list[str] | None = None,
|
||
output_types: list[str] | None = None,
|
||
) -> ApplicationFormFillBatch:
|
||
batch_no = build_batch_no()
|
||
work_dir = build_batch_work_dir(batch_no=batch_no)
|
||
work_dir.mkdir(parents=True, exist_ok=True)
|
||
batch = ApplicationFormFillBatch.objects.create(
|
||
conversation=conversation,
|
||
user=user,
|
||
trigger_message=trigger_message,
|
||
source_summary_batch=source_summary_batch,
|
||
batch_no=batch_no,
|
||
requested_templates=requested_templates or [],
|
||
output_types=output_types or DEFAULT_OUTPUT_TYPES,
|
||
work_dir=str(work_dir),
|
||
)
|
||
for code, name, group in FORM_FILL_NODE_DEFINITIONS:
|
||
WorkflowNodeRun.objects.create(
|
||
workflow_type=WORKFLOW_TYPE,
|
||
workflow_batch_id=batch.pk,
|
||
node_group=group,
|
||
node_code=code,
|
||
node_name=name,
|
||
)
|
||
record_event(batch, "workflow_created", {"batch_id": batch.pk, "batch_no": batch.batch_no})
|
||
return batch
|
||
|
||
|
||
class FormFillWorkflowExecutor:
|
||
"""Runs the auto-fill workflow skeleton; later stages fill node bodies."""
|
||
|
||
def __init__(self, batch: ApplicationFormFillBatch):
|
||
self.batch = batch
|
||
self.template_config: dict = {}
|
||
self.selected_templates: list[TemplateSpec] = []
|
||
self.template_paths: dict[str, str] = {}
|
||
self.document_texts: dict[str, str] = {}
|
||
self.extract_payload: dict = {}
|
||
self.merged_fields: dict[str, MergedField] = {}
|
||
self.conflicts: list[dict] = []
|
||
self.exports = []
|
||
self.generation_results: list[dict] = []
|
||
self.non_blocking_errors: list[str] = []
|
||
|
||
def run(self) -> None:
|
||
logger.info("自动填表工作流开始 batch_no=%s batch_id=%s", self.batch.batch_no, self.batch.pk)
|
||
self.batch.status = ApplicationFormFillBatch.Status.RUNNING
|
||
self.batch.started_at = timezone.now()
|
||
self.batch.save(update_fields=["status", "started_at"])
|
||
record_event(self.batch, "workflow_started", {"batch_id": self.batch.pk})
|
||
|
||
try:
|
||
for node in self._nodes():
|
||
if node.status in {WorkflowNodeRun.Status.SUCCESS, WorkflowNodeRun.Status.SKIPPED}:
|
||
continue
|
||
self._run_node(node)
|
||
except Exception as exc:
|
||
logger.exception("Application form fill workflow failed", extra={"batch_id": self.batch.pk})
|
||
self.batch.status = ApplicationFormFillBatch.Status.FAILED
|
||
self.batch.error_message = str(exc)
|
||
self.batch.finished_at = timezone.now()
|
||
self.batch.save(update_fields=["status", "error_message", "finished_at"])
|
||
record_event(self.batch, "workflow_failed", {"message": str(exc)})
|
||
return
|
||
|
||
self.batch.refresh_from_db()
|
||
if self.batch.status != ApplicationFormFillBatch.Status.PARTIAL_SUCCESS:
|
||
self.batch.status = ApplicationFormFillBatch.Status.SUCCESS
|
||
self.batch.finished_at = timezone.now()
|
||
self.batch.save(update_fields=["status", "finished_at"])
|
||
record_event(self.batch, "workflow_completed", {"batch_id": self.batch.pk})
|
||
logger.info("自动填表工作流完成 batch_no=%s", self.batch.batch_no)
|
||
|
||
def _nodes(self):
|
||
return WorkflowNodeRun.objects.filter(
|
||
workflow_type=WORKFLOW_TYPE,
|
||
workflow_batch_id=self.batch.pk,
|
||
).order_by("id")
|
||
|
||
def _run_node(self, node: WorkflowNodeRun) -> None:
|
||
node.status = WorkflowNodeRun.Status.RUNNING
|
||
node.progress = 10
|
||
node.started_at = timezone.now()
|
||
node.message = f"{node.node_name}处理中"
|
||
node.save(update_fields=["status", "progress", "started_at", "message"])
|
||
record_event(
|
||
self.batch,
|
||
"node_progress",
|
||
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
|
||
)
|
||
|
||
if node.node_code == "pdf_convert":
|
||
self._append_risk_note(
|
||
{
|
||
"type": "pdf_pending",
|
||
"message": "PDF 转换为后续增强项,本次优先生成 Word。",
|
||
}
|
||
)
|
||
node.status = WorkflowNodeRun.Status.SKIPPED
|
||
node.progress = 100
|
||
node.finished_at = timezone.now()
|
||
node.message = "PDF 转换为后续增强项,本次跳过"
|
||
node.save(update_fields=["status", "progress", "finished_at", "message"])
|
||
record_event(
|
||
self.batch,
|
||
"node_progress",
|
||
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
|
||
)
|
||
return
|
||
|
||
self._execute_node(node)
|
||
|
||
node.status = WorkflowNodeRun.Status.SUCCESS
|
||
node.progress = 100
|
||
node.finished_at = timezone.now()
|
||
node.message = f"{node.node_name}完成"
|
||
node.save(update_fields=["status", "progress", "finished_at", "message"])
|
||
record_event(
|
||
self.batch,
|
||
"node_progress",
|
||
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
|
||
)
|
||
|
||
def _execute_node(self, node: WorkflowNodeRun) -> None:
|
||
if node.node_code == "prepare":
|
||
if self.batch.source_summary_batch.status != FileSummaryBatch.Status.SUCCESS:
|
||
raise ValueError("自动填表需要成功的文件汇总批次。")
|
||
return
|
||
if node.node_code == "template_select":
|
||
self.template_config = load_template_config()
|
||
errors = validate_template_config(self.template_config)
|
||
if errors:
|
||
raise ValueError(";".join(errors))
|
||
requested = parse_requested_templates(self.batch.trigger_message.content if self.batch.trigger_message else "")
|
||
registration_type, source = detect_registration_type(batch=self.batch, message=self.batch.trigger_message.content if self.batch.trigger_message else "")
|
||
specs, risk_notes = select_templates(self.template_config, requested, registration_type)
|
||
if not specs:
|
||
raise ValueError("未选择到可用申报模板。")
|
||
self.selected_templates = specs
|
||
self.batch.requested_templates = requested
|
||
self.batch.selected_templates = [spec.code for spec in specs]
|
||
self.batch.registration_type = registration_type
|
||
self.batch.registration_type_source = source
|
||
self.batch.template_config_version = str(self.template_config.get("version") or "")
|
||
self.batch.template_config_hash = compute_config_hash()
|
||
self.batch.risk_notes = list(self.batch.risk_notes or []) + risk_notes
|
||
self.batch.save(
|
||
update_fields=[
|
||
"requested_templates",
|
||
"selected_templates",
|
||
"registration_type",
|
||
"registration_type_source",
|
||
"template_config_version",
|
||
"template_config_hash",
|
||
"risk_notes",
|
||
]
|
||
)
|
||
return
|
||
if node.node_code == "template_copy":
|
||
for spec in self.selected_templates:
|
||
try:
|
||
artifact = copy_template_to_batch(spec, self.batch, self.template_config)
|
||
self.template_paths[spec.code] = artifact.storage_path
|
||
except TemplateUnavailableError as exc:
|
||
self.non_blocking_errors.append(str(exc))
|
||
self._append_risk_note({"type": "template_unavailable", "message": str(exc), "template_code": spec.code})
|
||
if not self.template_paths:
|
||
raise ValueError("没有可用的 Word 模板副本。")
|
||
return
|
||
if node.node_code == "field_extract":
|
||
self.document_texts = collect_document_texts(self.batch.source_summary_batch)
|
||
self.extract_payload = run_parallel_extract(self.document_texts, self.selected_templates)
|
||
save_field_extract_result(self.batch, self.extract_payload)
|
||
return
|
||
if node.node_code == "conflict_merge":
|
||
self.merged_fields, self.conflicts = merge_fields(
|
||
self.extract_payload.get("regex_results") or {},
|
||
self.extract_payload.get("llm_results") or {},
|
||
)
|
||
product = self.merged_fields.get("product_name")
|
||
if product and product.value:
|
||
self.batch.product_name = product.value
|
||
self.batch.conflict_summary = self.conflicts
|
||
self.batch.save(update_fields=["product_name", "conflict_summary"])
|
||
return
|
||
if node.node_code == "word_fill":
|
||
for spec in self.selected_templates:
|
||
template_path = self.template_paths.get(spec.code)
|
||
if not template_path:
|
||
self.generation_results.append(
|
||
{
|
||
"template_code": spec.code,
|
||
"template_label": spec.output_label,
|
||
"word_status": "failed",
|
||
"pdf_status": "待增强",
|
||
"error_message": "模板不可用",
|
||
}
|
||
)
|
||
continue
|
||
export = create_word_export(self.batch, spec, template_path, self.merged_fields, self.conflicts)
|
||
self.exports.append(export)
|
||
self.generation_results.append(
|
||
{
|
||
"template_code": spec.code,
|
||
"template_label": spec.output_label,
|
||
"word_status": "success",
|
||
"pdf_status": "待增强",
|
||
"error_message": "",
|
||
}
|
||
)
|
||
if not any(item["word_status"] == "success" for item in self.generation_results):
|
||
raise ValueError("所有目标 Word 模板均生成失败。")
|
||
return
|
||
if node.node_code == "trace_export":
|
||
self.exports.extend(
|
||
save_traceability_exports(
|
||
self.batch,
|
||
self.merged_fields,
|
||
self.conflicts,
|
||
self.selected_templates,
|
||
self.generation_results,
|
||
)
|
||
)
|
||
return
|
||
if node.node_code == "output_export":
|
||
Message.objects.create(
|
||
conversation=self.batch.conversation,
|
||
role=Message.Role.ASSISTANT,
|
||
content=build_assistant_summary(self.batch, self.exports),
|
||
)
|
||
return
|
||
if node.node_code == "notify":
|
||
notification = notify_completion(
|
||
self.batch,
|
||
self.exports,
|
||
fail=getattr(settings, "APPLICATION_FORM_FILL_MOCK_NOTIFY_FAIL", False),
|
||
)
|
||
if notification.send_status == notification.SendStatus.FAILED:
|
||
self.non_blocking_errors.append(notification.error_message or "通知失败")
|
||
return
|
||
if node.node_code == "completed":
|
||
self._mark_final_status()
|
||
|
||
def _mark_final_status(self) -> None:
|
||
failed_word = any(item.get("word_status") == "failed" for item in self.generation_results)
|
||
if self.non_blocking_errors or failed_word:
|
||
self.batch.status = ApplicationFormFillBatch.Status.PARTIAL_SUCCESS
|
||
else:
|
||
self.batch.status = ApplicationFormFillBatch.Status.SUCCESS
|
||
self.batch.save(update_fields=["status"])
|
||
|
||
def _append_risk_note(self, note: dict) -> None:
|
||
self.batch.risk_notes = list(self.batch.risk_notes or []) + [note]
|
||
self.batch.save(update_fields=["risk_notes"])
|
||
|
||
|
||
def start_application_form_fill_workflow(batch: ApplicationFormFillBatch, *, async_run: bool = True) -> None:
|
||
executor = FormFillWorkflowExecutor(batch)
|
||
if not async_run:
|
||
executor.run()
|
||
return
|
||
Thread(target=executor.run, daemon=True).start()
|