from __future__ import annotations import hashlib import logging import subprocess import tempfile from dataclasses import dataclass from pathlib import Path from django.conf import settings from docx import Document from openpyxl import load_workbook from pypdf import PdfReader from pptx import Presentation from .rag_embedding import EmbeddingFunction logger = logging.getLogger("review_agent.regulatory_review.rag_index") @dataclass(frozen=True) class TextChunk: text: str metadata: dict[str, object] def chunk_text(text: str, *, source: str, chunk_size: int = 900, overlap: int = 120) -> list[TextChunk]: normalized = "\n".join(line.strip() for line in text.splitlines() if line.strip()) if not normalized: return [] chunks = [] start = 0 index = 0 step = max(1, chunk_size - overlap) while start < len(normalized): part = normalized[start : start + chunk_size].strip() if part: chunks.append(TextChunk(text=part, metadata={"source": source, "chunk_index": index})) index += 1 start += step return chunks def extract_text_from_path(path: Path) -> str: suffix = path.suffix.lower() if suffix in {".txt", ".md"}: return path.read_text(encoding="utf-8", errors="ignore") if suffix == ".pdf": return "\n".join(page.extract_text() or "" for page in PdfReader(str(path)).pages) if suffix == ".docx": return "\n".join(paragraph.text for paragraph in Document(str(path)).paragraphs) if suffix == ".pptx": presentation = Presentation(str(path)) lines = [] for slide in presentation.slides: for shape in slide.shapes: if hasattr(shape, "text"): lines.append(shape.text) return "\n".join(lines) if suffix == ".xlsx": workbook = load_workbook(path, data_only=True, read_only=True) lines = [] for sheet in workbook.worksheets: for row in sheet.iter_rows(values_only=True): values = [str(cell) for cell in row if cell not in {None, ""}] if values: lines.append("\t".join(values)) return "\n".join(lines) if suffix == ".doc": return _extract_legacy_doc_with_libreoffice(path) return "" def _extract_legacy_doc_with_libreoffice(path: Path) -> str: with tempfile.TemporaryDirectory() as tmp_dir: target_dir = Path(tmp_dir) try: subprocess.run( [ "soffice", "--headless", "--convert-to", "docx", "--outdir", str(target_dir), str(path), ], check=True, capture_output=True, text=True, timeout=60, ) except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc: raise RuntimeError(f"无法通过 LibreOffice 转换法规 .doc 材料:{path.name}") from exc converted = target_dir / f"{path.stem}.docx" if not converted.exists(): raise RuntimeError(f"LibreOffice 未生成 docx:{path.name}") return extract_text_from_path(converted) def collect_source_chunks(source_dir: Path) -> list[TextChunk]: chunks: list[TextChunk] = [] for path in sorted(source_dir.rglob("*")): if not path.is_file(): continue try: text = extract_text_from_path(path) except RuntimeError as exc: if _is_attachment4(path): raise RuntimeError(f"附件 4 核心法规材料抽取失败:{path.name}") from exc logger.warning("Regulatory source extraction skipped", extra={"path": str(path), "error": str(exc)}) continue chunks.extend(chunk_text(text, source=str(path.relative_to(source_dir)))) return chunks def _is_attachment4(path: Path) -> bool: normalized = path.name.replace(" ", "") return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized def build_chroma_index( *, source_dir: Path, embedding_provider: EmbeddingFunction, persist_path: Path | None = None, collection_name: str | None = None, ) -> int: try: import chromadb except ImportError as exc: raise RuntimeError("chromadb 未安装,请先安装 requirements.txt。") from exc persist_path = persist_path or Path(settings.REGULATORY_RAG_CHROMA_PATH) collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION persist_path.mkdir(parents=True, exist_ok=True) chunks = collect_source_chunks(source_dir) client = chromadb.PersistentClient(path=str(persist_path)) collection = client.get_or_create_collection(collection_name) if not chunks: return 0 texts = [chunk.text for chunk in chunks] embeddings = embedding_provider(texts) ids = [ hashlib.sha256(f"{chunk.metadata['source']}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest() for chunk in chunks ] collection.upsert( ids=ids, documents=texts, metadatas=[chunk.metadata for chunk in chunks], embeddings=embeddings, ) return len(chunks)