from __future__ import annotations import json import re import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Callable from review_agent.llm import generate_completion from review_agent.regulatory_info_package.schemas import InstructionExtractResult FIELD_PATTERNS = { "product_name": ("产品名称", r"产品名称[::\s]*([^\n\r]+)"), "applicant_name": ("申请人名称", r"(?:申请人名称|注册人/售后服务单位名称|注册人名称|售后服务单位名称|生产企业名称)[::\s]*([^\n\r]+)"), "manufacturer_name": ("生产企业名称", r"生产企业名称[::\s]*([^\n\r]+)"), "applicant_address": ("申请人住所", r"(?:申请人住所|注册人住所|生产企业住所)[::\s]*([^\n\r]+)"), "applicant_contact": ("申请人联系方式", r"(?:联系方式|联系电话|电话)[::\s]*([^\n\r]+)"), "production_address": ("生产地址", r"生产地址[::\s]*([^\n\r]+)"), "storage_condition": ("储存条件", r"(?:储存条件|贮存条件|保存条件)[::\s]*([^\n\r]+)"), "intended_use": ("预期用途", r"预期用途[::\s]*([^\n\r]+)"), "package_specification": ("包装规格", r"(?:包装规格|规格)[::\s]*([^\n\r]+)"), "sample_type": ("样本类型", r"样本类型[::\s]*([^\n\r]+)"), "applicable_instrument": ("适用仪器", r"适用仪器[::\s]*([^\n\r]+)"), "standard_no": ("标准号", r"((?:GB|YY|WS|T/C[A-Z0-9]*)[ /T0-9.\-—]+)"), } def extract_fields_by_rules(instruction: InstructionExtractResult) -> dict[str, dict]: text = "\n".join([instruction.front_text, *instruction.paragraphs, *instruction.sections.values()]) results: dict[str, dict] = {} for key, (label, pattern) in FIELD_PATTERNS.items(): section_value = _value_after_label_paragraph(instruction.paragraphs, label) if section_value: results[key] = { "label": label, "value": section_value, "evidence": f"【{label}】\n{section_value}", "confidence": 0.82, "source": "rule", } continue match = re.search(pattern, text, flags=re.IGNORECASE) if match: value = _clean_value(match.group(1)) if value: results[key] = { "label": label, "value": value, "evidence": match.group(0)[:240], "confidence": 0.75, "source": "rule", } component_table = _best_component_table(instruction.component_tables) if component_table: results["component_table"] = { "label": "主要组成成分", "value": json.dumps(component_table, ensure_ascii=False), "evidence": "说明书【主要组成成分】表格", "confidence": 0.86, "source": "rule", } component_notes = _component_notes(instruction.sections) if component_notes: results["component_notes"] = { "label": "主要组成成分备注", "value": component_notes, "evidence": "说明书【主要组成成分】段落", "confidence": 0.8, "source": "rule", } return results def extract_fields_with_llm(instruction: InstructionExtractResult) -> dict[str, dict]: prompt = ( "请从体外诊断试剂产品说明书中抽取字段,输出 JSON 对象,字段包括 " "product_name、storage_condition、intended_use、package_specification、sample_type、applicable_instrument、standard_no。" "每个字段值为 {label,value,evidence,confidence}。\n\n" + instruction.front_text[:6000] ) raw = generate_completion([{"role": "user", "content": prompt}], temperature=0.0) payload = _parse_json_object(raw) return {key: value for key, value in payload.items() if isinstance(value, dict)} def run_llm_extract_with_retry( instruction: InstructionExtractResult, *, llm_extract_func: Callable[[InstructionExtractResult], dict[str, dict]] | None = None, sleep_func: Callable[[float], None] = time.sleep, ) -> dict[str, dict]: func = llm_extract_func or extract_fields_with_llm last_exc: Exception | None = None for delay in [0, 1, 2]: if delay: sleep_func(delay) try: return func(instruction) except Exception as exc: last_exc = exc if last_exc: raise last_exc return {} def run_parallel_extract( instruction: InstructionExtractResult, *, llm_extract_func: Callable[[InstructionExtractResult], dict[str, dict]] | None = None, ) -> dict: payload = {"regex_results": {}, "llm_results": {}, "llm_error": ""} with ThreadPoolExecutor(max_workers=2) as executor: rule_future = executor.submit(extract_fields_by_rules, instruction) llm_future = executor.submit(run_llm_extract_with_retry, instruction, llm_extract_func=llm_extract_func) payload["regex_results"] = rule_future.result() try: payload["llm_results"] = llm_future.result() except Exception as exc: payload["llm_error"] = str(exc) return payload def save_field_extract_result(path: str | Path, payload: dict) -> Path: target = Path(path) target.parent.mkdir(parents=True, exist_ok=True) target.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") return target def _clean_value(value: str) -> str: cleaned = value.strip() if cleaned in {"】", "】】", "】:"}: return "" return re.split(r"[。;;]", cleaned)[0].strip() def _value_after_label_paragraph(paragraphs: list[str], label: str) -> str: bracketed = {f"【{label}】", f"[{label}]", label} for index, text in enumerate(paragraphs): stripped = text.strip() if stripped in bracketed and index + 1 < len(paragraphs): return _clean_value(paragraphs[index + 1]) return "" def _parse_json_object(raw: str) -> dict: text = (raw or "").strip() if text.startswith("```"): text = text.strip("`").strip() if text.lower().startswith("json"): text = text[4:].strip() start = text.find("{") end = text.rfind("}") if start == -1 or end == -1: return {} return json.loads(text[start : end + 1]) def _best_component_table(component_tables: list[dict]) -> dict: if not component_tables: return {} return max(component_tables, key=lambda table: len(table.get("rows") or [])) def _component_notes(sections: dict[str, str]) -> str: for key, value in sections.items(): if "主要组成" in key: return value.strip() return ""