from __future__ import annotations from pathlib import Path from django.conf import settings from .rag_embedding import EmbeddingFunction, get_embedding_provider class RagIndexUnavailable(RuntimeError): pass def retrieve_citations( query: str, *, embedding_provider: EmbeddingFunction | None = None, collection=None, n_results: int = 3, ) -> list[dict[str, object]]: provider = embedding_provider or get_embedding_provider() if collection is None: collection = _load_collection() embeddings = provider([query]) result = collection.query(query_embeddings=embeddings, n_results=n_results) documents = (result.get("documents") or [[]])[0] metadatas = (result.get("metadatas") or [[]])[0] distances = (result.get("distances") or [[]])[0] if not documents: return [{"source": "原文依据待补充", "text": "RAG 无命中", "score": None}] citations = [] for index, document in enumerate(documents): metadata = metadatas[index] if index < len(metadatas) else {} distance = distances[index] if index < len(distances) else None citations.append( { "source": metadata.get("source", "法规材料"), "text": document, "score": distance, "metadata": metadata, } ) return citations def _load_collection(): persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH) if not persist_path.exists(): raise RagIndexUnavailable("法规 RAG 索引不存在,请先运行 regulatory_rag_build。") try: import chromadb except ImportError as exc: raise RagIndexUnavailable("chromadb 未安装,请先安装 requirements.txt。") from exc client = chromadb.PersistentClient(path=str(persist_path)) try: return client.get_collection(settings.REGULATORY_RAG_COLLECTION) except Exception as exc: raise RagIndexUnavailable("法规 RAG collection 不存在,请先运行 regulatory_rag_build。") from exc