refactor(rag): 梳理文档入库与检索服务结构

This commit is contained in:
2026-05-30 00:44:52 +08:00
parent f68b44f325
commit ccfe5eb667
6 changed files with 284 additions and 103 deletions

View File

@@ -1,6 +1,6 @@
import importlib.util
import json
import re
import importlib.util
from pathlib import Path
from django.conf import settings
@@ -8,6 +8,52 @@ from django.conf import settings
from .chroma_store import query_chunks
def retrieve(
scenario_id: str,
query: str,
collection: str,
top_k: int = 5,
document_ids: list[int] | None = None,
store_path: str | Path | None = None,
) -> list[dict]:
"""
统一对外提供检索入口。
与 ingest_document 保持一致:
- 真实运行优先走 Chroma
- 测试或降级模式走本地 JSON + 轻量文本打分
"""
if _should_use_chroma(store_path):
return query_chunks(
scenario_id=scenario_id,
query=query,
collection=collection,
top_k=top_k,
document_ids=document_ids,
)
resolved_store_path = Path(store_path) if store_path else _default_store_path()
query_tokens = _tokens(query)
allowed_document_ids = set(document_ids or [])
scored_chunks = []
for chunk in _load_store(resolved_store_path):
if not _matches_scope(
chunk=chunk,
scenario_id=scenario_id,
collection=collection,
allowed_document_ids=allowed_document_ids,
):
continue
score = _score(query_tokens, chunk.get("content", ""))
if score <= 0:
continue
scored_chunks.append({**chunk, "score": score})
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]
def _should_use_chroma(store_path: str | Path | None) -> bool:
return store_path is None and importlib.util.find_spec("chromadb") is not None
def _default_store_path() -> Path:
return Path(settings.CHROMA_PATH) / "rag_store.json"
@@ -19,7 +65,30 @@ def _load_store(store_path: Path) -> list[dict]:
return json.load(file)
def _matches_scope(
chunk: dict,
scenario_id: str,
collection: str,
allowed_document_ids: set[int],
) -> bool:
"""先按场景、collection 和可选文档范围过滤,再进行相关性打分。"""
if chunk.get("scenario_id") != scenario_id:
return False
if chunk.get("collection") != collection:
return False
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
return False
return True
def _tokens(text: str) -> set[str]:
"""
兼容中英文的轻量分词策略。
该分词仅用于 fallback 模式,不替代真实向量检索:
- 英文/数字按词提取
- 中文按连续词片段和单字同时保留,提升短查询命中率
"""
lowered = text.lower()
ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered))
cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered))
@@ -28,42 +97,9 @@ def _tokens(text: str) -> set[str]:
def _score(query_tokens: set[str], content: str) -> float:
"""使用交集占比计算一个便于排序的简化相关性分数。"""
content_tokens = _tokens(content)
if not query_tokens or not content_tokens:
return 0.0
overlap = query_tokens & content_tokens
return round(len(overlap) / len(query_tokens), 4)
def retrieve(
scenario_id: str,
query: str,
collection: str,
top_k: int = 5,
document_ids: list[int] | None = None,
store_path: str | Path | None = None,
) -> list[dict]:
if store_path is None and importlib.util.find_spec("chromadb") is not None:
return query_chunks(
scenario_id=scenario_id,
query=query,
collection=collection,
top_k=top_k,
document_ids=document_ids,
)
resolved_store_path = Path(store_path) if store_path else _default_store_path()
query_tokens = _tokens(query)
allowed_document_ids = set(document_ids or [])
scored_chunks = []
for chunk in _load_store(resolved_store_path):
if chunk.get("scenario_id") != scenario_id:
continue
if chunk.get("collection") != collection:
continue
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
continue
score = _score(query_tokens, chunk.get("content", ""))
if score <= 0:
continue
scored_chunks.append({**chunk, "score": score})
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]