Files
DEMO-AGENT/agent_core/rag/chroma_store.py

97 lines
3.0 KiB
Python

from pathlib import Path
from django.conf import settings
from agent_core.llm_provider import create_embedding_provider
def _client(path: str | Path | None = None):
import chromadb
resolved_path = str(path or settings.CHROMA_PATH)
return chromadb.PersistentClient(path=resolved_path)
def _embedding_provider():
return create_embedding_provider(
{
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
"EMBEDDING_BASE_URL": settings.EMBEDDING_BASE_URL,
"EMBEDDING_MODEL": settings.EMBEDDING_MODEL,
}
)
def upsert_chunks(
collection: str,
chunks: list[dict],
store_path: str | Path | None = None,
) -> None:
client = _client(store_path)
chroma_collection = client.get_or_create_collection(collection)
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
for document_id in document_ids:
chroma_collection.delete(where={"document_id": document_id})
texts = [chunk["content"] for chunk in chunks]
embeddings = _embedding_provider().embed_texts(texts)
chroma_collection.upsert(
ids=[chunk["chunk_id"] for chunk in chunks],
documents=texts,
embeddings=embeddings,
metadatas=[
{
"scenario_id": chunk["scenario_id"],
"document_id": chunk["document_id"],
"source": chunk["source"],
"chunk_id": chunk["chunk_id"],
"created_at": chunk["created_at"],
}
for chunk in chunks
],
)
def query_chunks(
scenario_id: str,
query: str,
collection: str,
top_k: int = 5,
document_ids: list[int] | None = None,
store_path: str | Path | None = None,
) -> list[dict]:
client = _client(store_path)
chroma_collection = client.get_or_create_collection(collection)
where: dict = {"scenario_id": scenario_id}
if document_ids:
where = {
"$and": [
{"scenario_id": scenario_id},
{"document_id": {"$in": document_ids}},
]
}
embedding = _embedding_provider().embed_texts([query])[0]
result = chroma_collection.query(
query_embeddings=[embedding],
n_results=top_k,
where=where,
include=["documents", "metadatas", "distances"],
)
chunks = []
documents = result.get("documents", [[]])[0]
metadatas = result.get("metadatas", [[]])[0]
distances = result.get("distances", [[]])[0]
for content, metadata, distance in zip(documents, metadatas, distances):
chunks.append(
{
"scenario_id": metadata.get("scenario_id"),
"document_id": metadata.get("document_id"),
"collection": collection,
"source": metadata.get("source"),
"chunk_id": metadata.get("chunk_id"),
"content": content,
"created_at": metadata.get("created_at"),
"score": round(1 / (1 + float(distance)), 4),
}
)
return chunks