97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
from pathlib import Path
|
|
|
|
from django.conf import settings
|
|
|
|
from agent_core.llm_provider import create_embedding_provider
|
|
|
|
|
|
def _client(path: str | Path | None = None):
|
|
import chromadb
|
|
|
|
resolved_path = str(path or settings.CHROMA_PATH)
|
|
return chromadb.PersistentClient(path=resolved_path)
|
|
|
|
|
|
def _embedding_provider():
|
|
return create_embedding_provider(
|
|
{
|
|
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
|
|
"EMBEDDING_BASE_URL": settings.EMBEDDING_BASE_URL,
|
|
"EMBEDDING_MODEL": settings.EMBEDDING_MODEL,
|
|
}
|
|
)
|
|
|
|
|
|
def upsert_chunks(
|
|
collection: str,
|
|
chunks: list[dict],
|
|
store_path: str | Path | None = None,
|
|
) -> None:
|
|
client = _client(store_path)
|
|
chroma_collection = client.get_or_create_collection(collection)
|
|
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
|
|
for document_id in document_ids:
|
|
chroma_collection.delete(where={"document_id": document_id})
|
|
texts = [chunk["content"] for chunk in chunks]
|
|
embeddings = _embedding_provider().embed_texts(texts)
|
|
chroma_collection.upsert(
|
|
ids=[chunk["chunk_id"] for chunk in chunks],
|
|
documents=texts,
|
|
embeddings=embeddings,
|
|
metadatas=[
|
|
{
|
|
"scenario_id": chunk["scenario_id"],
|
|
"document_id": chunk["document_id"],
|
|
"source": chunk["source"],
|
|
"chunk_id": chunk["chunk_id"],
|
|
"created_at": chunk["created_at"],
|
|
}
|
|
for chunk in chunks
|
|
],
|
|
)
|
|
|
|
|
|
def query_chunks(
|
|
scenario_id: str,
|
|
query: str,
|
|
collection: str,
|
|
top_k: int = 5,
|
|
document_ids: list[int] | None = None,
|
|
store_path: str | Path | None = None,
|
|
) -> list[dict]:
|
|
client = _client(store_path)
|
|
chroma_collection = client.get_or_create_collection(collection)
|
|
where: dict = {"scenario_id": scenario_id}
|
|
if document_ids:
|
|
where = {
|
|
"$and": [
|
|
{"scenario_id": scenario_id},
|
|
{"document_id": {"$in": document_ids}},
|
|
]
|
|
}
|
|
embedding = _embedding_provider().embed_texts([query])[0]
|
|
result = chroma_collection.query(
|
|
query_embeddings=[embedding],
|
|
n_results=top_k,
|
|
where=where,
|
|
include=["documents", "metadatas", "distances"],
|
|
)
|
|
chunks = []
|
|
documents = result.get("documents", [[]])[0]
|
|
metadatas = result.get("metadatas", [[]])[0]
|
|
distances = result.get("distances", [[]])[0]
|
|
for content, metadata, distance in zip(documents, metadatas, distances):
|
|
chunks.append(
|
|
{
|
|
"scenario_id": metadata.get("scenario_id"),
|
|
"document_id": metadata.get("document_id"),
|
|
"collection": collection,
|
|
"source": metadata.get("source"),
|
|
"chunk_id": metadata.get("chunk_id"),
|
|
"content": content,
|
|
"created_at": metadata.get("created_at"),
|
|
"score": round(1 / (1 + float(distance)), 4),
|
|
}
|
|
)
|
|
return chunks
|