refactor(rag): 梳理文档入库与检索服务结构

This commit is contained in:
2026-05-30 00:44:52 +08:00
parent f68b44f325
commit ccfe5eb667
6 changed files with 284 additions and 103 deletions

View File

@@ -6,6 +6,7 @@ from agent_core.llm_provider import create_embedding_provider
def _client(path: str | Path | None = None): def _client(path: str | Path | None = None):
"""按给定路径初始化 Chroma 持久化客户端。"""
import chromadb import chromadb
resolved_path = str(path or settings.CHROMA_PATH) resolved_path = str(path or settings.CHROMA_PATH)
@@ -13,6 +14,7 @@ def _client(path: str | Path | None = None):
def _embedding_provider(): def _embedding_provider():
"""从 Django settings 构造 Embedding Provider避免在业务层散落配置读取。"""
return create_embedding_provider( return create_embedding_provider(
{ {
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY, "EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
@@ -27,6 +29,11 @@ def upsert_chunks(
chunks: list[dict], chunks: list[dict],
store_path: str | Path | None = None, store_path: str | Path | None = None,
) -> None: ) -> None:
"""
将 chunk 写入 Chroma。
同一 document_id 重新入库前会先删除旧记录,保证一次文档只有一份有效向量数据。
"""
client = _client(store_path) client = _client(store_path)
chroma_collection = client.get_or_create_collection(collection) chroma_collection = client.get_or_create_collection(collection)
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None} document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
@@ -59,6 +66,7 @@ def query_chunks(
document_ids: list[int] | None = None, document_ids: list[int] | None = None,
store_path: str | Path | None = None, store_path: str | Path | None = None,
) -> list[dict]: ) -> list[dict]:
"""执行向量检索,并把 Chroma 原始结果转换为统一引用结构。"""
client = _client(store_path) client = _client(store_path)
chroma_collection = client.get_or_create_collection(collection) chroma_collection = client.get_or_create_collection(collection)
where: dict = {"scenario_id": scenario_id} where: dict = {"scenario_id": scenario_id}

View File

@@ -1,6 +1,6 @@
import importlib.util
import json import json
import re import re
import importlib.util
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
@@ -12,11 +12,61 @@ from .chroma_store import upsert_chunks
@dataclass @dataclass
class IngestResult: class IngestResult:
"""RAG 入库统一返回结构,供 Documents 模块稳定消费。"""
success: bool success: bool
chunks_count: int = 0 chunks_count: int = 0
error: str = "" error: str = ""
def ingest_document(
scenario_id: str,
source_file: str,
text: str,
collection: str,
document_id: int | None = None,
store_path: str | Path | None = None,
) -> IngestResult:
"""
将单个文档文本切分后写入知识库。
运行策略:
- 如果显式传入 `store_path`,说明当前是测试或降级模式,走本地 JSON 存储。
- 如果未传入且环境可用 chromadb则走真实 Chroma 持久化。
"""
if not text.strip():
return IngestResult(success=False, error="文档内容为空")
if _should_use_chroma(store_path):
return _ingest_chroma_document(
document_id=document_id,
scenario_id=scenario_id,
source_file=source_file,
text=text,
collection=collection,
)
resolved_store_path = Path(store_path) if store_path else _default_store_path()
chunks = _build_chunks(
scenario_id=scenario_id,
source_file=source_file,
text=text,
collection=collection,
document_id=document_id,
chunk_id_prefix=source_file,
)
persisted_chunks = _filter_out_same_document_chunks(
_load_store(resolved_store_path),
scenario_id=scenario_id,
collection=collection,
document_id=document_id,
)
_save_store(resolved_store_path, [*persisted_chunks, *chunks])
return IngestResult(success=True, chunks_count=len(chunks))
def _should_use_chroma(store_path: str | Path | None) -> bool:
"""只在未指定测试存储路径且安装 chromadb 时启用真实向量库。"""
return store_path is None and importlib.util.find_spec("chromadb") is not None
def _default_store_path() -> Path: def _default_store_path() -> Path:
return Path(settings.CHROMA_PATH) / "rag_store.json" return Path(settings.CHROMA_PATH) / "rag_store.json"
@@ -35,6 +85,13 @@ def _save_store(store_path: Path, chunks: list[dict]) -> None:
def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[str]: def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[str]:
"""
使用固定窗口 + overlap 切分文本。
该策略简单但稳定,便于解释:
- chunk_size 控制每个片段最大长度
- overlap 保证相邻片段共享上下文,降低边界信息丢失
"""
normalized = re.sub(r"\s+", " ", text).strip() normalized = re.sub(r"\s+", " ", text).strip()
if not normalized: if not normalized:
return [] return []
@@ -49,44 +106,46 @@ def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[st
return chunks return chunks
def ingest_document( def _build_chunks(
scenario_id: str, scenario_id: str,
source_file: str, source_file: str,
text: str, text: str,
collection: str, collection: str,
document_id: int | None = None, document_id: int | None,
store_path: str | Path | None = None, chunk_id_prefix: str,
) -> IngestResult: ) -> list[dict]:
if not text.strip(): """把原始文本切分并封装为统一 chunk 结构。"""
return IngestResult(success=False, error="文档内容为空") created_at = datetime.now(timezone.utc).isoformat()
if store_path is None and importlib.util.find_spec("chromadb") is not None: return [
return _ingest_chroma_document(document_id, scenario_id, source_file, text, collection) {
resolved_store_path = Path(store_path) if store_path else _default_store_path() "scenario_id": scenario_id,
existing_chunks = [ "document_id": document_id,
"collection": collection,
"source": source_file,
"chunk_id": f"{scenario_id}:{chunk_id_prefix}:{index}",
"content": chunk_text,
"created_at": created_at,
}
for index, chunk_text in enumerate(_split_text(text), start=1)
]
def _filter_out_same_document_chunks(
chunks: list[dict],
scenario_id: str,
collection: str,
document_id: int | None,
) -> list[dict]:
"""重新入库同一 document_id 时,先删除旧 chunk避免重复检索。"""
return [
chunk chunk
for chunk in _load_store(resolved_store_path) for chunk in chunks
if not ( if not (
chunk.get("document_id") == document_id chunk.get("document_id") == document_id
and chunk.get("scenario_id") == scenario_id and chunk.get("scenario_id") == scenario_id
and chunk.get("collection") == collection and chunk.get("collection") == collection
) )
] ]
created_at = datetime.now(timezone.utc).isoformat()
new_chunks = []
for index, chunk_text in enumerate(_split_text(text), start=1):
new_chunks.append(
{
"scenario_id": scenario_id,
"document_id": document_id,
"collection": collection,
"source": source_file,
"chunk_id": f"{scenario_id}:{source_file}:{index}",
"content": chunk_text,
"created_at": created_at,
}
)
_save_store(resolved_store_path, [*existing_chunks, *new_chunks])
return IngestResult(success=True, chunks_count=len(new_chunks))
def _ingest_chroma_document( def _ingest_chroma_document(
@@ -96,19 +155,15 @@ def _ingest_chroma_document(
text: str, text: str,
collection: str, collection: str,
) -> IngestResult: ) -> IngestResult:
created_at = datetime.now(timezone.utc).isoformat() """真实 Chroma 模式的入库分支。"""
chunks = [ chunks = _build_chunks(
{ scenario_id=scenario_id,
"scenario_id": scenario_id, source_file=source_file,
"document_id": document_id, text=text,
"collection": collection, collection=collection,
"source": source_file, document_id=document_id,
"chunk_id": f"{scenario_id}:{document_id or source_file}:{index}", chunk_id_prefix=str(document_id or source_file),
"content": chunk_text, )
"created_at": created_at,
}
for index, chunk_text in enumerate(_split_text(text), start=1)
]
try: try:
upsert_chunks(collection=collection, chunks=chunks) upsert_chunks(collection=collection, chunks=chunks)
except Exception as exc: except Exception as exc:

View File

@@ -1,6 +1,6 @@
import importlib.util
import json import json
import re import re
import importlib.util
from pathlib import Path from pathlib import Path
from django.conf import settings from django.conf import settings
@@ -8,6 +8,52 @@ from django.conf import settings
from .chroma_store import query_chunks from .chroma_store import query_chunks
def retrieve(
scenario_id: str,
query: str,
collection: str,
top_k: int = 5,
document_ids: list[int] | None = None,
store_path: str | Path | None = None,
) -> list[dict]:
"""
统一对外提供检索入口。
与 ingest_document 保持一致:
- 真实运行优先走 Chroma
- 测试或降级模式走本地 JSON + 轻量文本打分
"""
if _should_use_chroma(store_path):
return query_chunks(
scenario_id=scenario_id,
query=query,
collection=collection,
top_k=top_k,
document_ids=document_ids,
)
resolved_store_path = Path(store_path) if store_path else _default_store_path()
query_tokens = _tokens(query)
allowed_document_ids = set(document_ids or [])
scored_chunks = []
for chunk in _load_store(resolved_store_path):
if not _matches_scope(
chunk=chunk,
scenario_id=scenario_id,
collection=collection,
allowed_document_ids=allowed_document_ids,
):
continue
score = _score(query_tokens, chunk.get("content", ""))
if score <= 0:
continue
scored_chunks.append({**chunk, "score": score})
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]
def _should_use_chroma(store_path: str | Path | None) -> bool:
return store_path is None and importlib.util.find_spec("chromadb") is not None
def _default_store_path() -> Path: def _default_store_path() -> Path:
return Path(settings.CHROMA_PATH) / "rag_store.json" return Path(settings.CHROMA_PATH) / "rag_store.json"
@@ -19,7 +65,30 @@ def _load_store(store_path: Path) -> list[dict]:
return json.load(file) return json.load(file)
def _matches_scope(
chunk: dict,
scenario_id: str,
collection: str,
allowed_document_ids: set[int],
) -> bool:
"""先按场景、collection 和可选文档范围过滤,再进行相关性打分。"""
if chunk.get("scenario_id") != scenario_id:
return False
if chunk.get("collection") != collection:
return False
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
return False
return True
def _tokens(text: str) -> set[str]: def _tokens(text: str) -> set[str]:
"""
兼容中英文的轻量分词策略。
该分词仅用于 fallback 模式,不替代真实向量检索:
- 英文/数字按词提取
- 中文按连续词片段和单字同时保留,提升短查询命中率
"""
lowered = text.lower() lowered = text.lower()
ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered)) ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered))
cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered)) cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered))
@@ -28,42 +97,9 @@ def _tokens(text: str) -> set[str]:
def _score(query_tokens: set[str], content: str) -> float: def _score(query_tokens: set[str], content: str) -> float:
"""使用交集占比计算一个便于排序的简化相关性分数。"""
content_tokens = _tokens(content) content_tokens = _tokens(content)
if not query_tokens or not content_tokens: if not query_tokens or not content_tokens:
return 0.0 return 0.0
overlap = query_tokens & content_tokens overlap = query_tokens & content_tokens
return round(len(overlap) / len(query_tokens), 4) return round(len(overlap) / len(query_tokens), 4)
def retrieve(
scenario_id: str,
query: str,
collection: str,
top_k: int = 5,
document_ids: list[int] | None = None,
store_path: str | Path | None = None,
) -> list[dict]:
if store_path is None and importlib.util.find_spec("chromadb") is not None:
return query_chunks(
scenario_id=scenario_id,
query=query,
collection=collection,
top_k=top_k,
document_ids=document_ids,
)
resolved_store_path = Path(store_path) if store_path else _default_store_path()
query_tokens = _tokens(query)
allowed_document_ids = set(document_ids or [])
scored_chunks = []
for chunk in _load_store(resolved_store_path):
if chunk.get("scenario_id") != scenario_id:
continue
if chunk.get("collection") != collection:
continue
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
continue
score = _score(query_tokens, chunk.get("content", ""))
if score <= 0:
continue
scored_chunks.append({**chunk, "score": score})
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]

View File

@@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
from zipfile import BadZipFile, ZipFile
import re import re
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from zipfile import BadZipFile, ZipFile
from agent_core.rag.ingest import ingest_document from agent_core.rag.ingest import ingest_document
@@ -9,7 +9,13 @@ from .models import UploadedDocument
def create_uploaded_document(scenario_id: str, uploaded_file) -> UploadedDocument: def create_uploaded_document(scenario_id: str, uploaded_file) -> UploadedDocument:
extension = Path(uploaded_file.name).suffix.lower().lstrip(".") """
保存上传文件的元数据记录。
Documents 模块只记录文件与场景关系、原始名称、类型和大小,
真正的入库动作由用户后续主动触发,避免上传阶段就耦合 RAG 流程。
"""
extension = _detect_extension(uploaded_file.name)
return UploadedDocument.objects.create( return UploadedDocument.objects.create(
scenario_id=scenario_id, scenario_id=scenario_id,
original_name=uploaded_file.name, original_name=uploaded_file.name,
@@ -21,6 +27,14 @@ def create_uploaded_document(scenario_id: str, uploaded_file) -> UploadedDocumen
def extract_text(document: UploadedDocument) -> str: def extract_text(document: UploadedDocument) -> str:
"""
根据文档类型选择合适的文本抽取策略。
V1 的目标是“可演示且稳定”,因此:
- `.txt` / `.md` 直接按文本读取
- `.pdf` 优先走 pypdf失败时回退为二进制容错读取
- `.docx` 优先解析 Word XML失败时回退为二进制容错读取
"""
path = Path(document.file.path) path = Path(document.file.path)
extension = f".{document.file_type.lower().lstrip('.')}" extension = f".{document.file_type.lower().lstrip('.')}"
if extension == ".pdf": if extension == ".pdf":
@@ -30,7 +44,47 @@ def extract_text(document: UploadedDocument) -> str:
return _read_text_file(path) return _read_text_file(path)
def index_document(document: UploadedDocument) -> UploadedDocument:
"""
触发单个文档入库,并把成功/失败状态回写到 UploadedDocument。
这里故意不抛业务异常给 View
View 层只需要知道“最终状态是什么”,而错误信息统一落到模型字段中,
便于页面重试和演示。
"""
try:
text = extract_text(document)
ingest_result = ingest_document(
document_id=document.id,
scenario_id=document.scenario_id,
source_file=document.original_name,
text=text,
collection=document.scenario_id,
)
_apply_ingest_result(document, ingest_result.success, ingest_result.error)
except Exception as exc:
_apply_ingest_result(document, success=False, error=str(exc))
document.save(update_fields=["status", "error_message", "updated_at"])
return document
def _apply_ingest_result(document: UploadedDocument, success: bool, error: str = "") -> None:
"""把入库结果映射为 UploadedDocument 的稳定状态字段。"""
if success:
document.status = UploadedDocument.STATUS_INDEXED
document.error_message = ""
return
document.status = UploadedDocument.STATUS_FAILED
document.error_message = error
def _detect_extension(file_name: str) -> str:
"""统一将扩展名转成小写且去掉前导点,便于模型字段存储。"""
return Path(file_name).suffix.lower().lstrip(".")
def _read_text_file(path: Path) -> str: def _read_text_file(path: Path) -> str:
"""优先按 UTF-8 读取;失败时回退到系统默认编码。"""
try: try:
return path.read_text(encoding="utf-8") return path.read_text(encoding="utf-8")
except UnicodeDecodeError: except UnicodeDecodeError:
@@ -38,6 +92,7 @@ def _read_text_file(path: Path) -> str:
def _extract_pdf_text(path: Path) -> str: def _extract_pdf_text(path: Path) -> str:
"""优先使用 pypdf 抽取 PDF 文本,失败时回退到容错方案。"""
try: try:
import pypdf import pypdf
@@ -48,6 +103,7 @@ def _extract_pdf_text(path: Path) -> str:
def _extract_docx_text(path: Path) -> str: def _extract_docx_text(path: Path) -> str:
"""提取 Word XML 中的可见文字内容,不追求保留样式。"""
try: try:
with ZipFile(path) as archive: with ZipFile(path) as archive:
document_xml = archive.read("word/document.xml") document_xml = archive.read("word/document.xml")
@@ -60,30 +116,12 @@ def _extract_docx_text(path: Path) -> str:
def _read_binary_text_fallback(path: Path) -> str: def _read_binary_text_fallback(path: Path) -> str:
"""
当结构化抽取失败时,退回到“尽可能保留纯文本”的保底方案。
该方案不保证版式,但足以支撑 V1 入库和演示。
"""
data = path.read_bytes() data = path.read_bytes()
text = data.decode("utf-8", errors="ignore") text = data.decode("utf-8", errors="ignore")
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]+", " ", text) text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]+", " ", text)
return text.strip() return text.strip()
def index_document(document: UploadedDocument) -> UploadedDocument:
try:
text = extract_text(document)
result = ingest_document(
document_id=document.id,
scenario_id=document.scenario_id,
source_file=document.original_name,
text=text,
collection=document.scenario_id,
)
if result.success:
document.status = UploadedDocument.STATUS_INDEXED
document.error_message = ""
else:
document.status = UploadedDocument.STATUS_FAILED
document.error_message = result.error
except Exception as exc:
document.status = UploadedDocument.STATUS_FAILED
document.error_message = str(exc)
document.save(update_fields=["status", "error_message", "updated_at"])
return document

View File

@@ -1,5 +1,5 @@
from agent_core.orchestrator import build_messages, run_agent from agent_core.orchestrator import build_messages, run_agent
from agent_core.rag.ingest import ingest_document from agent_core.rag.ingest import _split_text, ingest_document
from agent_core.rag.retriever import retrieve from agent_core.rag.retriever import retrieve
@@ -221,3 +221,30 @@ def test_run_agent_uses_retrieved_document_chunks(tmp_path):
assert result.references[0]["source"] == "sop.md" assert result.references[0]["source"] == "sop.md"
assert "隔离现场" in result.references[0]["content"] assert "隔离现场" in result.references[0]["content"]
def test_rag_split_text_keeps_overlap_and_non_empty_chunks():
chunks = _split_text("A" * 20, chunk_size=8, overlap=3)
assert chunks == ["AAAAAAAA", "AAAAAAAA", "AAAAAAAA", "AAAAA"]
def test_retrieve_returns_empty_when_query_has_no_overlap(tmp_path):
store_path = tmp_path / "rag_store.json"
ingest_document(
scenario_id="knowledge_qa",
source_file="rules.md",
text="这里描述的是报销流程和审批链。",
collection="knowledge_qa",
store_path=store_path,
)
chunks = retrieve(
scenario_id="knowledge_qa",
query="设备点检",
collection="knowledge_qa",
top_k=3,
store_path=store_path,
)
assert chunks == []

View File

@@ -3,7 +3,7 @@ from django.urls import reverse
from apps.documents.forms import DocumentUploadForm from apps.documents.forms import DocumentUploadForm
from apps.documents.models import UploadedDocument from apps.documents.models import UploadedDocument
from apps.documents.services import extract_text from apps.documents.services import extract_text, index_document
def test_upload_txt_document_creates_uploaded_record(client, db): def test_upload_txt_document_creates_uploaded_record(client, db):
@@ -128,3 +128,20 @@ def test_index_failure_message_is_visible_on_document_list(client, db, monkeypat
assert response.status_code == 200 assert response.status_code == 200
assert "文档入库失败,请检查错误原因后重试" in content assert "文档入库失败,请检查错误原因后重试" in content
assert "模拟入库失败" in content assert "模拟入库失败" in content
def test_index_document_marks_failed_when_extracted_text_is_empty(db, monkeypatch):
document = UploadedDocument.objects.create(
scenario_id="knowledge_qa",
original_name="empty.md",
file_type="md",
size=0,
status="uploaded",
)
monkeypatch.setattr("apps.documents.services.extract_text", lambda target: " ")
updated_document = index_document(document)
assert updated_document.status == UploadedDocument.STATUS_FAILED
assert "文档内容为空" in updated_document.error_message