206 lines
7.0 KiB
Python
206 lines
7.0 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import re
|
||
from collections.abc import Callable
|
||
from typing import Any
|
||
|
||
from django.conf import settings
|
||
|
||
from review_agent.llm import LLMConfigurationError, LLMRequestError, generate_completion
|
||
|
||
|
||
FIELD_LABELS = ["产品名称", "型号规格", "预期用途", "管理类别", "分类编码", "注册类型", "临床评价路径"]
|
||
CompletionFunc = Callable[[list[dict[str, str]]], str]
|
||
|
||
|
||
def review_condition_fields(
|
||
*,
|
||
text: str,
|
||
rule_fields: dict[str, str],
|
||
file_context: str = "",
|
||
completion_func: Callable[..., str] | None = None,
|
||
) -> dict[str, Any]:
|
||
llm_fields: dict[str, str] = {}
|
||
status = "skipped"
|
||
error_message = ""
|
||
if not _should_call_llm(completion_func):
|
||
selected_fields, selected_sources = _select_fields(rule_fields, llm_fields)
|
||
return {
|
||
"status": status,
|
||
"error_message": error_message,
|
||
"rule_fields": _clean_fields(rule_fields),
|
||
"llm_fields": llm_fields,
|
||
"selected_fields": selected_fields,
|
||
"selected_sources": selected_sources,
|
||
}
|
||
try:
|
||
raw = (completion_func or generate_completion)(_condition_messages(text, rule_fields, file_context), temperature=0.0)
|
||
payload = _parse_json_object(raw)
|
||
llm_fields = _clean_fields(payload.get("fields") or payload)
|
||
status = "success"
|
||
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc:
|
||
status = "failed"
|
||
error_message = str(exc)
|
||
|
||
selected_fields, selected_sources = _select_fields(rule_fields, llm_fields)
|
||
return {
|
||
"status": status,
|
||
"error_message": error_message,
|
||
"rule_fields": _clean_fields(rule_fields),
|
||
"llm_fields": llm_fields,
|
||
"selected_fields": selected_fields,
|
||
"selected_sources": selected_sources,
|
||
}
|
||
|
||
|
||
def review_workflow_payload(
|
||
*,
|
||
stage: str,
|
||
payload: dict[str, Any],
|
||
completion_func: Callable[..., str] | None = None,
|
||
) -> dict[str, Any]:
|
||
if not _should_call_llm(completion_func):
|
||
return {
|
||
"status": "skipped",
|
||
"stage": stage,
|
||
"result": {},
|
||
"error_message": "",
|
||
}
|
||
try:
|
||
raw = (completion_func or generate_completion)(_workflow_messages(stage, payload), temperature=0.0)
|
||
parsed = _parse_json_object(raw)
|
||
return {
|
||
"status": "success",
|
||
"stage": stage,
|
||
"result": parsed,
|
||
"error_message": "",
|
||
}
|
||
except (LLMConfigurationError, LLMRequestError, json.JSONDecodeError, TypeError, ValueError) as exc:
|
||
return {
|
||
"status": "failed",
|
||
"stage": stage,
|
||
"result": {},
|
||
"error_message": str(exc),
|
||
}
|
||
|
||
|
||
def _condition_messages(text: str, rule_fields: dict[str, str], file_context: str) -> list[dict[str, str]]:
|
||
return [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"你是NMPA注册资料字段复核助手。请从附件文本中提取最合理的字段值,"
|
||
"只返回JSON,格式为 {\"fields\": {\"产品名称\": \"...\"}}。"
|
||
"产品名称应包含完整名称、检测对象和方法学括号;不要把章节标题当产品名称。"
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": json.dumps(
|
||
{
|
||
"file_context": file_context,
|
||
"rule_fields": rule_fields,
|
||
"text": text[:4000],
|
||
"allowed_fields": FIELD_LABELS,
|
||
},
|
||
ensure_ascii=False,
|
||
),
|
||
},
|
||
]
|
||
|
||
|
||
def _workflow_messages(stage: str, payload: dict[str, Any]) -> list[dict[str, str]]:
|
||
return [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"你是NMPA法规核查复核助手。请复核当前流程节点的规则结果,"
|
||
"指出可能误判、漏判和更合理的建议。只返回JSON。"
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": json.dumps({"stage": stage, "payload": payload}, ensure_ascii=False)[:6000],
|
||
},
|
||
]
|
||
|
||
|
||
def _parse_json_object(raw: str) -> dict[str, Any]:
|
||
value = (raw or "").strip()
|
||
if value.startswith("```"):
|
||
value = re.sub(r"^```(?:json)?\s*", "", value)
|
||
value = re.sub(r"\s*```$", "", value)
|
||
start = value.find("{")
|
||
end = value.rfind("}")
|
||
if start >= 0 and end >= start:
|
||
value = value[start : end + 1]
|
||
parsed = json.loads(value)
|
||
if not isinstance(parsed, dict):
|
||
raise ValueError("LLM复核结果不是JSON对象。")
|
||
return parsed
|
||
|
||
|
||
def _should_call_llm(completion_func: Callable[..., str] | None) -> bool:
|
||
if completion_func is not None:
|
||
return True
|
||
if os.environ.get("PYTEST_CURRENT_TEST") and not getattr(settings, "REGULATORY_LLM_REVIEW_ALLOW_TEST_CALLS", False):
|
||
return False
|
||
return bool(settings.LLM_API_KEY and settings.LLM_MODEL)
|
||
|
||
|
||
def _clean_fields(fields: dict[str, Any]) -> dict[str, str]:
|
||
clean = {}
|
||
for label in FIELD_LABELS:
|
||
value = fields.get(label)
|
||
if not isinstance(value, str):
|
||
continue
|
||
normalized = " ".join(value.strip().split()).replace("(", "(").replace(")", ")")
|
||
if normalized:
|
||
clean[label] = normalized
|
||
return clean
|
||
|
||
|
||
def _select_fields(rule_fields: dict[str, str], llm_fields: dict[str, str]) -> tuple[dict[str, str], dict[str, str]]:
|
||
rule_clean = _clean_fields(rule_fields)
|
||
selected = {}
|
||
sources = {}
|
||
for label in FIELD_LABELS:
|
||
rule_value = rule_clean.get(label, "")
|
||
llm_value = llm_fields.get(label, "")
|
||
value, source = _select_field(label, rule_value, llm_value)
|
||
if value:
|
||
selected[label] = value
|
||
sources[label] = source
|
||
return selected, sources
|
||
|
||
|
||
def _select_field(label: str, rule_value: str, llm_value: str) -> tuple[str, str]:
|
||
if _invalid_field_value(llm_value):
|
||
return rule_value, "rule" if rule_value else ""
|
||
if not rule_value:
|
||
return llm_value, "llm" if llm_value else ""
|
||
if not llm_value:
|
||
return rule_value, "rule"
|
||
if label == "产品名称" and _better_product_name(llm_value, rule_value):
|
||
return llm_value, "llm"
|
||
if len(llm_value) > len(rule_value) * 1.35 and rule_value in llm_value:
|
||
return llm_value, "llm"
|
||
return rule_value, "rule"
|
||
|
||
|
||
def _better_product_name(candidate: str, current: str) -> bool:
|
||
if current and current in candidate and len(candidate) > len(current):
|
||
return True
|
||
product_keywords = ["试剂盒", "检测试剂", "荧光PCR法", "PCR法", "核酸检测"]
|
||
return len(candidate) > len(current) and any(keyword in candidate for keyword in product_keywords)
|
||
|
||
|
||
def _invalid_field_value(value: str) -> bool:
|
||
if not value:
|
||
return True
|
||
if "<EFBFBD>" in value:
|
||
return True
|
||
return any(keyword in value for keyword in ["第1章", "第2章", "第3章", "监管信息", "综述资料", "章节目录"])
|