diff --git a/review_agent/regulatory_review/services/completeness_check.py b/review_agent/regulatory_review/services/completeness_check.py index c30e11c..47e317f 100644 --- a/review_agent/regulatory_review/services/completeness_check.py +++ b/review_agent/regulatory_review/services/completeness_check.py @@ -1,13 +1,25 @@ from __future__ import annotations +from collections.abc import Callable + from review_agent.models import FileSummaryBatch from review_agent.regulatory_review.schemas import Finding -def run_completeness_check(batch: FileSummaryBatch, rule_set: dict) -> list[Finding]: +def run_completeness_check( + batch: FileSummaryBatch, + rule_set: dict, + progress_callback: Callable[[dict[str, object]], None] | None = None, +) -> list[Finding]: items = list(batch.items.order_by("file_index")) findings: list[Finding] = [] - for requirement in rule_set.get("requirements", []): + requirements = [ + requirement + for requirement in rule_set.get("requirements", []) + if requirement.get("type") in {"required", "conditional", "recommended", "chapter", "directory"} + ] + total = len(requirements) + for index, requirement in enumerate(requirements, start=1): if requirement.get("type") not in {"required", "conditional", "recommended", "chapter", "directory"}: continue matched = [ @@ -20,24 +32,32 @@ def run_completeness_check(batch: FileSummaryBatch, rule_set: dict) -> list[Find [*requirement.get("file_keywords", []), *requirement.get("aliases", [])], ) ] - if matched: - continue - findings.append( - Finding( - rule_code=requirement["code"], - category=requirement.get("category", "completeness"), - severity=requirement.get("severity", "medium"), - title=f"缺少{_numbered_title(requirement)}", - detail=f"当前文件汇总批次未发现{_numbered_title(requirement)}。", - suggestion=requirement.get("suggestion", ""), - evidence={ - "requirement_type": requirement.get("type"), - "matched_files": [], - "searched_keywords": requirement.get("file_keywords", []), - "searched_fields": ["file_name", "relative_path", "directory_level"], - }, + if not matched: + findings.append( + Finding( + rule_code=requirement["code"], + category=requirement.get("category", "completeness"), + severity=requirement.get("severity", "medium"), + title=f"缺少{_numbered_title(requirement)}", + detail=f"当前文件汇总批次未发现{_numbered_title(requirement)}。", + suggestion=requirement.get("suggestion", ""), + evidence={ + "requirement_type": requirement.get("type"), + "matched_files": [], + "searched_keywords": requirement.get("file_keywords", []), + "searched_fields": ["file_name", "relative_path", "directory_level"], + }, + ) + ) + if progress_callback: + progress_callback( + { + "processed": index, + "total": total, + "label": _numbered_title(requirement), + "finding_count": len(findings), + } ) - ) return findings diff --git a/review_agent/regulatory_review/services/consistency_check.py b/review_agent/regulatory_review/services/consistency_check.py index 1f24e17..19193aa 100644 --- a/review_agent/regulatory_review/services/consistency_check.py +++ b/review_agent/regulatory_review/services/consistency_check.py @@ -2,6 +2,7 @@ from __future__ import annotations import re from collections import defaultdict +from collections.abc import Callable from review_agent.regulatory_review.schemas import Finding @@ -17,27 +18,40 @@ FIELDS = { } -def run_consistency_check(document_texts: dict[str, str]) -> list[Finding]: +def run_consistency_check( + document_texts: dict[str, str], + progress_callback: Callable[[dict[str, object]], None] | None = None, +) -> list[Finding]: findings: list[Finding] = [] - for label, pattern in FIELDS.items(): + fields = list(FIELDS.items()) + total = len(fields) + for index, (label, pattern) in enumerate(fields, start=1): values: dict[str, list[str]] = defaultdict(list) for file_name, text in document_texts.items(): match = re.search(pattern, text) if match: values[_normalize(match.group(1))].append(file_name) - if len(values) <= 1: - continue - findings.append( - Finding( - rule_code=f"consistency:{label}", - category="consistency", - severity="high", - title=f"{label}在不同文件中不一致", - detail=f"发现 {len(values)} 个不同的{label}取值。", - suggestion=f"请统一各注册资料中的{label}。", - evidence={"field": label, "values": dict(values)}, + if len(values) > 1: + findings.append( + Finding( + rule_code=f"consistency:{label}", + category="consistency", + severity="high", + title=f"{label}在不同文件中不一致", + detail=f"发现 {len(values)} 个不同的{label}取值。", + suggestion=f"请统一各注册资料中的{label}。", + evidence={"field": label, "values": dict(values)}, + ) + ) + if progress_callback: + progress_callback( + { + "processed": index, + "total": total, + "label": label, + "finding_count": len(findings), + } ) - ) return findings diff --git a/review_agent/regulatory_review/services/structure_check.py b/review_agent/regulatory_review/services/structure_check.py index 85f5b27..efe8a40 100644 --- a/review_agent/regulatory_review/services/structure_check.py +++ b/review_agent/regulatory_review/services/structure_check.py @@ -1,12 +1,20 @@ from __future__ import annotations +from collections.abc import Callable + from review_agent.regulatory_review.schemas import Finding -def run_structure_check(document_texts: dict[str, str], rule_set: dict) -> list[Finding]: +def run_structure_check( + document_texts: dict[str, str], + rule_set: dict, + progress_callback: Callable[[dict[str, object]], None] | None = None, +) -> list[Finding]: findings: list[Finding] = [] combined_all_text = "\n".join(document_texts.values()) - for requirement in rule_set.get("requirements", []): + requirements = list(rule_set.get("requirements", [])) + total = len(requirements) + for index, requirement in enumerate(requirements, start=1): if requirement.get("structure_required") and not _contains_any( combined_all_text, [requirement.get("title", ""), *requirement.get("aliases", [])], @@ -27,25 +35,32 @@ def run_structure_check(document_texts: dict[str, str], rule_set: dict) -> list[ ) ) required_sections = requirement.get("required_sections") or [] - if not required_sections: - continue - matching_docs = _matching_documents(document_texts, requirement.get("file_keywords", [])) - if not matching_docs: - continue - combined_text = "\n".join(matching_docs.values()) - for section in required_sections: - if _contains_any(combined_text, [section]): - continue - findings.append( - Finding( - rule_code=f"{requirement['code']}:{section}", - category="structure", - severity=requirement.get("severity", "medium"), - title=f"{requirement['title']}缺少{section}章节", - detail=f"已匹配{requirement['title']}文件,但未发现{section}相关内容。", - suggestion=requirement.get("suggestion", ""), - evidence={"section": section, "files": list(matching_docs)}, - ) + if required_sections: + matching_docs = _matching_documents(document_texts, requirement.get("file_keywords", [])) + if matching_docs: + combined_text = "\n".join(matching_docs.values()) + for section in required_sections: + if _contains_any(combined_text, [section]): + continue + findings.append( + Finding( + rule_code=f"{requirement['code']}:{section}", + category="structure", + severity=requirement.get("severity", "medium"), + title=f"{requirement['title']}缺少{section}章节", + detail=f"已匹配{requirement['title']}文件,但未发现{section}相关内容。", + suggestion=requirement.get("suggestion", ""), + evidence={"section": section, "files": list(matching_docs)}, + ) + ) + if progress_callback: + progress_callback( + { + "processed": index, + "total": total, + "label": _numbered_title(requirement), + "finding_count": len(findings), + } ) return findings diff --git a/review_agent/regulatory_review/workflow.py b/review_agent/regulatory_review/workflow.py index 8e3c62c..3b4edbd 100644 --- a/review_agent/regulatory_review/workflow.py +++ b/review_agent/regulatory_review/workflow.py @@ -178,7 +178,7 @@ class RegulatoryWorkflowExecutor: {"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message}, ) - self._execute_node(node.node_code) + self._execute_node(node) node.status = WorkflowNodeRun.Status.SUCCESS node.progress = 100 @@ -198,7 +198,44 @@ class RegulatoryWorkflowExecutor: node.progress, ) - def _execute_node(self, node_code: str) -> None: + def _update_node_progress( + self, + node: WorkflowNodeRun, + *, + processed: int, + total: int, + message: str, + ) -> None: + if total <= 0: + return + progress = min(95, 10 + int((max(processed, 0) / total) * 85)) + node.progress = progress + node.message = message + node.save(update_fields=["progress", "message"]) + record_event( + self.batch, + "node_progress", + { + "node_code": node.node_code, + "status": node.status, + "progress": node.progress, + "message": node.message, + "processed": processed, + "total": total, + }, + ) + logger.info( + "节点进度 batch_no=%s node=%s progress=%s processed=%s total=%s message=%s", + self.batch.batch_no, + node.node_code, + progress, + processed, + total, + message, + ) + + def _execute_node(self, node: WorkflowNodeRun) -> None: + node_code = node.node_code if node_code == "condition_confirm": self._pause_for_condition_confirmation() return @@ -212,7 +249,19 @@ class RegulatoryWorkflowExecutor: ) return if node_code == "completeness_check": - findings = run_completeness_check(self.batch.source_summary_batch, self._rules()) + findings = run_completeness_check( + self.batch.source_summary_batch, + self._rules(), + progress_callback=lambda update: self._update_node_progress( + node, + processed=int(update.get("processed") or 0), + total=int(update.get("total") or 0), + message=( + f"完整性核查 {update.get('processed')}/{update.get('total')}:" + f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题" + ), + ), + ) self.findings.extend(findings) logger.info( "方法执行 batch_no=%s method=run_completeness_check findings=%s source_summary=%s", @@ -229,7 +278,7 @@ class RegulatoryWorkflowExecutor: ) return if node_code == "text_extract": - self.document_texts = self._extract_source_texts() + self.document_texts = self._extract_source_texts(node) logger.info( "方法执行 batch_no=%s method=_extract_source_texts success_docs=%s total_files=%s", self.batch.batch_no, @@ -246,7 +295,19 @@ class RegulatoryWorkflowExecutor: ) return if node_code == "structure_check": - findings = run_structure_check(self.document_texts, self._rules()) + findings = run_structure_check( + self.document_texts, + self._rules(), + progress_callback=lambda update: self._update_node_progress( + node, + processed=int(update.get("processed") or 0), + total=int(update.get("total") or 0), + message=( + f"章节核查 {update.get('processed')}/{update.get('total')}:" + f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题" + ), + ), + ) self.findings.extend(findings) logger.info( "方法执行 batch_no=%s method=run_structure_check findings=%s docs=%s", @@ -257,7 +318,18 @@ class RegulatoryWorkflowExecutor: self._save_llm_review("structure_check", {"findings": [finding.to_dict() for finding in findings]}) return if node_code == "consistency_check": - findings = run_consistency_check(self.document_texts) + findings = run_consistency_check( + self.document_texts, + progress_callback=lambda update: self._update_node_progress( + node, + processed=int(update.get("processed") or 0), + total=int(update.get("total") or 0), + message=( + f"一致性核查 {update.get('processed')}/{update.get('total')}:" + f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题" + ), + ), + ) self.findings.extend(findings) logger.info( "方法执行 batch_no=%s method=run_consistency_check findings=%s docs=%s", @@ -353,9 +425,11 @@ class RegulatoryWorkflowExecutor: self.rule_set = apply_rule_scope(load_rule_file(), self.batch.condition_json.get("rule_scope") or {}) return self.rule_set - def _extract_source_texts(self) -> dict[str, str]: + def _extract_source_texts(self, node: WorkflowNodeRun | None = None) -> dict[str, str]: texts = {} - for item in self.batch.source_summary_batch.items.order_by("file_index"): + items = list(self.batch.source_summary_batch.items.order_by("file_index")) + total = len(items) + for index, item in enumerate(items, start=1): path = Path(item.storage_path) if not path.is_absolute(): path = Path(settings.MEDIA_ROOT) / item.storage_path @@ -369,6 +443,13 @@ class RegulatoryWorkflowExecutor: "field_candidates": {}, "front_text": "", } + if node: + self._update_node_progress( + node, + processed=index, + total=total, + message=f"文本抽取 {index}/{total}:{item.file_name}(文件不存在)", + ) continue result = extract_text(path) field_review = review_condition_fields( @@ -396,6 +477,13 @@ class RegulatoryWorkflowExecutor: len((field_review.get("selected_fields") or {})), len(result.text or ""), ) + if node: + self._update_node_progress( + node, + processed=index, + total=total, + message=f"文本抽取 {index}/{total}:{item.file_name}({result.status})", + ) return texts def _save_llm_review(self, stage: str, payload: dict[str, object]) -> dict[str, object]: diff --git a/tests/test_regulatory_workflow.py b/tests/test_regulatory_workflow.py index 9230357..18da71b 100644 --- a/tests/test_regulatory_workflow.py +++ b/tests/test_regulatory_workflow.py @@ -17,6 +17,7 @@ from review_agent.models import ( ) from review_agent.regulatory_review.workflow import ( NODE_DEFINITIONS, + RegulatoryWorkflowExecutor, create_regulatory_review_batch, find_latest_successful_summary_batch, start_regulatory_review_workflow, @@ -413,3 +414,89 @@ def test_workflow_records_llm_review_artifacts_for_review_nodes( assert "llm_review_structure_check.json" in artifact_names assert "llm_review_consistency_check.json" in artifact_names assert "llm_review_risk_assess.json" in artifact_names + + +def test_workflow_progress_uses_processed_file_counts(settings, tmp_path, django_user_model): + settings.MEDIA_ROOT = tmp_path + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + summary = FileSummaryBatch.objects.create( + conversation=conversation, + user=user, + batch_no="FS-OK", + status=FileSummaryBatch.Status.SUCCESS, + ) + for index, name in enumerate(["注册信息.txt", "说明书.txt", "综述.txt"], start=1): + path = tmp_path / name + path.write_text(f"产品名称:甲胎蛋白检测试剂盒\n文件:{name}", encoding="utf-8") + FileSummaryItem.objects.create( + batch=summary, + file_index=index, + file_name=name, + file_type="txt", + relative_path=name, + storage_path=str(path), + ) + batch = create_regulatory_review_batch( + conversation=conversation, + user=user, + source_summary_batch=summary, + ) + node = WorkflowNodeRun.objects.get( + workflow_type="regulatory_review", + workflow_batch_id=batch.pk, + node_code="text_extract", + ) + executor = RegulatoryWorkflowExecutor(batch) + + texts = executor._extract_source_texts(node) + + node.refresh_from_db() + assert len(texts) == 3 + assert node.progress == 95 + assert "文本抽取 3/3" in node.message + assert "综述.txt" in node.message + assert WorkflowEvent.objects.filter( + workflow_type="regulatory_review", + workflow_batch_id=batch.pk, + event_type="node_progress", + payload__node_code="text_extract", + payload__processed=3, + payload__total=3, + ).exists() + + +def test_review_services_emit_actual_workload_progress_callbacks(django_user_model): + from review_agent.regulatory_review.services.completeness_check import run_completeness_check + from review_agent.regulatory_review.services.consistency_check import FIELDS, run_consistency_check + from review_agent.regulatory_review.services.structure_check import run_structure_check + + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + summary = FileSummaryBatch.objects.create( + conversation=conversation, + user=user, + batch_no="FS-OK", + status=FileSummaryBatch.Status.SUCCESS, + ) + rule_set = { + "requirements": [ + {"code": "r1", "title": "注册信息", "type": "required", "file_keywords": ["注册信息"]}, + {"code": "r2", "title": "说明书", "type": "required", "file_keywords": ["说明书"]}, + ] + } + completeness_updates = [] + structure_updates = [] + consistency_updates = [] + + run_completeness_check(summary, rule_set, progress_callback=completeness_updates.append) + run_structure_check({"注册信息.txt": "注册信息"}, rule_set, progress_callback=structure_updates.append) + run_consistency_check({"注册信息.txt": "产品名称:A"}, progress_callback=consistency_updates.append) + + assert completeness_updates[-1]["processed"] == 2 + assert completeness_updates[-1]["total"] == 2 + assert completeness_updates[-1]["label"] == "说明书" + assert structure_updates[-1]["processed"] == 2 + assert structure_updates[-1]["total"] == 2 + assert consistency_updates[-1]["processed"] == len(FIELDS) + assert consistency_updates[-1]["total"] == len(FIELDS)