from pathlib import Path from django.conf import settings from agent_core.llm_provider import create_embedding_provider def _client(path: str | Path | None = None): """按给定路径初始化 Chroma 持久化客户端。""" import chromadb resolved_path = str(path or settings.CHROMA_PATH) return chromadb.PersistentClient(path=resolved_path) def _embedding_provider(): """从 Django settings 构造 Embedding Provider,避免在业务层散落配置读取。""" return create_embedding_provider( { "EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY, "EMBEDDING_BASE_URL": settings.EMBEDDING_BASE_URL, "EMBEDDING_MODEL": settings.EMBEDDING_MODEL, } ) def upsert_chunks( collection: str, chunks: list[dict], store_path: str | Path | None = None, ) -> None: """ 将 chunk 写入 Chroma。 同一 document_id 重新入库前会先删除旧记录,保证一次文档只有一份有效向量数据。 """ client = _client(store_path) chroma_collection = client.get_or_create_collection(collection) document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None} for document_id in document_ids: chroma_collection.delete(where={"document_id": document_id}) texts = [chunk["content"] for chunk in chunks] embeddings = _embedding_provider().embed_texts(texts) chroma_collection.upsert( ids=[chunk["chunk_id"] for chunk in chunks], documents=texts, embeddings=embeddings, metadatas=[ { "scenario_id": chunk["scenario_id"], "document_id": chunk["document_id"], "source": chunk["source"], "chunk_id": chunk["chunk_id"], "created_at": chunk["created_at"], } for chunk in chunks ], ) def query_chunks( scenario_id: str, query: str, collection: str, top_k: int = 5, document_ids: list[int] | None = None, store_path: str | Path | None = None, ) -> list[dict]: """执行向量检索,并把 Chroma 原始结果转换为统一引用结构。""" client = _client(store_path) chroma_collection = client.get_or_create_collection(collection) where: dict = {"scenario_id": scenario_id} if document_ids: where = { "$and": [ {"scenario_id": scenario_id}, {"document_id": {"$in": document_ids}}, ] } embedding = _embedding_provider().embed_texts([query])[0] result = chroma_collection.query( query_embeddings=[embedding], n_results=top_k, where=where, include=["documents", "metadatas", "distances"], ) chunks = [] documents = result.get("documents", [[]])[0] metadatas = result.get("metadatas", [[]])[0] distances = result.get("distances", [[]])[0] for content, metadata, distance in zip(documents, metadatas, distances): chunks.append( { "scenario_id": metadata.get("scenario_id"), "document_id": metadata.get("document_id"), "collection": collection, "source": metadata.get("source"), "chunk_id": metadata.get("chunk_id"), "content": content, "created_at": metadata.get("created_at"), "score": round(1 / (1 + float(distance)), 4), } ) return chunks