From 945669b9c251d970cc504bb0722eb6dce5d2a8ff Mon Sep 17 00:00:00 2001 From: bruce Date: Sun, 7 Jun 2026 11:46:55 +0800 Subject: [PATCH] =?UTF-8?q?feat(regulatory):=20=E5=A2=9E=E5=8A=A0=E6=9D=A1?= =?UTF-8?q?=E4=BB=B6=E5=AD=97=E6=AE=B5LLM=E5=A4=8D=E6=A0=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../services/info_extract.py | 18 +- .../regulatory_review/services/llm_review.py | 175 ++++++++++++++++++ tests/test_regulatory_condition.py | 43 +++++ tests/test_regulatory_llm_review.py | 42 +++++ 4 files changed, 275 insertions(+), 3 deletions(-) create mode 100644 review_agent/regulatory_review/services/llm_review.py create mode 100644 tests/test_regulatory_llm_review.py diff --git a/review_agent/regulatory_review/services/info_extract.py b/review_agent/regulatory_review/services/info_extract.py index 29ebbe5..1a48820 100644 --- a/review_agent/regulatory_review/services/info_extract.py +++ b/review_agent/regulatory_review/services/info_extract.py @@ -5,6 +5,7 @@ from pathlib import Path from django.conf import settings from review_agent.models import FileSummaryBatch +from review_agent.regulatory_review.services.llm_review import review_condition_fields from review_agent.regulatory_review.services.text_extract import extract_text @@ -20,10 +21,14 @@ def detect_regulatory_condition_candidates(summary_batch: FileSummaryBatch) -> d corpus_parts = [summary_batch.product_name or ""] field_candidates: dict[str, str] = {} + field_sources: dict[str, str] = {} for item in summary_batch.items.order_by("file_index"): corpus_parts.extend([item.directory_level, item.file_name, item.relative_path]) - extracted = _extract_item_fields(item) + review = _extract_item_fields(item) + extracted = review.get("selected_fields", {}) + sources = review.get("selected_sources", {}) field_candidates.update({key: value for key, value in extracted.items() if value and key not in field_candidates}) + field_sources.update({key: value for key, value in sources.items() if value and key not in field_sources}) corpus_parts.extend(extracted.values()) corpus = "\n".join(part for part in corpus_parts if part) product_name = field_candidates.get("产品名称") or _safe_summary_product_name(summary_batch.product_name) @@ -51,21 +56,24 @@ def detect_regulatory_condition_candidates(summary_batch: FileSummaryBatch) -> d "label": "产品名称", "input_type": "text", "suggested": product_name, + "source": field_sources.get("产品名称", "summary" if product_name else ""), }, "model_spec": { "label": "型号规格", "input_type": "text", "suggested": field_candidates.get("型号规格", ""), + "source": field_sources.get("型号规格", ""), }, "intended_use": { "label": "预期用途", "input_type": "text", "suggested": field_candidates.get("预期用途", ""), + "source": field_sources.get("预期用途", ""), }, } -def _extract_item_fields(item) -> dict[str, str]: +def _extract_item_fields(item) -> dict[str, object]: path = Path(item.storage_path) if not path.is_absolute(): path = Path(settings.MEDIA_ROOT) / item.storage_path @@ -74,7 +82,11 @@ def _extract_item_fields(item) -> dict[str, str]: result = extract_text(path) if result.status != "success" or not result.field_candidates: return {} - return result.field_candidates + return review_condition_fields( + text=result.front_text or result.text, + rule_fields=result.field_candidates, + file_context=f"{item.directory_level}\n{item.file_name}\n{item.relative_path}", + ) def _safe_summary_product_name(product_name: str) -> str: diff --git a/review_agent/regulatory_review/services/llm_review.py b/review_agent/regulatory_review/services/llm_review.py new file mode 100644 index 0000000..f712ec9 --- /dev/null +++ b/review_agent/regulatory_review/services/llm_review.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import json +import re +from collections.abc import Callable +from typing import Any + +from review_agent.llm import LLMConfigurationError, LLMRequestError, generate_completion + + +FIELD_LABELS = ["产品名称", "型号规格", "预期用途", "管理类别", "分类编码", "注册类型", "临床评价路径"] +CompletionFunc = Callable[[list[dict[str, str]]], str] + + +def review_condition_fields( + *, + text: str, + rule_fields: dict[str, str], + file_context: str = "", + completion_func: Callable[..., str] | None = None, +) -> dict[str, Any]: + llm_fields: dict[str, str] = {} + status = "skipped" + error_message = "" + try: + raw = (completion_func or generate_completion)(_condition_messages(text, rule_fields, file_context), temperature=0.0) + payload = _parse_json_object(raw) + llm_fields = _clean_fields(payload.get("fields") or payload) + status = "success" + except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc: + status = "failed" + error_message = str(exc) + + selected_fields, selected_sources = _select_fields(rule_fields, llm_fields) + return { + "status": status, + "error_message": error_message, + "rule_fields": _clean_fields(rule_fields), + "llm_fields": llm_fields, + "selected_fields": selected_fields, + "selected_sources": selected_sources, + } + + +def review_workflow_payload( + *, + stage: str, + payload: dict[str, Any], + completion_func: Callable[..., str] | None = None, +) -> dict[str, Any]: + try: + raw = (completion_func or generate_completion)(_workflow_messages(stage, payload), temperature=0.0) + parsed = _parse_json_object(raw) + return { + "status": "success", + "stage": stage, + "result": parsed, + "error_message": "", + } + except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc: + return { + "status": "failed", + "stage": stage, + "result": {}, + "error_message": str(exc), + } + + +def _condition_messages(text: str, rule_fields: dict[str, str], file_context: str) -> list[dict[str, str]]: + return [ + { + "role": "system", + "content": ( + "你是NMPA注册资料字段复核助手。请从附件文本中提取最合理的字段值," + "只返回JSON,格式为 {\"fields\": {\"产品名称\": \"...\"}}。" + "产品名称应包含完整名称、检测对象和方法学括号;不要把章节标题当产品名称。" + ), + }, + { + "role": "user", + "content": json.dumps( + { + "file_context": file_context, + "rule_fields": rule_fields, + "text": text[:4000], + "allowed_fields": FIELD_LABELS, + }, + ensure_ascii=False, + ), + }, + ] + + +def _workflow_messages(stage: str, payload: dict[str, Any]) -> list[dict[str, str]]: + return [ + { + "role": "system", + "content": ( + "你是NMPA法规核查复核助手。请复核当前流程节点的规则结果," + "指出可能误判、漏判和更合理的建议。只返回JSON。" + ), + }, + { + "role": "user", + "content": json.dumps({"stage": stage, "payload": payload}, ensure_ascii=False)[:6000], + }, + ] + + +def _parse_json_object(raw: str) -> dict[str, Any]: + value = (raw or "").strip() + if value.startswith("```"): + value = re.sub(r"^```(?:json)?\s*", "", value) + value = re.sub(r"\s*```$", "", value) + start = value.find("{") + end = value.rfind("}") + if start >= 0 and end >= start: + value = value[start : end + 1] + parsed = json.loads(value) + if not isinstance(parsed, dict): + raise ValueError("LLM复核结果不是JSON对象。") + return parsed + + +def _clean_fields(fields: dict[str, Any]) -> dict[str, str]: + clean = {} + for label in FIELD_LABELS: + value = fields.get(label) + if not isinstance(value, str): + continue + normalized = " ".join(value.strip().split()) + if normalized: + clean[label] = normalized + return clean + + +def _select_fields(rule_fields: dict[str, str], llm_fields: dict[str, str]) -> tuple[dict[str, str], dict[str, str]]: + rule_clean = _clean_fields(rule_fields) + selected = {} + sources = {} + for label in FIELD_LABELS: + rule_value = rule_clean.get(label, "") + llm_value = llm_fields.get(label, "") + value, source = _select_field(label, rule_value, llm_value) + if value: + selected[label] = value + sources[label] = source + return selected, sources + + +def _select_field(label: str, rule_value: str, llm_value: str) -> tuple[str, str]: + if _invalid_field_value(llm_value): + return rule_value, "rule" if rule_value else "" + if not rule_value: + return llm_value, "llm" if llm_value else "" + if not llm_value: + return rule_value, "rule" + if label == "产品名称" and _better_product_name(llm_value, rule_value): + return llm_value, "llm" + if len(llm_value) > len(rule_value) * 1.35 and rule_value in llm_value: + return llm_value, "llm" + return rule_value, "rule" + + +def _better_product_name(candidate: str, current: str) -> bool: + if current and current in candidate and len(candidate) > len(current): + return True + product_keywords = ["试剂盒", "检测试剂", "荧光PCR法", "PCR法", "核酸检测"] + return len(candidate) > len(current) and any(keyword in candidate for keyword in product_keywords) + + +def _invalid_field_value(value: str) -> bool: + if not value: + return True + return any(keyword in value for keyword in ["第1章", "第2章", "第3章", "监管信息", "综述资料", "章节目录"]) diff --git a/tests/test_regulatory_condition.py b/tests/test_regulatory_condition.py index dfcbd28..e8bb232 100644 --- a/tests/test_regulatory_condition.py +++ b/tests/test_regulatory_condition.py @@ -117,6 +117,49 @@ def test_detect_regulatory_condition_keeps_wrapped_product_name(settings, tmp_pa assert candidates["model_spec"]["suggested"] == "24人份/盒" +def test_detect_regulatory_condition_uses_llm_review_for_better_product_name( + monkeypatch, settings, tmp_path, django_user_model +): + settings.MEDIA_ROOT = tmp_path + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + summary = FileSummaryBatch.objects.create( + conversation=conversation, + user=user, + batch_no="FS-COND", + status=FileSummaryBatch.Status.SUCCESS, + product_name="第1章 监管信息", + ) + application = tmp_path / "application.txt" + application.write_text( + "产品名称:呼吸道合胞病毒、肺炎支原体核酸检测试剂盒\n" + "型号规格:24人份/盒\n", + encoding="utf-8", + ) + FileSummaryItem.objects.create( + batch=summary, + file_index=1, + directory_level="1. 监管信息 / 1.2 申请表", + file_name="申请表.txt", + file_type="txt", + relative_path="1.监管信息/申请表.txt", + storage_path=str(application), + ) + + monkeypatch.setattr( + "review_agent.regulatory_review.services.llm_review.generate_completion", + lambda messages, temperature=0.0: json.dumps( + {"fields": {"产品名称": "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒 (荧光PCR法)"}}, + ensure_ascii=False, + ), + ) + + candidates = detect_regulatory_condition_candidates(summary) + + assert candidates["product_name"]["suggested"] == "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒 (荧光PCR法)" + assert candidates["product_name"]["source"] == "llm" + + def test_workflow_pauses_before_rule_scope_until_conditions_confirmed(settings, tmp_path, django_user_model): settings.MEDIA_ROOT = tmp_path user = django_user_model.objects.create_user(username="owner", password="pass") diff --git a/tests/test_regulatory_llm_review.py b/tests/test_regulatory_llm_review.py new file mode 100644 index 0000000..0d5ad6e --- /dev/null +++ b/tests/test_regulatory_llm_review.py @@ -0,0 +1,42 @@ +import json + +from review_agent.regulatory_review.services.llm_review import review_condition_fields + + +def test_review_condition_fields_selects_more_complete_llm_product_name(): + def completion(messages, temperature=0.0): + return json.dumps( + { + "fields": { + "产品名称": "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒 (荧光PCR法)", + "型号规格": "24人份/盒", + } + }, + ensure_ascii=False, + ) + + result = review_condition_fields( + text="产品名称:呼吸道合胞病毒、肺炎支原体核酸检测试剂盒\n(荧光PCR法)\n型号规格:24人份/盒", + rule_fields={"产品名称": "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒", "型号规格": "24人份/盒"}, + file_context="申请表.txt", + completion_func=completion, + ) + + assert result["selected_fields"]["产品名称"] == "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒 (荧光PCR法)" + assert result["selected_sources"]["产品名称"] == "llm" + assert result["selected_sources"]["型号规格"] == "rule" + + +def test_review_condition_fields_falls_back_when_llm_returns_chapter_title(): + def completion(messages, temperature=0.0): + return json.dumps({"fields": {"产品名称": "第1章 监管信息"}}, ensure_ascii=False) + + result = review_condition_fields( + text="产品名称:甲胎蛋白检测试剂盒", + rule_fields={"产品名称": "甲胎蛋白检测试剂盒"}, + file_context="申请表.txt", + completion_func=completion, + ) + + assert result["selected_fields"]["产品名称"] == "甲胎蛋白检测试剂盒" + assert result["selected_sources"]["产品名称"] == "rule"