77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import re
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
from .rag_index import extract_text_from_path
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class ExtractedText:
|
||
path: Path
|
||
text: str
|
||
status: str
|
||
content_hash: str = ""
|
||
error_message: str = ""
|
||
front_text: str = ""
|
||
section_candidates: list[str] | None = None
|
||
field_candidates: dict[str, str] | None = None
|
||
|
||
|
||
SUPPORTED_EXTENSIONS = {".txt", ".md", ".pdf", ".docx", ".pptx", ".xlsx", ".doc"}
|
||
|
||
|
||
def extract_text(path: str | Path) -> ExtractedText:
|
||
file_path = Path(path)
|
||
if file_path.suffix.lower() not in SUPPORTED_EXTENSIONS:
|
||
return ExtractedText(path=file_path, text="", status="unsupported")
|
||
try:
|
||
text = extract_text_from_path(file_path)
|
||
except Exception as exc:
|
||
return ExtractedText(
|
||
path=file_path,
|
||
text="",
|
||
status="failed",
|
||
error_message=str(exc),
|
||
section_candidates=[],
|
||
field_candidates={},
|
||
)
|
||
content_hash = hashlib.sha256(text.encode("utf-8")).hexdigest() if text else ""
|
||
return ExtractedText(
|
||
path=file_path,
|
||
text=text,
|
||
status="success",
|
||
content_hash=content_hash,
|
||
front_text=_front_text(text),
|
||
section_candidates=_section_candidates(text),
|
||
field_candidates=_field_candidates(text),
|
||
)
|
||
|
||
|
||
def _front_text(text: str, limit: int = 1200) -> str:
|
||
return text[:limit]
|
||
|
||
|
||
def _section_candidates(text: str) -> list[str]:
|
||
candidates = []
|
||
for line in text.splitlines():
|
||
normalized = line.strip()
|
||
if not normalized:
|
||
continue
|
||
if re.match(r"^([一二三四五六七八九十]+[、..]|[0-9]+(\.[0-9]+)*[、..\s])", normalized):
|
||
candidates.append(normalized[:120])
|
||
elif any(keyword in normalized for keyword in ["章节目录", "监管信息", "综述资料", "非临床资料", "临床评价资料", "质量管理体系"]):
|
||
candidates.append(normalized[:120])
|
||
return candidates[:80]
|
||
|
||
|
||
def _field_candidates(text: str) -> dict[str, str]:
|
||
fields = {}
|
||
for label in ["产品名称", "型号规格", "预期用途", "管理类别", "分类编码", "注册类型", "临床评价路径"]:
|
||
match = re.search(rf"{label}[::]\s*([^\n\r]+)", text)
|
||
if match:
|
||
fields[label] = " ".join(match.group(1).strip().split())
|
||
return fields
|