feat(regulatory): 增加本地法规RAG索引检索
This commit is contained in:
57
review_agent/regulatory_review/services/rag_citation.py
Normal file
57
review_agent/regulatory_review/services/rag_citation.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from .rag_embedding import EmbeddingFunction, get_embedding_provider
|
||||
|
||||
|
||||
class RagIndexUnavailable(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def retrieve_citations(
|
||||
query: str,
|
||||
*,
|
||||
embedding_provider: EmbeddingFunction | None = None,
|
||||
collection=None,
|
||||
n_results: int = 3,
|
||||
) -> list[dict[str, object]]:
|
||||
provider = embedding_provider or get_embedding_provider()
|
||||
if collection is None:
|
||||
collection = _load_collection()
|
||||
embeddings = provider([query])
|
||||
result = collection.query(query_embeddings=embeddings, n_results=n_results)
|
||||
documents = (result.get("documents") or [[]])[0]
|
||||
metadatas = (result.get("metadatas") or [[]])[0]
|
||||
distances = (result.get("distances") or [[]])[0]
|
||||
if not documents:
|
||||
return [{"source": "原文依据待补充", "text": "RAG 无命中", "score": None}]
|
||||
citations = []
|
||||
for index, document in enumerate(documents):
|
||||
metadata = metadatas[index] if index < len(metadatas) else {}
|
||||
distance = distances[index] if index < len(distances) else None
|
||||
citations.append(
|
||||
{
|
||||
"source": metadata.get("source", "法规材料"),
|
||||
"text": document,
|
||||
"score": distance,
|
||||
}
|
||||
)
|
||||
return citations
|
||||
|
||||
|
||||
def _load_collection():
|
||||
persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH)
|
||||
if not persist_path.exists():
|
||||
raise RagIndexUnavailable("法规 RAG 索引不存在,请先运行 regulatory_rag_build。")
|
||||
try:
|
||||
import chromadb
|
||||
except ImportError as exc:
|
||||
raise RagIndexUnavailable("chromadb 未安装,请先安装 requirements.txt。") from exc
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
try:
|
||||
return client.get_collection(settings.REGULATORY_RAG_COLLECTION)
|
||||
except Exception as exc:
|
||||
raise RagIndexUnavailable("法规 RAG collection 不存在,请先运行 regulatory_rag_build。") from exc
|
||||
Reference in New Issue
Block a user