feat(agent-core): 增加智能编排与模型工具基础
This commit is contained in:
96
agent_core/rag/chroma_store.py
Normal file
96
agent_core/rag/chroma_store.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from agent_core.llm_provider import create_embedding_provider
|
||||
|
||||
|
||||
def _client(path: str | Path | None = None):
|
||||
import chromadb
|
||||
|
||||
resolved_path = str(path or settings.CHROMA_PATH)
|
||||
return chromadb.PersistentClient(path=resolved_path)
|
||||
|
||||
|
||||
def _embedding_provider():
|
||||
return create_embedding_provider(
|
||||
{
|
||||
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
|
||||
"EMBEDDING_BASE_URL": settings.EMBEDDING_BASE_URL,
|
||||
"EMBEDDING_MODEL": settings.EMBEDDING_MODEL,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def upsert_chunks(
|
||||
collection: str,
|
||||
chunks: list[dict],
|
||||
store_path: str | Path | None = None,
|
||||
) -> None:
|
||||
client = _client(store_path)
|
||||
chroma_collection = client.get_or_create_collection(collection)
|
||||
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
|
||||
for document_id in document_ids:
|
||||
chroma_collection.delete(where={"document_id": document_id})
|
||||
texts = [chunk["content"] for chunk in chunks]
|
||||
embeddings = _embedding_provider().embed_texts(texts)
|
||||
chroma_collection.upsert(
|
||||
ids=[chunk["chunk_id"] for chunk in chunks],
|
||||
documents=texts,
|
||||
embeddings=embeddings,
|
||||
metadatas=[
|
||||
{
|
||||
"scenario_id": chunk["scenario_id"],
|
||||
"document_id": chunk["document_id"],
|
||||
"source": chunk["source"],
|
||||
"chunk_id": chunk["chunk_id"],
|
||||
"created_at": chunk["created_at"],
|
||||
}
|
||||
for chunk in chunks
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def query_chunks(
|
||||
scenario_id: str,
|
||||
query: str,
|
||||
collection: str,
|
||||
top_k: int = 5,
|
||||
document_ids: list[int] | None = None,
|
||||
store_path: str | Path | None = None,
|
||||
) -> list[dict]:
|
||||
client = _client(store_path)
|
||||
chroma_collection = client.get_or_create_collection(collection)
|
||||
where: dict = {"scenario_id": scenario_id}
|
||||
if document_ids:
|
||||
where = {
|
||||
"$and": [
|
||||
{"scenario_id": scenario_id},
|
||||
{"document_id": {"$in": document_ids}},
|
||||
]
|
||||
}
|
||||
embedding = _embedding_provider().embed_texts([query])[0]
|
||||
result = chroma_collection.query(
|
||||
query_embeddings=[embedding],
|
||||
n_results=top_k,
|
||||
where=where,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
chunks = []
|
||||
documents = result.get("documents", [[]])[0]
|
||||
metadatas = result.get("metadatas", [[]])[0]
|
||||
distances = result.get("distances", [[]])[0]
|
||||
for content, metadata, distance in zip(documents, metadatas, distances):
|
||||
chunks.append(
|
||||
{
|
||||
"scenario_id": metadata.get("scenario_id"),
|
||||
"document_id": metadata.get("document_id"),
|
||||
"collection": collection,
|
||||
"source": metadata.get("source"),
|
||||
"chunk_id": metadata.get("chunk_id"),
|
||||
"content": content,
|
||||
"created_at": metadata.get("created_at"),
|
||||
"score": round(1 / (1 + float(distance)), 4),
|
||||
}
|
||||
)
|
||||
return chunks
|
||||
Reference in New Issue
Block a user