105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
from pathlib import Path
|
||
|
||
from django.conf import settings
|
||
|
||
from agent_core.llm_provider import create_embedding_provider
|
||
|
||
|
||
def _client(path: str | Path | None = None):
|
||
"""按给定路径初始化 Chroma 持久化客户端。"""
|
||
import chromadb
|
||
|
||
resolved_path = str(path or settings.CHROMA_PATH)
|
||
return chromadb.PersistentClient(path=resolved_path)
|
||
|
||
|
||
def _embedding_provider():
|
||
"""从 Django settings 构造 Embedding Provider,避免在业务层散落配置读取。"""
|
||
return create_embedding_provider(
|
||
{
|
||
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
|
||
"EMBEDDING_BASE_URL": settings.EMBEDDING_BASE_URL,
|
||
"EMBEDDING_MODEL": settings.EMBEDDING_MODEL,
|
||
}
|
||
)
|
||
|
||
|
||
def upsert_chunks(
|
||
collection: str,
|
||
chunks: list[dict],
|
||
store_path: str | Path | None = None,
|
||
) -> None:
|
||
"""
|
||
将 chunk 写入 Chroma。
|
||
|
||
同一 document_id 重新入库前会先删除旧记录,保证一次文档只有一份有效向量数据。
|
||
"""
|
||
client = _client(store_path)
|
||
chroma_collection = client.get_or_create_collection(collection)
|
||
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
|
||
for document_id in document_ids:
|
||
chroma_collection.delete(where={"document_id": document_id})
|
||
texts = [chunk["content"] for chunk in chunks]
|
||
embeddings = _embedding_provider().embed_texts(texts)
|
||
chroma_collection.upsert(
|
||
ids=[chunk["chunk_id"] for chunk in chunks],
|
||
documents=texts,
|
||
embeddings=embeddings,
|
||
metadatas=[
|
||
{
|
||
"scenario_id": chunk["scenario_id"],
|
||
"document_id": chunk["document_id"],
|
||
"source": chunk["source"],
|
||
"chunk_id": chunk["chunk_id"],
|
||
"created_at": chunk["created_at"],
|
||
}
|
||
for chunk in chunks
|
||
],
|
||
)
|
||
|
||
|
||
def query_chunks(
|
||
scenario_id: str,
|
||
query: str,
|
||
collection: str,
|
||
top_k: int = 5,
|
||
document_ids: list[int] | None = None,
|
||
store_path: str | Path | None = None,
|
||
) -> list[dict]:
|
||
"""执行向量检索,并把 Chroma 原始结果转换为统一引用结构。"""
|
||
client = _client(store_path)
|
||
chroma_collection = client.get_or_create_collection(collection)
|
||
where: dict = {"scenario_id": scenario_id}
|
||
if document_ids:
|
||
where = {
|
||
"$and": [
|
||
{"scenario_id": scenario_id},
|
||
{"document_id": {"$in": document_ids}},
|
||
]
|
||
}
|
||
embedding = _embedding_provider().embed_texts([query])[0]
|
||
result = chroma_collection.query(
|
||
query_embeddings=[embedding],
|
||
n_results=top_k,
|
||
where=where,
|
||
include=["documents", "metadatas", "distances"],
|
||
)
|
||
chunks = []
|
||
documents = result.get("documents", [[]])[0]
|
||
metadatas = result.get("metadatas", [[]])[0]
|
||
distances = result.get("distances", [[]])[0]
|
||
for content, metadata, distance in zip(documents, metadatas, distances):
|
||
chunks.append(
|
||
{
|
||
"scenario_id": metadata.get("scenario_id"),
|
||
"document_id": metadata.get("document_id"),
|
||
"collection": collection,
|
||
"source": metadata.get("source"),
|
||
"chunk_id": metadata.get("chunk_id"),
|
||
"content": content,
|
||
"created_at": metadata.get("created_at"),
|
||
"score": round(1 / (1 + float(distance)), 4),
|
||
}
|
||
)
|
||
return chunks
|