172 lines
6.7 KiB
Python
172 lines
6.7 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from pathlib import Path
|
||
from typing import Callable
|
||
|
||
from review_agent.llm import generate_completion
|
||
from review_agent.regulatory_info_package.schemas import InstructionExtractResult
|
||
|
||
|
||
FIELD_PATTERNS = {
|
||
"product_name": ("产品名称", r"产品名称[::\s]*([^\n\r]+)"),
|
||
"applicant_name": ("申请人名称", r"(?:申请人名称|注册人/售后服务单位名称|注册人名称|售后服务单位名称|生产企业名称)[::\s]*([^\n\r]+)"),
|
||
"manufacturer_name": ("生产企业名称", r"生产企业名称[::\s]*([^\n\r]+)"),
|
||
"applicant_address": ("申请人住所", r"(?:申请人住所|注册人住所|生产企业住所)[::\s]*([^\n\r]+)"),
|
||
"applicant_contact": ("申请人联系方式", r"(?:联系方式|联系电话|电话)[::\s]*([^\n\r]+)"),
|
||
"production_address": ("生产地址", r"生产地址[::\s]*([^\n\r]+)"),
|
||
"storage_condition": ("储存条件", r"(?:储存条件|贮存条件|保存条件)[::\s]*([^\n\r]+)"),
|
||
"intended_use": ("预期用途", r"预期用途[::\s]*([^\n\r]+)"),
|
||
"package_specification": ("包装规格", r"(?:包装规格|规格)[::\s]*([^\n\r]+)"),
|
||
"sample_type": ("样本类型", r"样本类型[::\s]*([^\n\r]+)"),
|
||
"applicable_instrument": ("适用仪器", r"适用仪器[::\s]*([^\n\r]+)"),
|
||
"standard_no": ("标准号", r"((?:GB|YY|WS|T/C[A-Z0-9]*)[ /T0-9.\-—]+)"),
|
||
}
|
||
|
||
|
||
def extract_fields_by_rules(instruction: InstructionExtractResult) -> dict[str, dict]:
|
||
text = "\n".join([instruction.front_text, *instruction.paragraphs, *instruction.sections.values()])
|
||
results: dict[str, dict] = {}
|
||
for key, (label, pattern) in FIELD_PATTERNS.items():
|
||
section_value = _value_after_label_paragraph(instruction.paragraphs, label)
|
||
if section_value:
|
||
results[key] = {
|
||
"label": label,
|
||
"value": section_value,
|
||
"evidence": f"【{label}】\n{section_value}",
|
||
"confidence": 0.82,
|
||
"source": "rule",
|
||
}
|
||
continue
|
||
match = re.search(pattern, text, flags=re.IGNORECASE)
|
||
if match:
|
||
value = _clean_value(match.group(1))
|
||
if value:
|
||
results[key] = {
|
||
"label": label,
|
||
"value": value,
|
||
"evidence": match.group(0)[:240],
|
||
"confidence": 0.75,
|
||
"source": "rule",
|
||
}
|
||
component_table = _best_component_table(instruction.component_tables)
|
||
if component_table:
|
||
results["component_table"] = {
|
||
"label": "主要组成成分",
|
||
"value": json.dumps(component_table, ensure_ascii=False),
|
||
"evidence": "说明书【主要组成成分】表格",
|
||
"confidence": 0.86,
|
||
"source": "rule",
|
||
}
|
||
component_notes = _component_notes(instruction.sections)
|
||
if component_notes:
|
||
results["component_notes"] = {
|
||
"label": "主要组成成分备注",
|
||
"value": component_notes,
|
||
"evidence": "说明书【主要组成成分】段落",
|
||
"confidence": 0.8,
|
||
"source": "rule",
|
||
}
|
||
return results
|
||
|
||
|
||
def extract_fields_with_llm(instruction: InstructionExtractResult) -> dict[str, dict]:
|
||
prompt = (
|
||
"请从体外诊断试剂产品说明书中抽取字段,输出 JSON 对象,字段包括 "
|
||
"product_name、storage_condition、intended_use、package_specification、sample_type、applicable_instrument、standard_no。"
|
||
"每个字段值为 {label,value,evidence,confidence}。\n\n"
|
||
+ instruction.front_text[:6000]
|
||
)
|
||
raw = generate_completion([{"role": "user", "content": prompt}], temperature=0.0)
|
||
payload = _parse_json_object(raw)
|
||
return {key: value for key, value in payload.items() if isinstance(value, dict)}
|
||
|
||
|
||
def run_llm_extract_with_retry(
|
||
instruction: InstructionExtractResult,
|
||
*,
|
||
llm_extract_func: Callable[[InstructionExtractResult], dict[str, dict]] | None = None,
|
||
sleep_func: Callable[[float], None] = time.sleep,
|
||
) -> dict[str, dict]:
|
||
func = llm_extract_func or extract_fields_with_llm
|
||
last_exc: Exception | None = None
|
||
for delay in [0, 1, 2]:
|
||
if delay:
|
||
sleep_func(delay)
|
||
try:
|
||
return func(instruction)
|
||
except Exception as exc:
|
||
last_exc = exc
|
||
if last_exc:
|
||
raise last_exc
|
||
return {}
|
||
|
||
|
||
def run_parallel_extract(
|
||
instruction: InstructionExtractResult,
|
||
*,
|
||
llm_extract_func: Callable[[InstructionExtractResult], dict[str, dict]] | None = None,
|
||
) -> dict:
|
||
payload = {"regex_results": {}, "llm_results": {}, "llm_error": ""}
|
||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||
rule_future = executor.submit(extract_fields_by_rules, instruction)
|
||
llm_future = executor.submit(run_llm_extract_with_retry, instruction, llm_extract_func=llm_extract_func)
|
||
payload["regex_results"] = rule_future.result()
|
||
try:
|
||
payload["llm_results"] = llm_future.result()
|
||
except Exception as exc:
|
||
payload["llm_error"] = str(exc)
|
||
return payload
|
||
|
||
|
||
def save_field_extract_result(path: str | Path, payload: dict) -> Path:
|
||
target = Path(path)
|
||
target.parent.mkdir(parents=True, exist_ok=True)
|
||
target.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
return target
|
||
|
||
|
||
def _clean_value(value: str) -> str:
|
||
cleaned = value.strip()
|
||
if cleaned in {"】", "】】", "】:"}:
|
||
return ""
|
||
return re.split(r"[。;;]", cleaned)[0].strip()
|
||
|
||
|
||
def _value_after_label_paragraph(paragraphs: list[str], label: str) -> str:
|
||
bracketed = {f"【{label}】", f"[{label}]", label}
|
||
for index, text in enumerate(paragraphs):
|
||
stripped = text.strip()
|
||
if stripped in bracketed and index + 1 < len(paragraphs):
|
||
return _clean_value(paragraphs[index + 1])
|
||
return ""
|
||
|
||
|
||
def _parse_json_object(raw: str) -> dict:
|
||
text = (raw or "").strip()
|
||
if text.startswith("```"):
|
||
text = text.strip("`").strip()
|
||
if text.lower().startswith("json"):
|
||
text = text[4:].strip()
|
||
start = text.find("{")
|
||
end = text.rfind("}")
|
||
if start == -1 or end == -1:
|
||
return {}
|
||
return json.loads(text[start : end + 1])
|
||
|
||
|
||
def _best_component_table(component_tables: list[dict]) -> dict:
|
||
if not component_tables:
|
||
return {}
|
||
return max(component_tables, key=lambda table: len(table.get("rows") or []))
|
||
|
||
|
||
def _component_notes(sections: dict[str, str]) -> str:
|
||
for key, value in sections.items():
|
||
if "主要组成" in key:
|
||
return value.strip()
|
||
return ""
|