import importlib.util import json import re from pathlib import Path from django.conf import settings from .chroma_store import query_chunks def retrieve( scenario_id: str, query: str, collection: str, top_k: int = 5, document_ids: list[int] | None = None, store_path: str | Path | None = None, ) -> list[dict]: """ 统一对外提供检索入口。 与 ingest_document 保持一致: - 真实运行优先走 Chroma - 测试或降级模式走本地 JSON + 轻量文本打分 """ if _should_use_chroma(store_path): return query_chunks( scenario_id=scenario_id, query=query, collection=collection, top_k=top_k, document_ids=document_ids, ) resolved_store_path = Path(store_path) if store_path else _default_store_path() query_tokens = _tokens(query) allowed_document_ids = set(document_ids or []) scored_chunks = [] for chunk in _load_store(resolved_store_path): if not _matches_scope( chunk=chunk, scenario_id=scenario_id, collection=collection, allowed_document_ids=allowed_document_ids, ): continue score = _score(query_tokens, chunk.get("content", "")) if score <= 0: continue scored_chunks.append({**chunk, "score": score}) return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k] def _should_use_chroma(store_path: str | Path | None) -> bool: return store_path is None and importlib.util.find_spec("chromadb") is not None def _default_store_path() -> Path: return Path(settings.CHROMA_PATH) / "rag_store.json" def _load_store(store_path: Path) -> list[dict]: if not store_path.exists(): return [] with store_path.open("r", encoding="utf-8") as file: return json.load(file) def _matches_scope( chunk: dict, scenario_id: str, collection: str, allowed_document_ids: set[int], ) -> bool: """先按场景、collection 和可选文档范围过滤,再进行相关性打分。""" if chunk.get("scenario_id") != scenario_id: return False if chunk.get("collection") != collection: return False if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids: return False return True def _tokens(text: str) -> set[str]: """ 兼容中英文的轻量分词策略。 该分词仅用于 fallback 模式,不替代真实向量检索: - 英文/数字按词提取 - 中文按连续词片段和单字同时保留,提升短查询命中率 """ lowered = text.lower() ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered)) cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered)) chars = {char for char in lowered if "\u4e00" <= char <= "\u9fff"} return ascii_tokens | cjk_tokens | chars def _score(query_tokens: set[str], content: str) -> float: """使用交集占比计算一个便于排序的简化相关性分数。""" content_tokens = _tokens(content) if not query_tokens or not content_tokens: return 0.0 overlap = query_tokens & content_tokens return round(len(overlap) / len(query_tokens), 4)