refactor(rag): 梳理文档入库与检索服务结构
This commit is contained in:
@@ -6,6 +6,7 @@ from agent_core.llm_provider import create_embedding_provider
|
||||
|
||||
|
||||
def _client(path: str | Path | None = None):
|
||||
"""按给定路径初始化 Chroma 持久化客户端。"""
|
||||
import chromadb
|
||||
|
||||
resolved_path = str(path or settings.CHROMA_PATH)
|
||||
@@ -13,6 +14,7 @@ def _client(path: str | Path | None = None):
|
||||
|
||||
|
||||
def _embedding_provider():
|
||||
"""从 Django settings 构造 Embedding Provider,避免在业务层散落配置读取。"""
|
||||
return create_embedding_provider(
|
||||
{
|
||||
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
|
||||
@@ -27,6 +29,11 @@ def upsert_chunks(
|
||||
chunks: list[dict],
|
||||
store_path: str | Path | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
将 chunk 写入 Chroma。
|
||||
|
||||
同一 document_id 重新入库前会先删除旧记录,保证一次文档只有一份有效向量数据。
|
||||
"""
|
||||
client = _client(store_path)
|
||||
chroma_collection = client.get_or_create_collection(collection)
|
||||
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
|
||||
@@ -59,6 +66,7 @@ def query_chunks(
|
||||
document_ids: list[int] | None = None,
|
||||
store_path: str | Path | None = None,
|
||||
) -> list[dict]:
|
||||
"""执行向量检索,并把 Chroma 原始结果转换为统一引用结构。"""
|
||||
client = _client(store_path)
|
||||
chroma_collection = client.get_or_create_collection(collection)
|
||||
where: dict = {"scenario_id": scenario_id}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import importlib.util
|
||||
import json
|
||||
import re
|
||||
import importlib.util
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -12,11 +12,61 @@ from .chroma_store import upsert_chunks
|
||||
|
||||
@dataclass
|
||||
class IngestResult:
|
||||
"""RAG 入库统一返回结构,供 Documents 模块稳定消费。"""
|
||||
success: bool
|
||||
chunks_count: int = 0
|
||||
error: str = ""
|
||||
|
||||
|
||||
def ingest_document(
|
||||
scenario_id: str,
|
||||
source_file: str,
|
||||
text: str,
|
||||
collection: str,
|
||||
document_id: int | None = None,
|
||||
store_path: str | Path | None = None,
|
||||
) -> IngestResult:
|
||||
"""
|
||||
将单个文档文本切分后写入知识库。
|
||||
|
||||
运行策略:
|
||||
- 如果显式传入 `store_path`,说明当前是测试或降级模式,走本地 JSON 存储。
|
||||
- 如果未传入且环境可用 chromadb,则走真实 Chroma 持久化。
|
||||
"""
|
||||
if not text.strip():
|
||||
return IngestResult(success=False, error="文档内容为空")
|
||||
if _should_use_chroma(store_path):
|
||||
return _ingest_chroma_document(
|
||||
document_id=document_id,
|
||||
scenario_id=scenario_id,
|
||||
source_file=source_file,
|
||||
text=text,
|
||||
collection=collection,
|
||||
)
|
||||
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
||||
chunks = _build_chunks(
|
||||
scenario_id=scenario_id,
|
||||
source_file=source_file,
|
||||
text=text,
|
||||
collection=collection,
|
||||
document_id=document_id,
|
||||
chunk_id_prefix=source_file,
|
||||
)
|
||||
persisted_chunks = _filter_out_same_document_chunks(
|
||||
_load_store(resolved_store_path),
|
||||
scenario_id=scenario_id,
|
||||
collection=collection,
|
||||
document_id=document_id,
|
||||
)
|
||||
_save_store(resolved_store_path, [*persisted_chunks, *chunks])
|
||||
return IngestResult(success=True, chunks_count=len(chunks))
|
||||
|
||||
|
||||
def _should_use_chroma(store_path: str | Path | None) -> bool:
|
||||
"""只在未指定测试存储路径且安装 chromadb 时启用真实向量库。"""
|
||||
return store_path is None and importlib.util.find_spec("chromadb") is not None
|
||||
|
||||
|
||||
def _default_store_path() -> Path:
|
||||
return Path(settings.CHROMA_PATH) / "rag_store.json"
|
||||
|
||||
@@ -35,6 +85,13 @@ def _save_store(store_path: Path, chunks: list[dict]) -> None:
|
||||
|
||||
|
||||
def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[str]:
|
||||
"""
|
||||
使用固定窗口 + overlap 切分文本。
|
||||
|
||||
该策略简单但稳定,便于解释:
|
||||
- chunk_size 控制每个片段最大长度
|
||||
- overlap 保证相邻片段共享上下文,降低边界信息丢失
|
||||
"""
|
||||
normalized = re.sub(r"\s+", " ", text).strip()
|
||||
if not normalized:
|
||||
return []
|
||||
@@ -49,44 +106,46 @@ def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[st
|
||||
return chunks
|
||||
|
||||
|
||||
def ingest_document(
|
||||
def _build_chunks(
|
||||
scenario_id: str,
|
||||
source_file: str,
|
||||
text: str,
|
||||
collection: str,
|
||||
document_id: int | None = None,
|
||||
store_path: str | Path | None = None,
|
||||
) -> IngestResult:
|
||||
if not text.strip():
|
||||
return IngestResult(success=False, error="文档内容为空")
|
||||
if store_path is None and importlib.util.find_spec("chromadb") is not None:
|
||||
return _ingest_chroma_document(document_id, scenario_id, source_file, text, collection)
|
||||
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
||||
existing_chunks = [
|
||||
document_id: int | None,
|
||||
chunk_id_prefix: str,
|
||||
) -> list[dict]:
|
||||
"""把原始文本切分并封装为统一 chunk 结构。"""
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
return [
|
||||
{
|
||||
"scenario_id": scenario_id,
|
||||
"document_id": document_id,
|
||||
"collection": collection,
|
||||
"source": source_file,
|
||||
"chunk_id": f"{scenario_id}:{chunk_id_prefix}:{index}",
|
||||
"content": chunk_text,
|
||||
"created_at": created_at,
|
||||
}
|
||||
for index, chunk_text in enumerate(_split_text(text), start=1)
|
||||
]
|
||||
|
||||
|
||||
def _filter_out_same_document_chunks(
|
||||
chunks: list[dict],
|
||||
scenario_id: str,
|
||||
collection: str,
|
||||
document_id: int | None,
|
||||
) -> list[dict]:
|
||||
"""重新入库同一 document_id 时,先删除旧 chunk,避免重复检索。"""
|
||||
return [
|
||||
chunk
|
||||
for chunk in _load_store(resolved_store_path)
|
||||
for chunk in chunks
|
||||
if not (
|
||||
chunk.get("document_id") == document_id
|
||||
and chunk.get("scenario_id") == scenario_id
|
||||
and chunk.get("collection") == collection
|
||||
)
|
||||
]
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
new_chunks = []
|
||||
for index, chunk_text in enumerate(_split_text(text), start=1):
|
||||
new_chunks.append(
|
||||
{
|
||||
"scenario_id": scenario_id,
|
||||
"document_id": document_id,
|
||||
"collection": collection,
|
||||
"source": source_file,
|
||||
"chunk_id": f"{scenario_id}:{source_file}:{index}",
|
||||
"content": chunk_text,
|
||||
"created_at": created_at,
|
||||
}
|
||||
)
|
||||
_save_store(resolved_store_path, [*existing_chunks, *new_chunks])
|
||||
return IngestResult(success=True, chunks_count=len(new_chunks))
|
||||
|
||||
|
||||
def _ingest_chroma_document(
|
||||
@@ -96,19 +155,15 @@ def _ingest_chroma_document(
|
||||
text: str,
|
||||
collection: str,
|
||||
) -> IngestResult:
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
chunks = [
|
||||
{
|
||||
"scenario_id": scenario_id,
|
||||
"document_id": document_id,
|
||||
"collection": collection,
|
||||
"source": source_file,
|
||||
"chunk_id": f"{scenario_id}:{document_id or source_file}:{index}",
|
||||
"content": chunk_text,
|
||||
"created_at": created_at,
|
||||
}
|
||||
for index, chunk_text in enumerate(_split_text(text), start=1)
|
||||
]
|
||||
"""真实 Chroma 模式的入库分支。"""
|
||||
chunks = _build_chunks(
|
||||
scenario_id=scenario_id,
|
||||
source_file=source_file,
|
||||
text=text,
|
||||
collection=collection,
|
||||
document_id=document_id,
|
||||
chunk_id_prefix=str(document_id or source_file),
|
||||
)
|
||||
try:
|
||||
upsert_chunks(collection=collection, chunks=chunks)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import importlib.util
|
||||
import json
|
||||
import re
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
@@ -8,6 +8,52 @@ from django.conf import settings
|
||||
from .chroma_store import query_chunks
|
||||
|
||||
|
||||
def retrieve(
|
||||
scenario_id: str,
|
||||
query: str,
|
||||
collection: str,
|
||||
top_k: int = 5,
|
||||
document_ids: list[int] | None = None,
|
||||
store_path: str | Path | None = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
统一对外提供检索入口。
|
||||
|
||||
与 ingest_document 保持一致:
|
||||
- 真实运行优先走 Chroma
|
||||
- 测试或降级模式走本地 JSON + 轻量文本打分
|
||||
"""
|
||||
if _should_use_chroma(store_path):
|
||||
return query_chunks(
|
||||
scenario_id=scenario_id,
|
||||
query=query,
|
||||
collection=collection,
|
||||
top_k=top_k,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
||||
query_tokens = _tokens(query)
|
||||
allowed_document_ids = set(document_ids or [])
|
||||
scored_chunks = []
|
||||
for chunk in _load_store(resolved_store_path):
|
||||
if not _matches_scope(
|
||||
chunk=chunk,
|
||||
scenario_id=scenario_id,
|
||||
collection=collection,
|
||||
allowed_document_ids=allowed_document_ids,
|
||||
):
|
||||
continue
|
||||
score = _score(query_tokens, chunk.get("content", ""))
|
||||
if score <= 0:
|
||||
continue
|
||||
scored_chunks.append({**chunk, "score": score})
|
||||
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]
|
||||
|
||||
|
||||
def _should_use_chroma(store_path: str | Path | None) -> bool:
|
||||
return store_path is None and importlib.util.find_spec("chromadb") is not None
|
||||
|
||||
|
||||
def _default_store_path() -> Path:
|
||||
return Path(settings.CHROMA_PATH) / "rag_store.json"
|
||||
|
||||
@@ -19,7 +65,30 @@ def _load_store(store_path: Path) -> list[dict]:
|
||||
return json.load(file)
|
||||
|
||||
|
||||
def _matches_scope(
|
||||
chunk: dict,
|
||||
scenario_id: str,
|
||||
collection: str,
|
||||
allowed_document_ids: set[int],
|
||||
) -> bool:
|
||||
"""先按场景、collection 和可选文档范围过滤,再进行相关性打分。"""
|
||||
if chunk.get("scenario_id") != scenario_id:
|
||||
return False
|
||||
if chunk.get("collection") != collection:
|
||||
return False
|
||||
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _tokens(text: str) -> set[str]:
|
||||
"""
|
||||
兼容中英文的轻量分词策略。
|
||||
|
||||
该分词仅用于 fallback 模式,不替代真实向量检索:
|
||||
- 英文/数字按词提取
|
||||
- 中文按连续词片段和单字同时保留,提升短查询命中率
|
||||
"""
|
||||
lowered = text.lower()
|
||||
ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered))
|
||||
cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered))
|
||||
@@ -28,42 +97,9 @@ def _tokens(text: str) -> set[str]:
|
||||
|
||||
|
||||
def _score(query_tokens: set[str], content: str) -> float:
|
||||
"""使用交集占比计算一个便于排序的简化相关性分数。"""
|
||||
content_tokens = _tokens(content)
|
||||
if not query_tokens or not content_tokens:
|
||||
return 0.0
|
||||
overlap = query_tokens & content_tokens
|
||||
return round(len(overlap) / len(query_tokens), 4)
|
||||
|
||||
|
||||
def retrieve(
|
||||
scenario_id: str,
|
||||
query: str,
|
||||
collection: str,
|
||||
top_k: int = 5,
|
||||
document_ids: list[int] | None = None,
|
||||
store_path: str | Path | None = None,
|
||||
) -> list[dict]:
|
||||
if store_path is None and importlib.util.find_spec("chromadb") is not None:
|
||||
return query_chunks(
|
||||
scenario_id=scenario_id,
|
||||
query=query,
|
||||
collection=collection,
|
||||
top_k=top_k,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
||||
query_tokens = _tokens(query)
|
||||
allowed_document_ids = set(document_ids or [])
|
||||
scored_chunks = []
|
||||
for chunk in _load_store(resolved_store_path):
|
||||
if chunk.get("scenario_id") != scenario_id:
|
||||
continue
|
||||
if chunk.get("collection") != collection:
|
||||
continue
|
||||
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
|
||||
continue
|
||||
score = _score(query_tokens, chunk.get("content", ""))
|
||||
if score <= 0:
|
||||
continue
|
||||
scored_chunks.append({**chunk, "score": score})
|
||||
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]
|
||||
|
||||
Reference in New Issue
Block a user