Files
DEMO-AGENT/review_agent/regulatory_review/services/rag_index.py

149 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import hashlib
import logging
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from django.conf import settings
from docx import Document
from openpyxl import load_workbook
from pypdf import PdfReader
from pptx import Presentation
from .rag_embedding import EmbeddingFunction
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
@dataclass(frozen=True)
class TextChunk:
text: str
metadata: dict[str, object]
def chunk_text(text: str, *, source: str, chunk_size: int = 900, overlap: int = 120) -> list[TextChunk]:
normalized = "\n".join(line.strip() for line in text.splitlines() if line.strip())
if not normalized:
return []
chunks = []
start = 0
index = 0
step = max(1, chunk_size - overlap)
while start < len(normalized):
part = normalized[start : start + chunk_size].strip()
if part:
chunks.append(TextChunk(text=part, metadata={"source": source, "chunk_index": index}))
index += 1
start += step
return chunks
def extract_text_from_path(path: Path) -> str:
suffix = path.suffix.lower()
if suffix in {".txt", ".md"}:
return path.read_text(encoding="utf-8", errors="ignore")
if suffix == ".pdf":
return "\n".join(page.extract_text() or "" for page in PdfReader(str(path)).pages)
if suffix == ".docx":
return "\n".join(paragraph.text for paragraph in Document(str(path)).paragraphs)
if suffix == ".pptx":
presentation = Presentation(str(path))
lines = []
for slide in presentation.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
lines.append(shape.text)
return "\n".join(lines)
if suffix == ".xlsx":
workbook = load_workbook(path, data_only=True, read_only=True)
lines = []
for sheet in workbook.worksheets:
for row in sheet.iter_rows(values_only=True):
values = [str(cell) for cell in row if cell not in {None, ""}]
if values:
lines.append("\t".join(values))
return "\n".join(lines)
if suffix == ".doc":
return _extract_legacy_doc_with_libreoffice(path)
return ""
def _extract_legacy_doc_with_libreoffice(path: Path) -> str:
with tempfile.TemporaryDirectory() as tmp_dir:
target_dir = Path(tmp_dir)
try:
subprocess.run(
[
"soffice",
"--headless",
"--convert-to",
"docx",
"--outdir",
str(target_dir),
str(path),
],
check=True,
capture_output=True,
text=True,
timeout=60,
)
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc:
raise RuntimeError(f"无法通过 LibreOffice 转换法规 .doc 材料:{path.name}") from exc
converted = target_dir / f"{path.stem}.docx"
if not converted.exists():
raise RuntimeError(f"LibreOffice 未生成 docx{path.name}")
return extract_text_from_path(converted)
def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
chunks: list[TextChunk] = []
for path in sorted(source_dir.rglob("*")):
if not path.is_file():
continue
try:
text = extract_text_from_path(path)
except RuntimeError as exc:
logger.warning("Regulatory source extraction skipped", extra={"path": str(path), "error": str(exc)})
continue
chunks.extend(chunk_text(text, source=str(path.relative_to(source_dir))))
return chunks
def build_chroma_index(
*,
source_dir: Path,
embedding_provider: EmbeddingFunction,
persist_path: Path | None = None,
collection_name: str | None = None,
) -> int:
try:
import chromadb
except ImportError as exc:
raise RuntimeError("chromadb 未安装,请先安装 requirements.txt。") from exc
persist_path = persist_path or Path(settings.REGULATORY_RAG_CHROMA_PATH)
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
persist_path.mkdir(parents=True, exist_ok=True)
chunks = collect_source_chunks(source_dir)
client = chromadb.PersistentClient(path=str(persist_path))
collection = client.get_or_create_collection(collection_name)
if not chunks:
return 0
texts = [chunk.text for chunk in chunks]
embeddings = embedding_provider(texts)
ids = [
hashlib.sha256(f"{chunk.metadata['source']}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
for chunk in chunks
]
collection.upsert(
ids=ids,
documents=texts,
metadatas=[chunk.metadata for chunk in chunks],
embeddings=embeddings,
)
return len(chunks)