Files
DEMO-AGENT/review_agent/application_form_fill/workflow.py

329 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import logging
from threading import Thread
from uuid import uuid4
from django.conf import settings
from django.db import transaction
from django.utils import timezone
from review_agent.application_form_fill.constants import DEFAULT_OUTPUT_TYPES, FORM_FILL_NODE_DEFINITIONS, WORKFLOW_TYPE
from review_agent.application_form_fill.events import record_event
from review_agent.application_form_fill.services.field_extract import (
collect_document_texts,
run_parallel_extract,
save_field_extract_result,
)
from review_agent.application_form_fill.services.field_merge import merge_fields
from review_agent.application_form_fill.services.notifier import notify_completion
from review_agent.application_form_fill.services.summary import build_assistant_summary
from review_agent.application_form_fill.services.template_config import (
compute_config_hash,
load_template_config,
validate_template_config,
)
from review_agent.application_form_fill.services.template_repository import (
TemplateUnavailableError,
copy_template_to_batch,
)
from review_agent.application_form_fill.services.template_select import (
detect_registration_type,
parse_requested_templates,
select_templates,
)
from review_agent.application_form_fill.services.traceability_export import save_traceability_exports
from review_agent.application_form_fill.services.word_fill import create_word_export
from review_agent.application_form_fill.schemas import MergedField, TemplateSpec
from review_agent.application_form_fill.storage import build_batch_work_dir
from review_agent.models import ApplicationFormFillBatch, Conversation, FileSummaryBatch, Message, WorkflowNodeRun
logger = logging.getLogger("review_agent.application_form_fill.workflow")
def build_batch_no() -> str:
return f"AFF-{timezone.localtime().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[:6]}"
def find_latest_successful_summary_batch(conversation: Conversation) -> FileSummaryBatch | None:
return (
FileSummaryBatch.objects.filter(
conversation=conversation,
status=FileSummaryBatch.Status.SUCCESS,
)
.order_by("-finished_at", "-created_at", "-id")
.first()
)
@transaction.atomic
def create_application_form_fill_batch(
*,
conversation: Conversation,
user,
source_summary_batch: FileSummaryBatch,
trigger_message: Message | None = None,
requested_templates: list[str] | None = None,
output_types: list[str] | None = None,
) -> ApplicationFormFillBatch:
batch_no = build_batch_no()
work_dir = build_batch_work_dir(batch_no=batch_no)
work_dir.mkdir(parents=True, exist_ok=True)
batch = ApplicationFormFillBatch.objects.create(
conversation=conversation,
user=user,
trigger_message=trigger_message,
source_summary_batch=source_summary_batch,
batch_no=batch_no,
requested_templates=requested_templates or [],
output_types=output_types or DEFAULT_OUTPUT_TYPES,
work_dir=str(work_dir),
)
for code, name, group in FORM_FILL_NODE_DEFINITIONS:
WorkflowNodeRun.objects.create(
workflow_type=WORKFLOW_TYPE,
workflow_batch_id=batch.pk,
node_group=group,
node_code=code,
node_name=name,
)
record_event(batch, "workflow_created", {"batch_id": batch.pk, "batch_no": batch.batch_no})
return batch
class FormFillWorkflowExecutor:
"""Runs the auto-fill workflow skeleton; later stages fill node bodies."""
def __init__(self, batch: ApplicationFormFillBatch):
self.batch = batch
self.template_config: dict = {}
self.selected_templates: list[TemplateSpec] = []
self.template_paths: dict[str, str] = {}
self.document_texts: dict[str, str] = {}
self.extract_payload: dict = {}
self.merged_fields: dict[str, MergedField] = {}
self.conflicts: list[dict] = []
self.exports = []
self.generation_results: list[dict] = []
self.non_blocking_errors: list[str] = []
def run(self) -> None:
logger.info("自动填表工作流开始 batch_no=%s batch_id=%s", self.batch.batch_no, self.batch.pk)
self.batch.status = ApplicationFormFillBatch.Status.RUNNING
self.batch.started_at = timezone.now()
self.batch.save(update_fields=["status", "started_at"])
record_event(self.batch, "workflow_started", {"batch_id": self.batch.pk})
try:
for node in self._nodes():
if node.status in {WorkflowNodeRun.Status.SUCCESS, WorkflowNodeRun.Status.SKIPPED}:
continue
self._run_node(node)
except Exception as exc:
logger.exception("Application form fill workflow failed", extra={"batch_id": self.batch.pk})
self.batch.status = ApplicationFormFillBatch.Status.FAILED
self.batch.error_message = str(exc)
self.batch.finished_at = timezone.now()
self.batch.save(update_fields=["status", "error_message", "finished_at"])
record_event(self.batch, "workflow_failed", {"message": str(exc)})
return
self.batch.refresh_from_db()
if self.batch.status != ApplicationFormFillBatch.Status.PARTIAL_SUCCESS:
self.batch.status = ApplicationFormFillBatch.Status.SUCCESS
self.batch.finished_at = timezone.now()
self.batch.save(update_fields=["status", "finished_at"])
record_event(self.batch, "workflow_completed", {"batch_id": self.batch.pk})
logger.info("自动填表工作流完成 batch_no=%s", self.batch.batch_no)
def _nodes(self):
return WorkflowNodeRun.objects.filter(
workflow_type=WORKFLOW_TYPE,
workflow_batch_id=self.batch.pk,
).order_by("id")
def _run_node(self, node: WorkflowNodeRun) -> None:
node.status = WorkflowNodeRun.Status.RUNNING
node.progress = 10
node.started_at = timezone.now()
node.message = f"{node.node_name}处理中"
node.save(update_fields=["status", "progress", "started_at", "message"])
record_event(
self.batch,
"node_progress",
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
)
if node.node_code == "pdf_convert":
self._append_risk_note(
{
"type": "pdf_pending",
"message": "PDF 转换为后续增强项,本次优先生成 Word。",
}
)
node.status = WorkflowNodeRun.Status.SKIPPED
node.progress = 100
node.finished_at = timezone.now()
node.message = "PDF 转换为后续增强项,本次跳过"
node.save(update_fields=["status", "progress", "finished_at", "message"])
record_event(
self.batch,
"node_progress",
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
)
return
self._execute_node(node)
node.status = WorkflowNodeRun.Status.SUCCESS
node.progress = 100
node.finished_at = timezone.now()
node.message = f"{node.node_name}完成"
node.save(update_fields=["status", "progress", "finished_at", "message"])
record_event(
self.batch,
"node_progress",
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
)
def _execute_node(self, node: WorkflowNodeRun) -> None:
if node.node_code == "prepare":
if self.batch.source_summary_batch.status != FileSummaryBatch.Status.SUCCESS:
raise ValueError("自动填表需要成功的文件汇总批次。")
return
if node.node_code == "template_select":
self.template_config = load_template_config()
errors = validate_template_config(self.template_config)
if errors:
raise ValueError("".join(errors))
requested = parse_requested_templates(self.batch.trigger_message.content if self.batch.trigger_message else "")
registration_type, source = detect_registration_type(batch=self.batch, message=self.batch.trigger_message.content if self.batch.trigger_message else "")
specs, risk_notes = select_templates(self.template_config, requested, registration_type)
if not specs:
raise ValueError("未选择到可用申报模板。")
self.selected_templates = specs
self.batch.requested_templates = requested
self.batch.selected_templates = [spec.code for spec in specs]
self.batch.registration_type = registration_type
self.batch.registration_type_source = source
self.batch.template_config_version = str(self.template_config.get("version") or "")
self.batch.template_config_hash = compute_config_hash()
self.batch.risk_notes = list(self.batch.risk_notes or []) + risk_notes
self.batch.save(
update_fields=[
"requested_templates",
"selected_templates",
"registration_type",
"registration_type_source",
"template_config_version",
"template_config_hash",
"risk_notes",
]
)
return
if node.node_code == "template_copy":
for spec in self.selected_templates:
try:
artifact = copy_template_to_batch(spec, self.batch, self.template_config)
self.template_paths[spec.code] = artifact.storage_path
except TemplateUnavailableError as exc:
self.non_blocking_errors.append(str(exc))
self._append_risk_note({"type": "template_unavailable", "message": str(exc), "template_code": spec.code})
if not self.template_paths:
raise ValueError("没有可用的 Word 模板副本。")
return
if node.node_code == "field_extract":
self.document_texts = collect_document_texts(self.batch.source_summary_batch)
self.extract_payload = run_parallel_extract(self.document_texts, self.selected_templates)
save_field_extract_result(self.batch, self.extract_payload)
return
if node.node_code == "conflict_merge":
self.merged_fields, self.conflicts = merge_fields(
self.extract_payload.get("regex_results") or {},
self.extract_payload.get("llm_results") or {},
)
product = self.merged_fields.get("product_name")
if product and product.value:
self.batch.product_name = product.value
self.batch.conflict_summary = self.conflicts
self.batch.save(update_fields=["product_name", "conflict_summary"])
return
if node.node_code == "word_fill":
for spec in self.selected_templates:
template_path = self.template_paths.get(spec.code)
if not template_path:
self.generation_results.append(
{
"template_code": spec.code,
"template_label": spec.output_label,
"word_status": "failed",
"pdf_status": "待增强",
"error_message": "模板不可用",
}
)
continue
export = create_word_export(self.batch, spec, template_path, self.merged_fields, self.conflicts)
self.exports.append(export)
self.generation_results.append(
{
"template_code": spec.code,
"template_label": spec.output_label,
"word_status": "success",
"pdf_status": "待增强",
"error_message": "",
}
)
if not any(item["word_status"] == "success" for item in self.generation_results):
raise ValueError("所有目标 Word 模板均生成失败。")
return
if node.node_code == "trace_export":
self.exports.extend(
save_traceability_exports(
self.batch,
self.merged_fields,
self.conflicts,
self.selected_templates,
self.generation_results,
)
)
return
if node.node_code == "output_export":
Message.objects.create(
conversation=self.batch.conversation,
role=Message.Role.ASSISTANT,
content=build_assistant_summary(self.batch, self.exports),
)
return
if node.node_code == "notify":
notification = notify_completion(
self.batch,
self.exports,
fail=getattr(settings, "APPLICATION_FORM_FILL_MOCK_NOTIFY_FAIL", False),
)
if notification.send_status == notification.SendStatus.FAILED:
self.non_blocking_errors.append(notification.error_message or "通知失败")
return
if node.node_code == "completed":
self._mark_final_status()
def _mark_final_status(self) -> None:
failed_word = any(item.get("word_status") == "failed" for item in self.generation_results)
if self.non_blocking_errors or failed_word:
self.batch.status = ApplicationFormFillBatch.Status.PARTIAL_SUCCESS
else:
self.batch.status = ApplicationFormFillBatch.Status.SUCCESS
self.batch.save(update_fields=["status"])
def _append_risk_note(self, note: dict) -> None:
self.batch.risk_notes = list(self.batch.risk_notes or []) + [note]
self.batch.save(update_fields=["risk_notes"])
def start_application_form_fill_workflow(batch: ApplicationFormFillBatch, *, async_run: bool = True) -> None:
executor = FormFillWorkflowExecutor(batch)
if not async_run:
executor.run()
return
Thread(target=executor.run, daemon=True).start()