Files
DEMO-AGENT/agent_core/rag/chroma_store.py

105 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
from django.conf import settings
from agent_core.llm_provider import create_embedding_provider
def _client(path: str | Path | None = None):
"""按给定路径初始化 Chroma 持久化客户端。"""
import chromadb
resolved_path = str(path or settings.CHROMA_PATH)
return chromadb.PersistentClient(path=resolved_path)
def _embedding_provider():
"""从 Django settings 构造 Embedding Provider避免在业务层散落配置读取。"""
return create_embedding_provider(
{
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
"EMBEDDING_BASE_URL": settings.EMBEDDING_BASE_URL,
"EMBEDDING_MODEL": settings.EMBEDDING_MODEL,
}
)
def upsert_chunks(
collection: str,
chunks: list[dict],
store_path: str | Path | None = None,
) -> None:
"""
将 chunk 写入 Chroma。
同一 document_id 重新入库前会先删除旧记录,保证一次文档只有一份有效向量数据。
"""
client = _client(store_path)
chroma_collection = client.get_or_create_collection(collection)
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
for document_id in document_ids:
chroma_collection.delete(where={"document_id": document_id})
texts = [chunk["content"] for chunk in chunks]
embeddings = _embedding_provider().embed_texts(texts)
chroma_collection.upsert(
ids=[chunk["chunk_id"] for chunk in chunks],
documents=texts,
embeddings=embeddings,
metadatas=[
{
"scenario_id": chunk["scenario_id"],
"document_id": chunk["document_id"],
"source": chunk["source"],
"chunk_id": chunk["chunk_id"],
"created_at": chunk["created_at"],
}
for chunk in chunks
],
)
def query_chunks(
scenario_id: str,
query: str,
collection: str,
top_k: int = 5,
document_ids: list[int] | None = None,
store_path: str | Path | None = None,
) -> list[dict]:
"""执行向量检索,并把 Chroma 原始结果转换为统一引用结构。"""
client = _client(store_path)
chroma_collection = client.get_or_create_collection(collection)
where: dict = {"scenario_id": scenario_id}
if document_ids:
where = {
"$and": [
{"scenario_id": scenario_id},
{"document_id": {"$in": document_ids}},
]
}
embedding = _embedding_provider().embed_texts([query])[0]
result = chroma_collection.query(
query_embeddings=[embedding],
n_results=top_k,
where=where,
include=["documents", "metadatas", "distances"],
)
chunks = []
documents = result.get("documents", [[]])[0]
metadatas = result.get("metadatas", [[]])[0]
distances = result.get("distances", [[]])[0]
for content, metadata, distance in zip(documents, metadatas, distances):
chunks.append(
{
"scenario_id": metadata.get("scenario_id"),
"document_id": metadata.get("document_id"),
"collection": collection,
"source": metadata.get("source"),
"chunk_id": metadata.get("chunk_id"),
"content": content,
"created_at": metadata.get("created_at"),
"score": round(1 / (1 + float(distance)), 4),
}
)
return chunks