Files
DEMO-AGENT/review_agent/regulatory_review/services/rag_citation.py

59 lines
2.0 KiB
Python

from __future__ import annotations
from pathlib import Path
from django.conf import settings
from .rag_embedding import EmbeddingFunction, get_embedding_provider
class RagIndexUnavailable(RuntimeError):
pass
def retrieve_citations(
query: str,
*,
embedding_provider: EmbeddingFunction | None = None,
collection=None,
n_results: int = 3,
) -> list[dict[str, object]]:
provider = embedding_provider or get_embedding_provider()
if collection is None:
collection = _load_collection()
embeddings = provider([query])
result = collection.query(query_embeddings=embeddings, n_results=n_results)
documents = (result.get("documents") or [[]])[0]
metadatas = (result.get("metadatas") or [[]])[0]
distances = (result.get("distances") or [[]])[0]
if not documents:
return [{"source": "原文依据待补充", "text": "RAG 无命中", "score": None}]
citations = []
for index, document in enumerate(documents):
metadata = metadatas[index] if index < len(metadatas) else {}
distance = distances[index] if index < len(distances) else None
citations.append(
{
"source": metadata.get("source", "法规材料"),
"text": document,
"score": distance,
"metadata": metadata,
}
)
return citations
def _load_collection():
persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH)
if not persist_path.exists():
raise RagIndexUnavailable("法规 RAG 索引不存在,请先运行 regulatory_rag_build。")
try:
import chromadb
except ImportError as exc:
raise RagIndexUnavailable("chromadb 未安装,请先安装 requirements.txt。") from exc
client = chromadb.PersistentClient(path=str(persist_path))
try:
return client.get_collection(settings.REGULATORY_RAG_COLLECTION)
except Exception as exc:
raise RagIndexUnavailable("法规 RAG collection 不存在,请先运行 regulatory_rag_build。") from exc