70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
import json
|
|
import re
|
|
import importlib.util
|
|
from pathlib import Path
|
|
|
|
from django.conf import settings
|
|
|
|
from .chroma_store import query_chunks
|
|
|
|
|
|
def _default_store_path() -> Path:
|
|
return Path(settings.CHROMA_PATH) / "rag_store.json"
|
|
|
|
|
|
def _load_store(store_path: Path) -> list[dict]:
|
|
if not store_path.exists():
|
|
return []
|
|
with store_path.open("r", encoding="utf-8") as file:
|
|
return json.load(file)
|
|
|
|
|
|
def _tokens(text: str) -> set[str]:
|
|
lowered = text.lower()
|
|
ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered))
|
|
cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered))
|
|
chars = {char for char in lowered if "\u4e00" <= char <= "\u9fff"}
|
|
return ascii_tokens | cjk_tokens | chars
|
|
|
|
|
|
def _score(query_tokens: set[str], content: str) -> float:
|
|
content_tokens = _tokens(content)
|
|
if not query_tokens or not content_tokens:
|
|
return 0.0
|
|
overlap = query_tokens & content_tokens
|
|
return round(len(overlap) / len(query_tokens), 4)
|
|
|
|
|
|
def retrieve(
|
|
scenario_id: str,
|
|
query: str,
|
|
collection: str,
|
|
top_k: int = 5,
|
|
document_ids: list[int] | None = None,
|
|
store_path: str | Path | None = None,
|
|
) -> list[dict]:
|
|
if store_path is None and importlib.util.find_spec("chromadb") is not None:
|
|
return query_chunks(
|
|
scenario_id=scenario_id,
|
|
query=query,
|
|
collection=collection,
|
|
top_k=top_k,
|
|
document_ids=document_ids,
|
|
)
|
|
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
|
query_tokens = _tokens(query)
|
|
allowed_document_ids = set(document_ids or [])
|
|
scored_chunks = []
|
|
for chunk in _load_store(resolved_store_path):
|
|
if chunk.get("scenario_id") != scenario_id:
|
|
continue
|
|
if chunk.get("collection") != collection:
|
|
continue
|
|
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
|
|
continue
|
|
score = _score(query_tokens, chunk.get("content", ""))
|
|
if score <= 0:
|
|
continue
|
|
scored_chunks.append({**chunk, "score": score})
|
|
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]
|