Files

172 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import re
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Callable
from review_agent.llm import generate_completion
from review_agent.regulatory_info_package.schemas import InstructionExtractResult
FIELD_PATTERNS = {
"product_name": ("产品名称", r"产品名称[:\s]*([^\n\r]+)"),
"applicant_name": ("申请人名称", r"(?:申请人名称|注册人/售后服务单位名称|注册人名称|售后服务单位名称|生产企业名称)[:\s]*([^\n\r]+)"),
"manufacturer_name": ("生产企业名称", r"生产企业名称[:\s]*([^\n\r]+)"),
"applicant_address": ("申请人住所", r"(?:申请人住所|注册人住所|生产企业住所)[:\s]*([^\n\r]+)"),
"applicant_contact": ("申请人联系方式", r"(?:联系方式|联系电话|电话)[:\s]*([^\n\r]+)"),
"production_address": ("生产地址", r"生产地址[:\s]*([^\n\r]+)"),
"storage_condition": ("储存条件", r"(?:储存条件|贮存条件|保存条件)[:\s]*([^\n\r]+)"),
"intended_use": ("预期用途", r"预期用途[:\s]*([^\n\r]+)"),
"package_specification": ("包装规格", r"(?:包装规格|规格)[:\s]*([^\n\r]+)"),
"sample_type": ("样本类型", r"样本类型[:\s]*([^\n\r]+)"),
"applicable_instrument": ("适用仪器", r"适用仪器[:\s]*([^\n\r]+)"),
"standard_no": ("标准号", r"((?:GB|YY|WS|T/C[A-Z0-9]*)[ /T0-9.\-—]+)"),
}
def extract_fields_by_rules(instruction: InstructionExtractResult) -> dict[str, dict]:
text = "\n".join([instruction.front_text, *instruction.paragraphs, *instruction.sections.values()])
results: dict[str, dict] = {}
for key, (label, pattern) in FIELD_PATTERNS.items():
section_value = _value_after_label_paragraph(instruction.paragraphs, label)
if section_value:
results[key] = {
"label": label,
"value": section_value,
"evidence": f"{label}\n{section_value}",
"confidence": 0.82,
"source": "rule",
}
continue
match = re.search(pattern, text, flags=re.IGNORECASE)
if match:
value = _clean_value(match.group(1))
if value:
results[key] = {
"label": label,
"value": value,
"evidence": match.group(0)[:240],
"confidence": 0.75,
"source": "rule",
}
component_table = _best_component_table(instruction.component_tables)
if component_table:
results["component_table"] = {
"label": "主要组成成分",
"value": json.dumps(component_table, ensure_ascii=False),
"evidence": "说明书【主要组成成分】表格",
"confidence": 0.86,
"source": "rule",
}
component_notes = _component_notes(instruction.sections)
if component_notes:
results["component_notes"] = {
"label": "主要组成成分备注",
"value": component_notes,
"evidence": "说明书【主要组成成分】段落",
"confidence": 0.8,
"source": "rule",
}
return results
def extract_fields_with_llm(instruction: InstructionExtractResult) -> dict[str, dict]:
prompt = (
"请从体外诊断试剂产品说明书中抽取字段,输出 JSON 对象,字段包括 "
"product_name、storage_condition、intended_use、package_specification、sample_type、applicable_instrument、standard_no。"
"每个字段值为 {label,value,evidence,confidence}。\n\n"
+ instruction.front_text[:6000]
)
raw = generate_completion([{"role": "user", "content": prompt}], temperature=0.0)
payload = _parse_json_object(raw)
return {key: value for key, value in payload.items() if isinstance(value, dict)}
def run_llm_extract_with_retry(
instruction: InstructionExtractResult,
*,
llm_extract_func: Callable[[InstructionExtractResult], dict[str, dict]] | None = None,
sleep_func: Callable[[float], None] = time.sleep,
) -> dict[str, dict]:
func = llm_extract_func or extract_fields_with_llm
last_exc: Exception | None = None
for delay in [0, 1, 2]:
if delay:
sleep_func(delay)
try:
return func(instruction)
except Exception as exc:
last_exc = exc
if last_exc:
raise last_exc
return {}
def run_parallel_extract(
instruction: InstructionExtractResult,
*,
llm_extract_func: Callable[[InstructionExtractResult], dict[str, dict]] | None = None,
) -> dict:
payload = {"regex_results": {}, "llm_results": {}, "llm_error": ""}
with ThreadPoolExecutor(max_workers=2) as executor:
rule_future = executor.submit(extract_fields_by_rules, instruction)
llm_future = executor.submit(run_llm_extract_with_retry, instruction, llm_extract_func=llm_extract_func)
payload["regex_results"] = rule_future.result()
try:
payload["llm_results"] = llm_future.result()
except Exception as exc:
payload["llm_error"] = str(exc)
return payload
def save_field_extract_result(path: str | Path, payload: dict) -> Path:
target = Path(path)
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
return target
def _clean_value(value: str) -> str:
cleaned = value.strip()
if cleaned in {"", "】】", "】:"}:
return ""
return re.split(r"[。;;]", cleaned)[0].strip()
def _value_after_label_paragraph(paragraphs: list[str], label: str) -> str:
bracketed = {f"{label}", f"[{label}]", label}
for index, text in enumerate(paragraphs):
stripped = text.strip()
if stripped in bracketed and index + 1 < len(paragraphs):
return _clean_value(paragraphs[index + 1])
return ""
def _parse_json_object(raw: str) -> dict:
text = (raw or "").strip()
if text.startswith("```"):
text = text.strip("`").strip()
if text.lower().startswith("json"):
text = text[4:].strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1:
return {}
return json.loads(text[start : end + 1])
def _best_component_table(component_tables: list[dict]) -> dict:
if not component_tables:
return {}
return max(component_tables, key=lambda table: len(table.get("rows") or []))
def _component_notes(sections: dict[str, str]) -> str:
for key, value in sections.items():
if "主要组成" in key:
return value.strip()
return ""