feat(regulatory): 按实际处理数量更新节点进度
This commit is contained in:
@@ -1,13 +1,25 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from review_agent.models import FileSummaryBatch
|
from review_agent.models import FileSummaryBatch
|
||||||
from review_agent.regulatory_review.schemas import Finding
|
from review_agent.regulatory_review.schemas import Finding
|
||||||
|
|
||||||
|
|
||||||
def run_completeness_check(batch: FileSummaryBatch, rule_set: dict) -> list[Finding]:
|
def run_completeness_check(
|
||||||
|
batch: FileSummaryBatch,
|
||||||
|
rule_set: dict,
|
||||||
|
progress_callback: Callable[[dict[str, object]], None] | None = None,
|
||||||
|
) -> list[Finding]:
|
||||||
items = list(batch.items.order_by("file_index"))
|
items = list(batch.items.order_by("file_index"))
|
||||||
findings: list[Finding] = []
|
findings: list[Finding] = []
|
||||||
for requirement in rule_set.get("requirements", []):
|
requirements = [
|
||||||
|
requirement
|
||||||
|
for requirement in rule_set.get("requirements", [])
|
||||||
|
if requirement.get("type") in {"required", "conditional", "recommended", "chapter", "directory"}
|
||||||
|
]
|
||||||
|
total = len(requirements)
|
||||||
|
for index, requirement in enumerate(requirements, start=1):
|
||||||
if requirement.get("type") not in {"required", "conditional", "recommended", "chapter", "directory"}:
|
if requirement.get("type") not in {"required", "conditional", "recommended", "chapter", "directory"}:
|
||||||
continue
|
continue
|
||||||
matched = [
|
matched = [
|
||||||
@@ -20,8 +32,7 @@ def run_completeness_check(batch: FileSummaryBatch, rule_set: dict) -> list[Find
|
|||||||
[*requirement.get("file_keywords", []), *requirement.get("aliases", [])],
|
[*requirement.get("file_keywords", []), *requirement.get("aliases", [])],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
if matched:
|
if not matched:
|
||||||
continue
|
|
||||||
findings.append(
|
findings.append(
|
||||||
Finding(
|
Finding(
|
||||||
rule_code=requirement["code"],
|
rule_code=requirement["code"],
|
||||||
@@ -38,6 +49,15 @@ def run_completeness_check(batch: FileSummaryBatch, rule_set: dict) -> list[Find
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
{
|
||||||
|
"processed": index,
|
||||||
|
"total": total,
|
||||||
|
"label": _numbered_title(requirement),
|
||||||
|
"finding_count": len(findings),
|
||||||
|
}
|
||||||
|
)
|
||||||
return findings
|
return findings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from review_agent.regulatory_review.schemas import Finding
|
from review_agent.regulatory_review.schemas import Finding
|
||||||
|
|
||||||
@@ -17,16 +18,20 @@ FIELDS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def run_consistency_check(document_texts: dict[str, str]) -> list[Finding]:
|
def run_consistency_check(
|
||||||
|
document_texts: dict[str, str],
|
||||||
|
progress_callback: Callable[[dict[str, object]], None] | None = None,
|
||||||
|
) -> list[Finding]:
|
||||||
findings: list[Finding] = []
|
findings: list[Finding] = []
|
||||||
for label, pattern in FIELDS.items():
|
fields = list(FIELDS.items())
|
||||||
|
total = len(fields)
|
||||||
|
for index, (label, pattern) in enumerate(fields, start=1):
|
||||||
values: dict[str, list[str]] = defaultdict(list)
|
values: dict[str, list[str]] = defaultdict(list)
|
||||||
for file_name, text in document_texts.items():
|
for file_name, text in document_texts.items():
|
||||||
match = re.search(pattern, text)
|
match = re.search(pattern, text)
|
||||||
if match:
|
if match:
|
||||||
values[_normalize(match.group(1))].append(file_name)
|
values[_normalize(match.group(1))].append(file_name)
|
||||||
if len(values) <= 1:
|
if len(values) > 1:
|
||||||
continue
|
|
||||||
findings.append(
|
findings.append(
|
||||||
Finding(
|
Finding(
|
||||||
rule_code=f"consistency:{label}",
|
rule_code=f"consistency:{label}",
|
||||||
@@ -38,6 +43,15 @@ def run_consistency_check(document_texts: dict[str, str]) -> list[Finding]:
|
|||||||
evidence={"field": label, "values": dict(values)},
|
evidence={"field": label, "values": dict(values)},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
{
|
||||||
|
"processed": index,
|
||||||
|
"total": total,
|
||||||
|
"label": label,
|
||||||
|
"finding_count": len(findings),
|
||||||
|
}
|
||||||
|
)
|
||||||
return findings
|
return findings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from review_agent.regulatory_review.schemas import Finding
|
from review_agent.regulatory_review.schemas import Finding
|
||||||
|
|
||||||
|
|
||||||
def run_structure_check(document_texts: dict[str, str], rule_set: dict) -> list[Finding]:
|
def run_structure_check(
|
||||||
|
document_texts: dict[str, str],
|
||||||
|
rule_set: dict,
|
||||||
|
progress_callback: Callable[[dict[str, object]], None] | None = None,
|
||||||
|
) -> list[Finding]:
|
||||||
findings: list[Finding] = []
|
findings: list[Finding] = []
|
||||||
combined_all_text = "\n".join(document_texts.values())
|
combined_all_text = "\n".join(document_texts.values())
|
||||||
for requirement in rule_set.get("requirements", []):
|
requirements = list(rule_set.get("requirements", []))
|
||||||
|
total = len(requirements)
|
||||||
|
for index, requirement in enumerate(requirements, start=1):
|
||||||
if requirement.get("structure_required") and not _contains_any(
|
if requirement.get("structure_required") and not _contains_any(
|
||||||
combined_all_text,
|
combined_all_text,
|
||||||
[requirement.get("title", ""), *requirement.get("aliases", [])],
|
[requirement.get("title", ""), *requirement.get("aliases", [])],
|
||||||
@@ -27,11 +35,9 @@ def run_structure_check(document_texts: dict[str, str], rule_set: dict) -> list[
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
required_sections = requirement.get("required_sections") or []
|
required_sections = requirement.get("required_sections") or []
|
||||||
if not required_sections:
|
if required_sections:
|
||||||
continue
|
|
||||||
matching_docs = _matching_documents(document_texts, requirement.get("file_keywords", []))
|
matching_docs = _matching_documents(document_texts, requirement.get("file_keywords", []))
|
||||||
if not matching_docs:
|
if matching_docs:
|
||||||
continue
|
|
||||||
combined_text = "\n".join(matching_docs.values())
|
combined_text = "\n".join(matching_docs.values())
|
||||||
for section in required_sections:
|
for section in required_sections:
|
||||||
if _contains_any(combined_text, [section]):
|
if _contains_any(combined_text, [section]):
|
||||||
@@ -47,6 +53,15 @@ def run_structure_check(document_texts: dict[str, str], rule_set: dict) -> list[
|
|||||||
evidence={"section": section, "files": list(matching_docs)},
|
evidence={"section": section, "files": list(matching_docs)},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(
|
||||||
|
{
|
||||||
|
"processed": index,
|
||||||
|
"total": total,
|
||||||
|
"label": _numbered_title(requirement),
|
||||||
|
"finding_count": len(findings),
|
||||||
|
}
|
||||||
|
)
|
||||||
return findings
|
return findings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class RegulatoryWorkflowExecutor:
|
|||||||
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
|
{"node_code": node.node_code, "status": node.status, "progress": node.progress, "message": node.message},
|
||||||
)
|
)
|
||||||
|
|
||||||
self._execute_node(node.node_code)
|
self._execute_node(node)
|
||||||
|
|
||||||
node.status = WorkflowNodeRun.Status.SUCCESS
|
node.status = WorkflowNodeRun.Status.SUCCESS
|
||||||
node.progress = 100
|
node.progress = 100
|
||||||
@@ -198,7 +198,44 @@ class RegulatoryWorkflowExecutor:
|
|||||||
node.progress,
|
node.progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _execute_node(self, node_code: str) -> None:
|
def _update_node_progress(
|
||||||
|
self,
|
||||||
|
node: WorkflowNodeRun,
|
||||||
|
*,
|
||||||
|
processed: int,
|
||||||
|
total: int,
|
||||||
|
message: str,
|
||||||
|
) -> None:
|
||||||
|
if total <= 0:
|
||||||
|
return
|
||||||
|
progress = min(95, 10 + int((max(processed, 0) / total) * 85))
|
||||||
|
node.progress = progress
|
||||||
|
node.message = message
|
||||||
|
node.save(update_fields=["progress", "message"])
|
||||||
|
record_event(
|
||||||
|
self.batch,
|
||||||
|
"node_progress",
|
||||||
|
{
|
||||||
|
"node_code": node.node_code,
|
||||||
|
"status": node.status,
|
||||||
|
"progress": node.progress,
|
||||||
|
"message": node.message,
|
||||||
|
"processed": processed,
|
||||||
|
"total": total,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"节点进度 batch_no=%s node=%s progress=%s processed=%s total=%s message=%s",
|
||||||
|
self.batch.batch_no,
|
||||||
|
node.node_code,
|
||||||
|
progress,
|
||||||
|
processed,
|
||||||
|
total,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _execute_node(self, node: WorkflowNodeRun) -> None:
|
||||||
|
node_code = node.node_code
|
||||||
if node_code == "condition_confirm":
|
if node_code == "condition_confirm":
|
||||||
self._pause_for_condition_confirmation()
|
self._pause_for_condition_confirmation()
|
||||||
return
|
return
|
||||||
@@ -212,7 +249,19 @@ class RegulatoryWorkflowExecutor:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
if node_code == "completeness_check":
|
if node_code == "completeness_check":
|
||||||
findings = run_completeness_check(self.batch.source_summary_batch, self._rules())
|
findings = run_completeness_check(
|
||||||
|
self.batch.source_summary_batch,
|
||||||
|
self._rules(),
|
||||||
|
progress_callback=lambda update: self._update_node_progress(
|
||||||
|
node,
|
||||||
|
processed=int(update.get("processed") or 0),
|
||||||
|
total=int(update.get("total") or 0),
|
||||||
|
message=(
|
||||||
|
f"完整性核查 {update.get('processed')}/{update.get('total')}:"
|
||||||
|
f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
self.findings.extend(findings)
|
self.findings.extend(findings)
|
||||||
logger.info(
|
logger.info(
|
||||||
"方法执行 batch_no=%s method=run_completeness_check findings=%s source_summary=%s",
|
"方法执行 batch_no=%s method=run_completeness_check findings=%s source_summary=%s",
|
||||||
@@ -229,7 +278,7 @@ class RegulatoryWorkflowExecutor:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
if node_code == "text_extract":
|
if node_code == "text_extract":
|
||||||
self.document_texts = self._extract_source_texts()
|
self.document_texts = self._extract_source_texts(node)
|
||||||
logger.info(
|
logger.info(
|
||||||
"方法执行 batch_no=%s method=_extract_source_texts success_docs=%s total_files=%s",
|
"方法执行 batch_no=%s method=_extract_source_texts success_docs=%s total_files=%s",
|
||||||
self.batch.batch_no,
|
self.batch.batch_no,
|
||||||
@@ -246,7 +295,19 @@ class RegulatoryWorkflowExecutor:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
if node_code == "structure_check":
|
if node_code == "structure_check":
|
||||||
findings = run_structure_check(self.document_texts, self._rules())
|
findings = run_structure_check(
|
||||||
|
self.document_texts,
|
||||||
|
self._rules(),
|
||||||
|
progress_callback=lambda update: self._update_node_progress(
|
||||||
|
node,
|
||||||
|
processed=int(update.get("processed") or 0),
|
||||||
|
total=int(update.get("total") or 0),
|
||||||
|
message=(
|
||||||
|
f"章节核查 {update.get('processed')}/{update.get('total')}:"
|
||||||
|
f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
self.findings.extend(findings)
|
self.findings.extend(findings)
|
||||||
logger.info(
|
logger.info(
|
||||||
"方法执行 batch_no=%s method=run_structure_check findings=%s docs=%s",
|
"方法执行 batch_no=%s method=run_structure_check findings=%s docs=%s",
|
||||||
@@ -257,7 +318,18 @@ class RegulatoryWorkflowExecutor:
|
|||||||
self._save_llm_review("structure_check", {"findings": [finding.to_dict() for finding in findings]})
|
self._save_llm_review("structure_check", {"findings": [finding.to_dict() for finding in findings]})
|
||||||
return
|
return
|
||||||
if node_code == "consistency_check":
|
if node_code == "consistency_check":
|
||||||
findings = run_consistency_check(self.document_texts)
|
findings = run_consistency_check(
|
||||||
|
self.document_texts,
|
||||||
|
progress_callback=lambda update: self._update_node_progress(
|
||||||
|
node,
|
||||||
|
processed=int(update.get("processed") or 0),
|
||||||
|
total=int(update.get("total") or 0),
|
||||||
|
message=(
|
||||||
|
f"一致性核查 {update.get('processed')}/{update.get('total')}:"
|
||||||
|
f"{update.get('label') or ''},发现{update.get('finding_count') or 0}项问题"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
self.findings.extend(findings)
|
self.findings.extend(findings)
|
||||||
logger.info(
|
logger.info(
|
||||||
"方法执行 batch_no=%s method=run_consistency_check findings=%s docs=%s",
|
"方法执行 batch_no=%s method=run_consistency_check findings=%s docs=%s",
|
||||||
@@ -353,9 +425,11 @@ class RegulatoryWorkflowExecutor:
|
|||||||
self.rule_set = apply_rule_scope(load_rule_file(), self.batch.condition_json.get("rule_scope") or {})
|
self.rule_set = apply_rule_scope(load_rule_file(), self.batch.condition_json.get("rule_scope") or {})
|
||||||
return self.rule_set
|
return self.rule_set
|
||||||
|
|
||||||
def _extract_source_texts(self) -> dict[str, str]:
|
def _extract_source_texts(self, node: WorkflowNodeRun | None = None) -> dict[str, str]:
|
||||||
texts = {}
|
texts = {}
|
||||||
for item in self.batch.source_summary_batch.items.order_by("file_index"):
|
items = list(self.batch.source_summary_batch.items.order_by("file_index"))
|
||||||
|
total = len(items)
|
||||||
|
for index, item in enumerate(items, start=1):
|
||||||
path = Path(item.storage_path)
|
path = Path(item.storage_path)
|
||||||
if not path.is_absolute():
|
if not path.is_absolute():
|
||||||
path = Path(settings.MEDIA_ROOT) / item.storage_path
|
path = Path(settings.MEDIA_ROOT) / item.storage_path
|
||||||
@@ -369,6 +443,13 @@ class RegulatoryWorkflowExecutor:
|
|||||||
"field_candidates": {},
|
"field_candidates": {},
|
||||||
"front_text": "",
|
"front_text": "",
|
||||||
}
|
}
|
||||||
|
if node:
|
||||||
|
self._update_node_progress(
|
||||||
|
node,
|
||||||
|
processed=index,
|
||||||
|
total=total,
|
||||||
|
message=f"文本抽取 {index}/{total}:{item.file_name}(文件不存在)",
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
result = extract_text(path)
|
result = extract_text(path)
|
||||||
field_review = review_condition_fields(
|
field_review = review_condition_fields(
|
||||||
@@ -396,6 +477,13 @@ class RegulatoryWorkflowExecutor:
|
|||||||
len((field_review.get("selected_fields") or {})),
|
len((field_review.get("selected_fields") or {})),
|
||||||
len(result.text or ""),
|
len(result.text or ""),
|
||||||
)
|
)
|
||||||
|
if node:
|
||||||
|
self._update_node_progress(
|
||||||
|
node,
|
||||||
|
processed=index,
|
||||||
|
total=total,
|
||||||
|
message=f"文本抽取 {index}/{total}:{item.file_name}({result.status})",
|
||||||
|
)
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
def _save_llm_review(self, stage: str, payload: dict[str, object]) -> dict[str, object]:
|
def _save_llm_review(self, stage: str, payload: dict[str, object]) -> dict[str, object]:
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from review_agent.models import (
|
|||||||
)
|
)
|
||||||
from review_agent.regulatory_review.workflow import (
|
from review_agent.regulatory_review.workflow import (
|
||||||
NODE_DEFINITIONS,
|
NODE_DEFINITIONS,
|
||||||
|
RegulatoryWorkflowExecutor,
|
||||||
create_regulatory_review_batch,
|
create_regulatory_review_batch,
|
||||||
find_latest_successful_summary_batch,
|
find_latest_successful_summary_batch,
|
||||||
start_regulatory_review_workflow,
|
start_regulatory_review_workflow,
|
||||||
@@ -413,3 +414,89 @@ def test_workflow_records_llm_review_artifacts_for_review_nodes(
|
|||||||
assert "llm_review_structure_check.json" in artifact_names
|
assert "llm_review_structure_check.json" in artifact_names
|
||||||
assert "llm_review_consistency_check.json" in artifact_names
|
assert "llm_review_consistency_check.json" in artifact_names
|
||||||
assert "llm_review_risk_assess.json" in artifact_names
|
assert "llm_review_risk_assess.json" in artifact_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_progress_uses_processed_file_counts(settings, tmp_path, django_user_model):
|
||||||
|
settings.MEDIA_ROOT = tmp_path
|
||||||
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
|
conversation = Conversation.objects.create(user=user, title="会话")
|
||||||
|
summary = FileSummaryBatch.objects.create(
|
||||||
|
conversation=conversation,
|
||||||
|
user=user,
|
||||||
|
batch_no="FS-OK",
|
||||||
|
status=FileSummaryBatch.Status.SUCCESS,
|
||||||
|
)
|
||||||
|
for index, name in enumerate(["注册信息.txt", "说明书.txt", "综述.txt"], start=1):
|
||||||
|
path = tmp_path / name
|
||||||
|
path.write_text(f"产品名称:甲胎蛋白检测试剂盒\n文件:{name}", encoding="utf-8")
|
||||||
|
FileSummaryItem.objects.create(
|
||||||
|
batch=summary,
|
||||||
|
file_index=index,
|
||||||
|
file_name=name,
|
||||||
|
file_type="txt",
|
||||||
|
relative_path=name,
|
||||||
|
storage_path=str(path),
|
||||||
|
)
|
||||||
|
batch = create_regulatory_review_batch(
|
||||||
|
conversation=conversation,
|
||||||
|
user=user,
|
||||||
|
source_summary_batch=summary,
|
||||||
|
)
|
||||||
|
node = WorkflowNodeRun.objects.get(
|
||||||
|
workflow_type="regulatory_review",
|
||||||
|
workflow_batch_id=batch.pk,
|
||||||
|
node_code="text_extract",
|
||||||
|
)
|
||||||
|
executor = RegulatoryWorkflowExecutor(batch)
|
||||||
|
|
||||||
|
texts = executor._extract_source_texts(node)
|
||||||
|
|
||||||
|
node.refresh_from_db()
|
||||||
|
assert len(texts) == 3
|
||||||
|
assert node.progress == 95
|
||||||
|
assert "文本抽取 3/3" in node.message
|
||||||
|
assert "综述.txt" in node.message
|
||||||
|
assert WorkflowEvent.objects.filter(
|
||||||
|
workflow_type="regulatory_review",
|
||||||
|
workflow_batch_id=batch.pk,
|
||||||
|
event_type="node_progress",
|
||||||
|
payload__node_code="text_extract",
|
||||||
|
payload__processed=3,
|
||||||
|
payload__total=3,
|
||||||
|
).exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_review_services_emit_actual_workload_progress_callbacks(django_user_model):
|
||||||
|
from review_agent.regulatory_review.services.completeness_check import run_completeness_check
|
||||||
|
from review_agent.regulatory_review.services.consistency_check import FIELDS, run_consistency_check
|
||||||
|
from review_agent.regulatory_review.services.structure_check import run_structure_check
|
||||||
|
|
||||||
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
|
conversation = Conversation.objects.create(user=user, title="会话")
|
||||||
|
summary = FileSummaryBatch.objects.create(
|
||||||
|
conversation=conversation,
|
||||||
|
user=user,
|
||||||
|
batch_no="FS-OK",
|
||||||
|
status=FileSummaryBatch.Status.SUCCESS,
|
||||||
|
)
|
||||||
|
rule_set = {
|
||||||
|
"requirements": [
|
||||||
|
{"code": "r1", "title": "注册信息", "type": "required", "file_keywords": ["注册信息"]},
|
||||||
|
{"code": "r2", "title": "说明书", "type": "required", "file_keywords": ["说明书"]},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
completeness_updates = []
|
||||||
|
structure_updates = []
|
||||||
|
consistency_updates = []
|
||||||
|
|
||||||
|
run_completeness_check(summary, rule_set, progress_callback=completeness_updates.append)
|
||||||
|
run_structure_check({"注册信息.txt": "注册信息"}, rule_set, progress_callback=structure_updates.append)
|
||||||
|
run_consistency_check({"注册信息.txt": "产品名称:A"}, progress_callback=consistency_updates.append)
|
||||||
|
|
||||||
|
assert completeness_updates[-1]["processed"] == 2
|
||||||
|
assert completeness_updates[-1]["total"] == 2
|
||||||
|
assert completeness_updates[-1]["label"] == "说明书"
|
||||||
|
assert structure_updates[-1]["processed"] == 2
|
||||||
|
assert structure_updates[-1]["total"] == 2
|
||||||
|
assert consistency_updates[-1]["processed"] == len(FIELDS)
|
||||||
|
assert consistency_updates[-1]["total"] == len(FIELDS)
|
||||||
|
|||||||
Reference in New Issue
Block a user