185 lines
6.2 KiB
Python
185 lines
6.2 KiB
Python
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import logging
|
||
import subprocess
|
||
import tempfile
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
from django.conf import settings
|
||
from docx import Document
|
||
from docx.oxml.table import CT_Tbl
|
||
from docx.oxml.text.paragraph import CT_P
|
||
from docx.table import Table
|
||
from docx.text.paragraph import Paragraph
|
||
from openpyxl import load_workbook
|
||
from pypdf import PdfReader
|
||
from pptx import Presentation
|
||
|
||
from .rag_embedding import EmbeddingFunction
|
||
|
||
|
||
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class TextChunk:
|
||
text: str
|
||
metadata: dict[str, object]
|
||
|
||
|
||
def chunk_text(text: str, *, source: str, chunk_size: int = 900, overlap: int = 120) -> list[TextChunk]:
|
||
normalized = "\n".join(line.strip() for line in text.splitlines() if line.strip())
|
||
if not normalized:
|
||
return []
|
||
chunks = []
|
||
start = 0
|
||
index = 0
|
||
step = max(1, chunk_size - overlap)
|
||
while start < len(normalized):
|
||
part = normalized[start : start + chunk_size].strip()
|
||
if part:
|
||
chunks.append(TextChunk(text=part, metadata={"source": source, "chunk_index": index}))
|
||
index += 1
|
||
start += step
|
||
return chunks
|
||
|
||
|
||
def extract_text_from_path(path: Path) -> str:
|
||
suffix = path.suffix.lower()
|
||
if suffix in {".txt", ".md"}:
|
||
return path.read_text(encoding="utf-8", errors="ignore")
|
||
if suffix == ".pdf":
|
||
return "\n".join(page.extract_text() or "" for page in PdfReader(str(path)).pages)
|
||
if suffix == ".docx":
|
||
return _extract_docx_text(path)
|
||
if suffix == ".pptx":
|
||
presentation = Presentation(str(path))
|
||
lines = []
|
||
for slide in presentation.slides:
|
||
for shape in slide.shapes:
|
||
if hasattr(shape, "text"):
|
||
lines.append(shape.text)
|
||
return "\n".join(lines)
|
||
if suffix == ".xlsx":
|
||
workbook = load_workbook(path, data_only=True, read_only=True)
|
||
lines = []
|
||
for sheet in workbook.worksheets:
|
||
for row in sheet.iter_rows(values_only=True):
|
||
values = [str(cell) for cell in row if cell not in {None, ""}]
|
||
if values:
|
||
lines.append("\t".join(values))
|
||
return "\n".join(lines)
|
||
if suffix == ".doc":
|
||
return _extract_legacy_doc_with_libreoffice(path)
|
||
return ""
|
||
|
||
|
||
def _extract_docx_text(path: Path) -> str:
|
||
document = Document(str(path))
|
||
lines: list[str] = []
|
||
for block in _iter_docx_blocks(document):
|
||
if isinstance(block, Paragraph):
|
||
text = block.text.strip()
|
||
if text:
|
||
lines.append(text)
|
||
elif isinstance(block, Table):
|
||
for row in block.rows:
|
||
values = [cell.text.strip() for cell in row.cells if cell.text.strip()]
|
||
if values:
|
||
lines.append("\t".join(values))
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _iter_docx_blocks(document):
|
||
body = document.element.body
|
||
for child in body.iterchildren():
|
||
if isinstance(child, CT_P):
|
||
yield Paragraph(child, document)
|
||
elif isinstance(child, CT_Tbl):
|
||
yield Table(child, document)
|
||
|
||
|
||
def _extract_legacy_doc_with_libreoffice(path: Path) -> str:
|
||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||
target_dir = Path(tmp_dir)
|
||
try:
|
||
subprocess.run(
|
||
[
|
||
"soffice",
|
||
"--headless",
|
||
"--convert-to",
|
||
"docx",
|
||
"--outdir",
|
||
str(target_dir),
|
||
str(path),
|
||
],
|
||
check=True,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=60,
|
||
)
|
||
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc:
|
||
raise RuntimeError(f"无法通过 LibreOffice 转换法规 .doc 材料:{path.name}") from exc
|
||
converted = target_dir / f"{path.stem}.docx"
|
||
if not converted.exists():
|
||
raise RuntimeError(f"LibreOffice 未生成 docx:{path.name}")
|
||
return extract_text_from_path(converted)
|
||
|
||
|
||
def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
|
||
chunks: list[TextChunk] = []
|
||
for path in sorted(source_dir.rglob("*")):
|
||
if not path.is_file():
|
||
continue
|
||
try:
|
||
text = extract_text_from_path(path)
|
||
except RuntimeError as exc:
|
||
if _is_attachment4(path):
|
||
raise RuntimeError(f"附件 4 核心法规材料抽取失败:{path.name}") from exc
|
||
logger.warning("Regulatory source extraction skipped", extra={"path": str(path), "error": str(exc)})
|
||
continue
|
||
chunks.extend(chunk_text(text, source=str(path.relative_to(source_dir))))
|
||
return chunks
|
||
|
||
|
||
def _is_attachment4(path: Path) -> bool:
|
||
normalized = path.name.replace(" ", "")
|
||
return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized
|
||
|
||
|
||
def build_chroma_index(
|
||
*,
|
||
source_dir: Path,
|
||
embedding_provider: EmbeddingFunction,
|
||
persist_path: Path | None = None,
|
||
collection_name: str | None = None,
|
||
) -> int:
|
||
try:
|
||
import chromadb
|
||
except ImportError as exc:
|
||
raise RuntimeError("chromadb 未安装,请先安装 requirements.txt。") from exc
|
||
|
||
persist_path = persist_path or Path(settings.REGULATORY_RAG_CHROMA_PATH)
|
||
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
|
||
persist_path.mkdir(parents=True, exist_ok=True)
|
||
chunks = collect_source_chunks(source_dir)
|
||
client = chromadb.PersistentClient(path=str(persist_path))
|
||
collection = client.get_or_create_collection(collection_name)
|
||
if not chunks:
|
||
return 0
|
||
texts = [chunk.text for chunk in chunks]
|
||
embeddings = embedding_provider(texts)
|
||
ids = [
|
||
hashlib.sha256(f"{chunk.metadata['source']}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
|
||
for chunk in chunks
|
||
]
|
||
collection.upsert(
|
||
ids=ids,
|
||
documents=texts,
|
||
metadatas=[chunk.metadata for chunk in chunks],
|
||
embeddings=embeddings,
|
||
)
|
||
return len(chunks)
|