102 lines
3.4 KiB
Python
102 lines
3.4 KiB
Python
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import re
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
from .rag_index import extract_text_from_path
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class ExtractedText:
|
||
path: Path
|
||
text: str
|
||
status: str
|
||
content_hash: str = ""
|
||
error_message: str = ""
|
||
front_text: str = ""
|
||
section_candidates: list[str] | None = None
|
||
field_candidates: dict[str, str] | None = None
|
||
|
||
|
||
SUPPORTED_EXTENSIONS = {".txt", ".md", ".pdf", ".docx", ".pptx", ".xlsx", ".doc"}
|
||
FIELD_LABELS = ["产品名称", "型号规格", "预期用途", "管理类别", "分类编码", "注册类型", "临床评价路径"]
|
||
|
||
|
||
def extract_text(path: str | Path) -> ExtractedText:
|
||
file_path = Path(path)
|
||
if file_path.suffix.lower() not in SUPPORTED_EXTENSIONS:
|
||
return ExtractedText(path=file_path, text="", status="unsupported")
|
||
try:
|
||
text = extract_text_from_path(file_path)
|
||
except Exception as exc:
|
||
return ExtractedText(
|
||
path=file_path,
|
||
text="",
|
||
status="failed",
|
||
error_message=str(exc),
|
||
section_candidates=[],
|
||
field_candidates={},
|
||
)
|
||
content_hash = hashlib.sha256(text.encode("utf-8")).hexdigest() if text else ""
|
||
return ExtractedText(
|
||
path=file_path,
|
||
text=text,
|
||
status="success",
|
||
content_hash=content_hash,
|
||
front_text=_front_text(text),
|
||
section_candidates=_section_candidates(text),
|
||
field_candidates=_field_candidates(text),
|
||
)
|
||
|
||
|
||
def _front_text(text: str, limit: int = 1200) -> str:
|
||
return text[:limit]
|
||
|
||
|
||
def _section_candidates(text: str) -> list[str]:
|
||
candidates = []
|
||
for line in text.splitlines():
|
||
normalized = line.strip()
|
||
if not normalized:
|
||
continue
|
||
if re.match(r"^([一二三四五六七八九十]+[、..]|[0-9]+(\.[0-9]+)*[、..\s])", normalized):
|
||
candidates.append(normalized[:120])
|
||
elif any(keyword in normalized for keyword in ["章节目录", "监管信息", "综述资料", "非临床资料", "临床评价资料", "质量管理体系"]):
|
||
candidates.append(normalized[:120])
|
||
return candidates[:80]
|
||
|
||
|
||
def _field_candidates(text: str) -> dict[str, str]:
|
||
fields = {}
|
||
lines = text.splitlines()
|
||
for index, line in enumerate(lines):
|
||
normalized = line.strip()
|
||
if not normalized:
|
||
continue
|
||
for label in FIELD_LABELS:
|
||
match = re.match(rf"^{re.escape(label)}[::]\s*(.*)$", normalized)
|
||
if not match or label in fields:
|
||
continue
|
||
value_parts = [match.group(1).strip()]
|
||
for next_line in lines[index + 1 :]:
|
||
continuation = next_line.strip()
|
||
if not continuation or _starts_field_line(continuation) or _looks_like_section_heading(continuation):
|
||
break
|
||
value_parts.append(continuation)
|
||
value = " ".join(part for part in value_parts if part)
|
||
if value:
|
||
fields[label] = " ".join(value.split())
|
||
return fields
|
||
|
||
|
||
def _starts_field_line(line: str) -> bool:
|
||
if any(re.match(rf"^{re.escape(label)}[::]", line) for label in FIELD_LABELS):
|
||
return True
|
||
return bool(re.match(r"^[^\s::]{2,24}[::]", line))
|
||
|
||
|
||
def _looks_like_section_heading(line: str) -> bool:
|
||
return bool(re.match(r"^([一二三四五六七八九十]+[、..]|[0-9]+(\.[0-9]+)*[、..\s])", line))
|