From ccfe5eb667e845973a9efd524e76a147a63b747f Mon Sep 17 00:00:00 2001 From: bruce Date: Sat, 30 May 2026 00:44:52 +0800 Subject: [PATCH] =?UTF-8?q?refactor(rag):=20=E6=A2=B3=E7=90=86=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=85=A5=E5=BA=93=E4=B8=8E=E6=A3=80=E7=B4=A2=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_core/rag/chroma_store.py | 8 ++ agent_core/rag/ingest.py | 137 +++++++++++++++++++++++---------- agent_core/rag/retriever.py | 106 ++++++++++++++++--------- apps/documents/services.py | 88 +++++++++++++++------ tests/test_agent_core.py | 29 ++++++- tests/test_documents.py | 19 ++++- 6 files changed, 284 insertions(+), 103 deletions(-) diff --git a/agent_core/rag/chroma_store.py b/agent_core/rag/chroma_store.py index 01cfb6f..ec61414 100644 --- a/agent_core/rag/chroma_store.py +++ b/agent_core/rag/chroma_store.py @@ -6,6 +6,7 @@ from agent_core.llm_provider import create_embedding_provider def _client(path: str | Path | None = None): + """按给定路径初始化 Chroma 持久化客户端。""" import chromadb resolved_path = str(path or settings.CHROMA_PATH) @@ -13,6 +14,7 @@ def _client(path: str | Path | None = None): def _embedding_provider(): + """从 Django settings 构造 Embedding Provider,避免在业务层散落配置读取。""" return create_embedding_provider( { "EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY, @@ -27,6 +29,11 @@ def upsert_chunks( chunks: list[dict], store_path: str | Path | None = None, ) -> None: + """ + 将 chunk 写入 Chroma。 + + 同一 document_id 重新入库前会先删除旧记录,保证一次文档只有一份有效向量数据。 + """ client = _client(store_path) chroma_collection = client.get_or_create_collection(collection) document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None} @@ -59,6 +66,7 @@ def query_chunks( document_ids: list[int] | None = None, store_path: str | Path | None = None, ) -> list[dict]: + """执行向量检索,并把 Chroma 原始结果转换为统一引用结构。""" client = _client(store_path) chroma_collection = client.get_or_create_collection(collection) where: dict = {"scenario_id": scenario_id} diff --git a/agent_core/rag/ingest.py b/agent_core/rag/ingest.py index a3187a6..2a4ad32 100644 --- a/agent_core/rag/ingest.py +++ b/agent_core/rag/ingest.py @@ -1,6 +1,6 @@ +import importlib.util import json import re -import importlib.util from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path @@ -12,11 +12,61 @@ from .chroma_store import upsert_chunks @dataclass class IngestResult: + """RAG 入库统一返回结构,供 Documents 模块稳定消费。""" success: bool chunks_count: int = 0 error: str = "" +def ingest_document( + scenario_id: str, + source_file: str, + text: str, + collection: str, + document_id: int | None = None, + store_path: str | Path | None = None, +) -> IngestResult: + """ + 将单个文档文本切分后写入知识库。 + + 运行策略: + - 如果显式传入 `store_path`,说明当前是测试或降级模式,走本地 JSON 存储。 + - 如果未传入且环境可用 chromadb,则走真实 Chroma 持久化。 + """ + if not text.strip(): + return IngestResult(success=False, error="文档内容为空") + if _should_use_chroma(store_path): + return _ingest_chroma_document( + document_id=document_id, + scenario_id=scenario_id, + source_file=source_file, + text=text, + collection=collection, + ) + resolved_store_path = Path(store_path) if store_path else _default_store_path() + chunks = _build_chunks( + scenario_id=scenario_id, + source_file=source_file, + text=text, + collection=collection, + document_id=document_id, + chunk_id_prefix=source_file, + ) + persisted_chunks = _filter_out_same_document_chunks( + _load_store(resolved_store_path), + scenario_id=scenario_id, + collection=collection, + document_id=document_id, + ) + _save_store(resolved_store_path, [*persisted_chunks, *chunks]) + return IngestResult(success=True, chunks_count=len(chunks)) + + +def _should_use_chroma(store_path: str | Path | None) -> bool: + """只在未指定测试存储路径且安装 chromadb 时启用真实向量库。""" + return store_path is None and importlib.util.find_spec("chromadb") is not None + + def _default_store_path() -> Path: return Path(settings.CHROMA_PATH) / "rag_store.json" @@ -35,6 +85,13 @@ def _save_store(store_path: Path, chunks: list[dict]) -> None: def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[str]: + """ + 使用固定窗口 + overlap 切分文本。 + + 该策略简单但稳定,便于解释: + - chunk_size 控制每个片段最大长度 + - overlap 保证相邻片段共享上下文,降低边界信息丢失 + """ normalized = re.sub(r"\s+", " ", text).strip() if not normalized: return [] @@ -49,44 +106,46 @@ def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[st return chunks -def ingest_document( +def _build_chunks( scenario_id: str, source_file: str, text: str, collection: str, - document_id: int | None = None, - store_path: str | Path | None = None, -) -> IngestResult: - if not text.strip(): - return IngestResult(success=False, error="文档内容为空") - if store_path is None and importlib.util.find_spec("chromadb") is not None: - return _ingest_chroma_document(document_id, scenario_id, source_file, text, collection) - resolved_store_path = Path(store_path) if store_path else _default_store_path() - existing_chunks = [ + document_id: int | None, + chunk_id_prefix: str, +) -> list[dict]: + """把原始文本切分并封装为统一 chunk 结构。""" + created_at = datetime.now(timezone.utc).isoformat() + return [ + { + "scenario_id": scenario_id, + "document_id": document_id, + "collection": collection, + "source": source_file, + "chunk_id": f"{scenario_id}:{chunk_id_prefix}:{index}", + "content": chunk_text, + "created_at": created_at, + } + for index, chunk_text in enumerate(_split_text(text), start=1) + ] + + +def _filter_out_same_document_chunks( + chunks: list[dict], + scenario_id: str, + collection: str, + document_id: int | None, +) -> list[dict]: + """重新入库同一 document_id 时,先删除旧 chunk,避免重复检索。""" + return [ chunk - for chunk in _load_store(resolved_store_path) + for chunk in chunks if not ( chunk.get("document_id") == document_id and chunk.get("scenario_id") == scenario_id and chunk.get("collection") == collection ) ] - created_at = datetime.now(timezone.utc).isoformat() - new_chunks = [] - for index, chunk_text in enumerate(_split_text(text), start=1): - new_chunks.append( - { - "scenario_id": scenario_id, - "document_id": document_id, - "collection": collection, - "source": source_file, - "chunk_id": f"{scenario_id}:{source_file}:{index}", - "content": chunk_text, - "created_at": created_at, - } - ) - _save_store(resolved_store_path, [*existing_chunks, *new_chunks]) - return IngestResult(success=True, chunks_count=len(new_chunks)) def _ingest_chroma_document( @@ -96,19 +155,15 @@ def _ingest_chroma_document( text: str, collection: str, ) -> IngestResult: - created_at = datetime.now(timezone.utc).isoformat() - chunks = [ - { - "scenario_id": scenario_id, - "document_id": document_id, - "collection": collection, - "source": source_file, - "chunk_id": f"{scenario_id}:{document_id or source_file}:{index}", - "content": chunk_text, - "created_at": created_at, - } - for index, chunk_text in enumerate(_split_text(text), start=1) - ] + """真实 Chroma 模式的入库分支。""" + chunks = _build_chunks( + scenario_id=scenario_id, + source_file=source_file, + text=text, + collection=collection, + document_id=document_id, + chunk_id_prefix=str(document_id or source_file), + ) try: upsert_chunks(collection=collection, chunks=chunks) except Exception as exc: diff --git a/agent_core/rag/retriever.py b/agent_core/rag/retriever.py index d03a99b..f17def7 100644 --- a/agent_core/rag/retriever.py +++ b/agent_core/rag/retriever.py @@ -1,6 +1,6 @@ +import importlib.util import json import re -import importlib.util from pathlib import Path from django.conf import settings @@ -8,6 +8,52 @@ from django.conf import settings from .chroma_store import query_chunks +def retrieve( + scenario_id: str, + query: str, + collection: str, + top_k: int = 5, + document_ids: list[int] | None = None, + store_path: str | Path | None = None, +) -> list[dict]: + """ + 统一对外提供检索入口。 + + 与 ingest_document 保持一致: + - 真实运行优先走 Chroma + - 测试或降级模式走本地 JSON + 轻量文本打分 + """ + if _should_use_chroma(store_path): + return query_chunks( + scenario_id=scenario_id, + query=query, + collection=collection, + top_k=top_k, + document_ids=document_ids, + ) + resolved_store_path = Path(store_path) if store_path else _default_store_path() + query_tokens = _tokens(query) + allowed_document_ids = set(document_ids or []) + scored_chunks = [] + for chunk in _load_store(resolved_store_path): + if not _matches_scope( + chunk=chunk, + scenario_id=scenario_id, + collection=collection, + allowed_document_ids=allowed_document_ids, + ): + continue + score = _score(query_tokens, chunk.get("content", "")) + if score <= 0: + continue + scored_chunks.append({**chunk, "score": score}) + return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k] + + +def _should_use_chroma(store_path: str | Path | None) -> bool: + return store_path is None and importlib.util.find_spec("chromadb") is not None + + def _default_store_path() -> Path: return Path(settings.CHROMA_PATH) / "rag_store.json" @@ -19,7 +65,30 @@ def _load_store(store_path: Path) -> list[dict]: return json.load(file) +def _matches_scope( + chunk: dict, + scenario_id: str, + collection: str, + allowed_document_ids: set[int], +) -> bool: + """先按场景、collection 和可选文档范围过滤,再进行相关性打分。""" + if chunk.get("scenario_id") != scenario_id: + return False + if chunk.get("collection") != collection: + return False + if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids: + return False + return True + + def _tokens(text: str) -> set[str]: + """ + 兼容中英文的轻量分词策略。 + + 该分词仅用于 fallback 模式,不替代真实向量检索: + - 英文/数字按词提取 + - 中文按连续词片段和单字同时保留,提升短查询命中率 + """ lowered = text.lower() ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered)) cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered)) @@ -28,42 +97,9 @@ def _tokens(text: str) -> set[str]: def _score(query_tokens: set[str], content: str) -> float: + """使用交集占比计算一个便于排序的简化相关性分数。""" content_tokens = _tokens(content) if not query_tokens or not content_tokens: return 0.0 overlap = query_tokens & content_tokens return round(len(overlap) / len(query_tokens), 4) - - -def retrieve( - scenario_id: str, - query: str, - collection: str, - top_k: int = 5, - document_ids: list[int] | None = None, - store_path: str | Path | None = None, -) -> list[dict]: - if store_path is None and importlib.util.find_spec("chromadb") is not None: - return query_chunks( - scenario_id=scenario_id, - query=query, - collection=collection, - top_k=top_k, - document_ids=document_ids, - ) - resolved_store_path = Path(store_path) if store_path else _default_store_path() - query_tokens = _tokens(query) - allowed_document_ids = set(document_ids or []) - scored_chunks = [] - for chunk in _load_store(resolved_store_path): - if chunk.get("scenario_id") != scenario_id: - continue - if chunk.get("collection") != collection: - continue - if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids: - continue - score = _score(query_tokens, chunk.get("content", "")) - if score <= 0: - continue - scored_chunks.append({**chunk, "score": score}) - return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k] diff --git a/apps/documents/services.py b/apps/documents/services.py index 68976d8..a3b8e73 100644 --- a/apps/documents/services.py +++ b/apps/documents/services.py @@ -1,7 +1,7 @@ from pathlib import Path -from zipfile import BadZipFile, ZipFile import re import xml.etree.ElementTree as ET +from zipfile import BadZipFile, ZipFile from agent_core.rag.ingest import ingest_document @@ -9,7 +9,13 @@ from .models import UploadedDocument def create_uploaded_document(scenario_id: str, uploaded_file) -> UploadedDocument: - extension = Path(uploaded_file.name).suffix.lower().lstrip(".") + """ + 保存上传文件的元数据记录。 + + Documents 模块只记录文件与场景关系、原始名称、类型和大小, + 真正的入库动作由用户后续主动触发,避免上传阶段就耦合 RAG 流程。 + """ + extension = _detect_extension(uploaded_file.name) return UploadedDocument.objects.create( scenario_id=scenario_id, original_name=uploaded_file.name, @@ -21,6 +27,14 @@ def create_uploaded_document(scenario_id: str, uploaded_file) -> UploadedDocumen def extract_text(document: UploadedDocument) -> str: + """ + 根据文档类型选择合适的文本抽取策略。 + + V1 的目标是“可演示且稳定”,因此: + - `.txt` / `.md` 直接按文本读取 + - `.pdf` 优先走 pypdf,失败时回退为二进制容错读取 + - `.docx` 优先解析 Word XML,失败时回退为二进制容错读取 + """ path = Path(document.file.path) extension = f".{document.file_type.lower().lstrip('.')}" if extension == ".pdf": @@ -30,7 +44,47 @@ def extract_text(document: UploadedDocument) -> str: return _read_text_file(path) +def index_document(document: UploadedDocument) -> UploadedDocument: + """ + 触发单个文档入库,并把成功/失败状态回写到 UploadedDocument。 + + 这里故意不抛业务异常给 View: + View 层只需要知道“最终状态是什么”,而错误信息统一落到模型字段中, + 便于页面重试和演示。 + """ + try: + text = extract_text(document) + ingest_result = ingest_document( + document_id=document.id, + scenario_id=document.scenario_id, + source_file=document.original_name, + text=text, + collection=document.scenario_id, + ) + _apply_ingest_result(document, ingest_result.success, ingest_result.error) + except Exception as exc: + _apply_ingest_result(document, success=False, error=str(exc)) + document.save(update_fields=["status", "error_message", "updated_at"]) + return document + + +def _apply_ingest_result(document: UploadedDocument, success: bool, error: str = "") -> None: + """把入库结果映射为 UploadedDocument 的稳定状态字段。""" + if success: + document.status = UploadedDocument.STATUS_INDEXED + document.error_message = "" + return + document.status = UploadedDocument.STATUS_FAILED + document.error_message = error + + +def _detect_extension(file_name: str) -> str: + """统一将扩展名转成小写且去掉前导点,便于模型字段存储。""" + return Path(file_name).suffix.lower().lstrip(".") + + def _read_text_file(path: Path) -> str: + """优先按 UTF-8 读取;失败时回退到系统默认编码。""" try: return path.read_text(encoding="utf-8") except UnicodeDecodeError: @@ -38,6 +92,7 @@ def _read_text_file(path: Path) -> str: def _extract_pdf_text(path: Path) -> str: + """优先使用 pypdf 抽取 PDF 文本,失败时回退到容错方案。""" try: import pypdf @@ -48,6 +103,7 @@ def _extract_pdf_text(path: Path) -> str: def _extract_docx_text(path: Path) -> str: + """提取 Word XML 中的可见文字内容,不追求保留样式。""" try: with ZipFile(path) as archive: document_xml = archive.read("word/document.xml") @@ -60,30 +116,12 @@ def _extract_docx_text(path: Path) -> str: def _read_binary_text_fallback(path: Path) -> str: + """ + 当结构化抽取失败时,退回到“尽可能保留纯文本”的保底方案。 + + 该方案不保证版式,但足以支撑 V1 入库和演示。 + """ data = path.read_bytes() text = data.decode("utf-8", errors="ignore") text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]+", " ", text) return text.strip() - - -def index_document(document: UploadedDocument) -> UploadedDocument: - try: - text = extract_text(document) - result = ingest_document( - document_id=document.id, - scenario_id=document.scenario_id, - source_file=document.original_name, - text=text, - collection=document.scenario_id, - ) - if result.success: - document.status = UploadedDocument.STATUS_INDEXED - document.error_message = "" - else: - document.status = UploadedDocument.STATUS_FAILED - document.error_message = result.error - except Exception as exc: - document.status = UploadedDocument.STATUS_FAILED - document.error_message = str(exc) - document.save(update_fields=["status", "error_message", "updated_at"]) - return document diff --git a/tests/test_agent_core.py b/tests/test_agent_core.py index 277df9f..f66fb90 100644 --- a/tests/test_agent_core.py +++ b/tests/test_agent_core.py @@ -1,5 +1,5 @@ from agent_core.orchestrator import build_messages, run_agent -from agent_core.rag.ingest import ingest_document +from agent_core.rag.ingest import _split_text, ingest_document from agent_core.rag.retriever import retrieve @@ -221,3 +221,30 @@ def test_run_agent_uses_retrieved_document_chunks(tmp_path): assert result.references[0]["source"] == "sop.md" assert "隔离现场" in result.references[0]["content"] + + +def test_rag_split_text_keeps_overlap_and_non_empty_chunks(): + chunks = _split_text("A" * 20, chunk_size=8, overlap=3) + + assert chunks == ["AAAAAAAA", "AAAAAAAA", "AAAAAAAA", "AAAAA"] + + +def test_retrieve_returns_empty_when_query_has_no_overlap(tmp_path): + store_path = tmp_path / "rag_store.json" + ingest_document( + scenario_id="knowledge_qa", + source_file="rules.md", + text="这里描述的是报销流程和审批链。", + collection="knowledge_qa", + store_path=store_path, + ) + + chunks = retrieve( + scenario_id="knowledge_qa", + query="设备点检", + collection="knowledge_qa", + top_k=3, + store_path=store_path, + ) + + assert chunks == [] diff --git a/tests/test_documents.py b/tests/test_documents.py index a6b2b3c..5af8267 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -3,7 +3,7 @@ from django.urls import reverse from apps.documents.forms import DocumentUploadForm from apps.documents.models import UploadedDocument -from apps.documents.services import extract_text +from apps.documents.services import extract_text, index_document def test_upload_txt_document_creates_uploaded_record(client, db): @@ -128,3 +128,20 @@ def test_index_failure_message_is_visible_on_document_list(client, db, monkeypat assert response.status_code == 200 assert "文档入库失败,请检查错误原因后重试" in content assert "模拟入库失败" in content + + +def test_index_document_marks_failed_when_extracted_text_is_empty(db, monkeypatch): + document = UploadedDocument.objects.create( + scenario_id="knowledge_qa", + original_name="empty.md", + file_type="md", + size=0, + status="uploaded", + ) + + monkeypatch.setattr("apps.documents.services.extract_text", lambda target: " ") + + updated_document = index_document(document) + + assert updated_document.status == UploadedDocument.STATUS_FAILED + assert "文档内容为空" in updated_document.error_message