feat(regulatory): 增加条件字段LLM复核
This commit is contained in:
@@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
|
||||||
from review_agent.models import FileSummaryBatch
|
from review_agent.models import FileSummaryBatch
|
||||||
|
from review_agent.regulatory_review.services.llm_review import review_condition_fields
|
||||||
from review_agent.regulatory_review.services.text_extract import extract_text
|
from review_agent.regulatory_review.services.text_extract import extract_text
|
||||||
|
|
||||||
|
|
||||||
@@ -20,10 +21,14 @@ def detect_regulatory_condition_candidates(summary_batch: FileSummaryBatch) -> d
|
|||||||
|
|
||||||
corpus_parts = [summary_batch.product_name or ""]
|
corpus_parts = [summary_batch.product_name or ""]
|
||||||
field_candidates: dict[str, str] = {}
|
field_candidates: dict[str, str] = {}
|
||||||
|
field_sources: dict[str, str] = {}
|
||||||
for item in summary_batch.items.order_by("file_index"):
|
for item in summary_batch.items.order_by("file_index"):
|
||||||
corpus_parts.extend([item.directory_level, item.file_name, item.relative_path])
|
corpus_parts.extend([item.directory_level, item.file_name, item.relative_path])
|
||||||
extracted = _extract_item_fields(item)
|
review = _extract_item_fields(item)
|
||||||
|
extracted = review.get("selected_fields", {})
|
||||||
|
sources = review.get("selected_sources", {})
|
||||||
field_candidates.update({key: value for key, value in extracted.items() if value and key not in field_candidates})
|
field_candidates.update({key: value for key, value in extracted.items() if value and key not in field_candidates})
|
||||||
|
field_sources.update({key: value for key, value in sources.items() if value and key not in field_sources})
|
||||||
corpus_parts.extend(extracted.values())
|
corpus_parts.extend(extracted.values())
|
||||||
corpus = "\n".join(part for part in corpus_parts if part)
|
corpus = "\n".join(part for part in corpus_parts if part)
|
||||||
product_name = field_candidates.get("产品名称") or _safe_summary_product_name(summary_batch.product_name)
|
product_name = field_candidates.get("产品名称") or _safe_summary_product_name(summary_batch.product_name)
|
||||||
@@ -51,21 +56,24 @@ def detect_regulatory_condition_candidates(summary_batch: FileSummaryBatch) -> d
|
|||||||
"label": "产品名称",
|
"label": "产品名称",
|
||||||
"input_type": "text",
|
"input_type": "text",
|
||||||
"suggested": product_name,
|
"suggested": product_name,
|
||||||
|
"source": field_sources.get("产品名称", "summary" if product_name else ""),
|
||||||
},
|
},
|
||||||
"model_spec": {
|
"model_spec": {
|
||||||
"label": "型号规格",
|
"label": "型号规格",
|
||||||
"input_type": "text",
|
"input_type": "text",
|
||||||
"suggested": field_candidates.get("型号规格", ""),
|
"suggested": field_candidates.get("型号规格", ""),
|
||||||
|
"source": field_sources.get("型号规格", ""),
|
||||||
},
|
},
|
||||||
"intended_use": {
|
"intended_use": {
|
||||||
"label": "预期用途",
|
"label": "预期用途",
|
||||||
"input_type": "text",
|
"input_type": "text",
|
||||||
"suggested": field_candidates.get("预期用途", ""),
|
"suggested": field_candidates.get("预期用途", ""),
|
||||||
|
"source": field_sources.get("预期用途", ""),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _extract_item_fields(item) -> dict[str, str]:
|
def _extract_item_fields(item) -> dict[str, object]:
|
||||||
path = Path(item.storage_path)
|
path = Path(item.storage_path)
|
||||||
if not path.is_absolute():
|
if not path.is_absolute():
|
||||||
path = Path(settings.MEDIA_ROOT) / item.storage_path
|
path = Path(settings.MEDIA_ROOT) / item.storage_path
|
||||||
@@ -74,7 +82,11 @@ def _extract_item_fields(item) -> dict[str, str]:
|
|||||||
result = extract_text(path)
|
result = extract_text(path)
|
||||||
if result.status != "success" or not result.field_candidates:
|
if result.status != "success" or not result.field_candidates:
|
||||||
return {}
|
return {}
|
||||||
return result.field_candidates
|
return review_condition_fields(
|
||||||
|
text=result.front_text or result.text,
|
||||||
|
rule_fields=result.field_candidates,
|
||||||
|
file_context=f"{item.directory_level}\n{item.file_name}\n{item.relative_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _safe_summary_product_name(product_name: str) -> str:
|
def _safe_summary_product_name(product_name: str) -> str:
|
||||||
|
|||||||
175
review_agent/regulatory_review/services/llm_review.py
Normal file
175
review_agent/regulatory_review/services/llm_review.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from review_agent.llm import LLMConfigurationError, LLMRequestError, generate_completion
|
||||||
|
|
||||||
|
|
||||||
|
FIELD_LABELS = ["产品名称", "型号规格", "预期用途", "管理类别", "分类编码", "注册类型", "临床评价路径"]
|
||||||
|
CompletionFunc = Callable[[list[dict[str, str]]], str]
|
||||||
|
|
||||||
|
|
||||||
|
def review_condition_fields(
|
||||||
|
*,
|
||||||
|
text: str,
|
||||||
|
rule_fields: dict[str, str],
|
||||||
|
file_context: str = "",
|
||||||
|
completion_func: Callable[..., str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
llm_fields: dict[str, str] = {}
|
||||||
|
status = "skipped"
|
||||||
|
error_message = ""
|
||||||
|
try:
|
||||||
|
raw = (completion_func or generate_completion)(_condition_messages(text, rule_fields, file_context), temperature=0.0)
|
||||||
|
payload = _parse_json_object(raw)
|
||||||
|
llm_fields = _clean_fields(payload.get("fields") or payload)
|
||||||
|
status = "success"
|
||||||
|
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc:
|
||||||
|
status = "failed"
|
||||||
|
error_message = str(exc)
|
||||||
|
|
||||||
|
selected_fields, selected_sources = _select_fields(rule_fields, llm_fields)
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"error_message": error_message,
|
||||||
|
"rule_fields": _clean_fields(rule_fields),
|
||||||
|
"llm_fields": llm_fields,
|
||||||
|
"selected_fields": selected_fields,
|
||||||
|
"selected_sources": selected_sources,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def review_workflow_payload(
|
||||||
|
*,
|
||||||
|
stage: str,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
completion_func: Callable[..., str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
try:
|
||||||
|
raw = (completion_func or generate_completion)(_workflow_messages(stage, payload), temperature=0.0)
|
||||||
|
parsed = _parse_json_object(raw)
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"stage": stage,
|
||||||
|
"result": parsed,
|
||||||
|
"error_message": "",
|
||||||
|
}
|
||||||
|
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc:
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"stage": stage,
|
||||||
|
"result": {},
|
||||||
|
"error_message": str(exc),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _condition_messages(text: str, rule_fields: dict[str, str], file_context: str) -> list[dict[str, str]]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"你是NMPA注册资料字段复核助手。请从附件文本中提取最合理的字段值,"
|
||||||
|
"只返回JSON,格式为 {\"fields\": {\"产品名称\": \"...\"}}。"
|
||||||
|
"产品名称应包含完整名称、检测对象和方法学括号;不要把章节标题当产品名称。"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": json.dumps(
|
||||||
|
{
|
||||||
|
"file_context": file_context,
|
||||||
|
"rule_fields": rule_fields,
|
||||||
|
"text": text[:4000],
|
||||||
|
"allowed_fields": FIELD_LABELS,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _workflow_messages(stage: str, payload: dict[str, Any]) -> list[dict[str, str]]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"你是NMPA法规核查复核助手。请复核当前流程节点的规则结果,"
|
||||||
|
"指出可能误判、漏判和更合理的建议。只返回JSON。"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": json.dumps({"stage": stage, "payload": payload}, ensure_ascii=False)[:6000],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_object(raw: str) -> dict[str, Any]:
|
||||||
|
value = (raw or "").strip()
|
||||||
|
if value.startswith("```"):
|
||||||
|
value = re.sub(r"^```(?:json)?\s*", "", value)
|
||||||
|
value = re.sub(r"\s*```$", "", value)
|
||||||
|
start = value.find("{")
|
||||||
|
end = value.rfind("}")
|
||||||
|
if start >= 0 and end >= start:
|
||||||
|
value = value[start : end + 1]
|
||||||
|
parsed = json.loads(value)
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError("LLM复核结果不是JSON对象。")
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_fields(fields: dict[str, Any]) -> dict[str, str]:
|
||||||
|
clean = {}
|
||||||
|
for label in FIELD_LABELS:
|
||||||
|
value = fields.get(label)
|
||||||
|
if not isinstance(value, str):
|
||||||
|
continue
|
||||||
|
normalized = " ".join(value.strip().split())
|
||||||
|
if normalized:
|
||||||
|
clean[label] = normalized
|
||||||
|
return clean
|
||||||
|
|
||||||
|
|
||||||
|
def _select_fields(rule_fields: dict[str, str], llm_fields: dict[str, str]) -> tuple[dict[str, str], dict[str, str]]:
|
||||||
|
rule_clean = _clean_fields(rule_fields)
|
||||||
|
selected = {}
|
||||||
|
sources = {}
|
||||||
|
for label in FIELD_LABELS:
|
||||||
|
rule_value = rule_clean.get(label, "")
|
||||||
|
llm_value = llm_fields.get(label, "")
|
||||||
|
value, source = _select_field(label, rule_value, llm_value)
|
||||||
|
if value:
|
||||||
|
selected[label] = value
|
||||||
|
sources[label] = source
|
||||||
|
return selected, sources
|
||||||
|
|
||||||
|
|
||||||
|
def _select_field(label: str, rule_value: str, llm_value: str) -> tuple[str, str]:
|
||||||
|
if _invalid_field_value(llm_value):
|
||||||
|
return rule_value, "rule" if rule_value else ""
|
||||||
|
if not rule_value:
|
||||||
|
return llm_value, "llm" if llm_value else ""
|
||||||
|
if not llm_value:
|
||||||
|
return rule_value, "rule"
|
||||||
|
if label == "产品名称" and _better_product_name(llm_value, rule_value):
|
||||||
|
return llm_value, "llm"
|
||||||
|
if len(llm_value) > len(rule_value) * 1.35 and rule_value in llm_value:
|
||||||
|
return llm_value, "llm"
|
||||||
|
return rule_value, "rule"
|
||||||
|
|
||||||
|
|
||||||
|
def _better_product_name(candidate: str, current: str) -> bool:
|
||||||
|
if current and current in candidate and len(candidate) > len(current):
|
||||||
|
return True
|
||||||
|
product_keywords = ["试剂盒", "检测试剂", "荧光PCR法", "PCR法", "核酸检测"]
|
||||||
|
return len(candidate) > len(current) and any(keyword in candidate for keyword in product_keywords)
|
||||||
|
|
||||||
|
|
||||||
|
def _invalid_field_value(value: str) -> bool:
|
||||||
|
if not value:
|
||||||
|
return True
|
||||||
|
return any(keyword in value for keyword in ["第1章", "第2章", "第3章", "监管信息", "综述资料", "章节目录"])
|
||||||
@@ -117,6 +117,49 @@ def test_detect_regulatory_condition_keeps_wrapped_product_name(settings, tmp_pa
|
|||||||
assert candidates["model_spec"]["suggested"] == "24人份/盒"
|
assert candidates["model_spec"]["suggested"] == "24人份/盒"
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_regulatory_condition_uses_llm_review_for_better_product_name(
|
||||||
|
monkeypatch, settings, tmp_path, django_user_model
|
||||||
|
):
|
||||||
|
settings.MEDIA_ROOT = tmp_path
|
||||||
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
|
conversation = Conversation.objects.create(user=user, title="会话")
|
||||||
|
summary = FileSummaryBatch.objects.create(
|
||||||
|
conversation=conversation,
|
||||||
|
user=user,
|
||||||
|
batch_no="FS-COND",
|
||||||
|
status=FileSummaryBatch.Status.SUCCESS,
|
||||||
|
product_name="第1章 监管信息",
|
||||||
|
)
|
||||||
|
application = tmp_path / "application.txt"
|
||||||
|
application.write_text(
|
||||||
|
"产品名称:呼吸道合胞病毒、肺炎支原体核酸检测试剂盒\n"
|
||||||
|
"型号规格:24人份/盒\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
FileSummaryItem.objects.create(
|
||||||
|
batch=summary,
|
||||||
|
file_index=1,
|
||||||
|
directory_level="1. 监管信息 / 1.2 申请表",
|
||||||
|
file_name="申请表.txt",
|
||||||
|
file_type="txt",
|
||||||
|
relative_path="1.监管信息/申请表.txt",
|
||||||
|
storage_path=str(application),
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"review_agent.regulatory_review.services.llm_review.generate_completion",
|
||||||
|
lambda messages, temperature=0.0: json.dumps(
|
||||||
|
{"fields": {"产品名称": "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒 (荧光PCR法)"}},
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
candidates = detect_regulatory_condition_candidates(summary)
|
||||||
|
|
||||||
|
assert candidates["product_name"]["suggested"] == "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒 (荧光PCR法)"
|
||||||
|
assert candidates["product_name"]["source"] == "llm"
|
||||||
|
|
||||||
|
|
||||||
def test_workflow_pauses_before_rule_scope_until_conditions_confirmed(settings, tmp_path, django_user_model):
|
def test_workflow_pauses_before_rule_scope_until_conditions_confirmed(settings, tmp_path, django_user_model):
|
||||||
settings.MEDIA_ROOT = tmp_path
|
settings.MEDIA_ROOT = tmp_path
|
||||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
|
|||||||
42
tests/test_regulatory_llm_review.py
Normal file
42
tests/test_regulatory_llm_review.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from review_agent.regulatory_review.services.llm_review import review_condition_fields
|
||||||
|
|
||||||
|
|
||||||
|
def test_review_condition_fields_selects_more_complete_llm_product_name():
|
||||||
|
def completion(messages, temperature=0.0):
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"fields": {
|
||||||
|
"产品名称": "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒 (荧光PCR法)",
|
||||||
|
"型号规格": "24人份/盒",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = review_condition_fields(
|
||||||
|
text="产品名称:呼吸道合胞病毒、肺炎支原体核酸检测试剂盒\n(荧光PCR法)\n型号规格:24人份/盒",
|
||||||
|
rule_fields={"产品名称": "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒", "型号规格": "24人份/盒"},
|
||||||
|
file_context="申请表.txt",
|
||||||
|
completion_func=completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["selected_fields"]["产品名称"] == "呼吸道合胞病毒、肺炎支原体核酸检测试剂盒 (荧光PCR法)"
|
||||||
|
assert result["selected_sources"]["产品名称"] == "llm"
|
||||||
|
assert result["selected_sources"]["型号规格"] == "rule"
|
||||||
|
|
||||||
|
|
||||||
|
def test_review_condition_fields_falls_back_when_llm_returns_chapter_title():
|
||||||
|
def completion(messages, temperature=0.0):
|
||||||
|
return json.dumps({"fields": {"产品名称": "第1章 监管信息"}}, ensure_ascii=False)
|
||||||
|
|
||||||
|
result = review_condition_fields(
|
||||||
|
text="产品名称:甲胎蛋白检测试剂盒",
|
||||||
|
rule_fields={"产品名称": "甲胎蛋白检测试剂盒"},
|
||||||
|
file_context="申请表.txt",
|
||||||
|
completion_func=completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["selected_fields"]["产品名称"] == "甲胎蛋白检测试剂盒"
|
||||||
|
assert result["selected_sources"]["产品名称"] == "rule"
|
||||||
Reference in New Issue
Block a user