58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
from django.conf import settings
|
|
|
|
from .rag_embedding import EmbeddingFunction, get_embedding_provider
|
|
|
|
|
|
class RagIndexUnavailable(RuntimeError):
|
|
pass
|
|
|
|
|
|
def retrieve_citations(
|
|
query: str,
|
|
*,
|
|
embedding_provider: EmbeddingFunction | None = None,
|
|
collection=None,
|
|
n_results: int = 3,
|
|
) -> list[dict[str, object]]:
|
|
provider = embedding_provider or get_embedding_provider()
|
|
if collection is None:
|
|
collection = _load_collection()
|
|
embeddings = provider([query])
|
|
result = collection.query(query_embeddings=embeddings, n_results=n_results)
|
|
documents = (result.get("documents") or [[]])[0]
|
|
metadatas = (result.get("metadatas") or [[]])[0]
|
|
distances = (result.get("distances") or [[]])[0]
|
|
if not documents:
|
|
return [{"source": "原文依据待补充", "text": "RAG 无命中", "score": None}]
|
|
citations = []
|
|
for index, document in enumerate(documents):
|
|
metadata = metadatas[index] if index < len(metadatas) else {}
|
|
distance = distances[index] if index < len(distances) else None
|
|
citations.append(
|
|
{
|
|
"source": metadata.get("source", "法规材料"),
|
|
"text": document,
|
|
"score": distance,
|
|
}
|
|
)
|
|
return citations
|
|
|
|
|
|
def _load_collection():
|
|
persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH)
|
|
if not persist_path.exists():
|
|
raise RagIndexUnavailable("法规 RAG 索引不存在,请先运行 regulatory_rag_build。")
|
|
try:
|
|
import chromadb
|
|
except ImportError as exc:
|
|
raise RagIndexUnavailable("chromadb 未安装,请先安装 requirements.txt。") from exc
|
|
client = chromadb.PersistentClient(path=str(persist_path))
|
|
try:
|
|
return client.get_collection(settings.REGULATORY_RAG_COLLECTION)
|
|
except Exception as exc:
|
|
raise RagIndexUnavailable("法规 RAG collection 不存在,请先运行 regulatory_rag_build。") from exc
|