refactor(rag): 梳理文档入库与检索服务结构
This commit is contained in:
@@ -6,6 +6,7 @@ from agent_core.llm_provider import create_embedding_provider
|
|||||||
|
|
||||||
|
|
||||||
def _client(path: str | Path | None = None):
|
def _client(path: str | Path | None = None):
|
||||||
|
"""按给定路径初始化 Chroma 持久化客户端。"""
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|
||||||
resolved_path = str(path or settings.CHROMA_PATH)
|
resolved_path = str(path or settings.CHROMA_PATH)
|
||||||
@@ -13,6 +14,7 @@ def _client(path: str | Path | None = None):
|
|||||||
|
|
||||||
|
|
||||||
def _embedding_provider():
|
def _embedding_provider():
|
||||||
|
"""从 Django settings 构造 Embedding Provider,避免在业务层散落配置读取。"""
|
||||||
return create_embedding_provider(
|
return create_embedding_provider(
|
||||||
{
|
{
|
||||||
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
|
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
|
||||||
@@ -27,6 +29,11 @@ def upsert_chunks(
|
|||||||
chunks: list[dict],
|
chunks: list[dict],
|
||||||
store_path: str | Path | None = None,
|
store_path: str | Path | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
将 chunk 写入 Chroma。
|
||||||
|
|
||||||
|
同一 document_id 重新入库前会先删除旧记录,保证一次文档只有一份有效向量数据。
|
||||||
|
"""
|
||||||
client = _client(store_path)
|
client = _client(store_path)
|
||||||
chroma_collection = client.get_or_create_collection(collection)
|
chroma_collection = client.get_or_create_collection(collection)
|
||||||
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
|
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
|
||||||
@@ -59,6 +66,7 @@ def query_chunks(
|
|||||||
document_ids: list[int] | None = None,
|
document_ids: list[int] | None = None,
|
||||||
store_path: str | Path | None = None,
|
store_path: str | Path | None = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
|
"""执行向量检索,并把 Chroma 原始结果转换为统一引用结构。"""
|
||||||
client = _client(store_path)
|
client = _client(store_path)
|
||||||
chroma_collection = client.get_or_create_collection(collection)
|
chroma_collection = client.get_or_create_collection(collection)
|
||||||
where: dict = {"scenario_id": scenario_id}
|
where: dict = {"scenario_id": scenario_id}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import importlib.util
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -12,11 +12,61 @@ from .chroma_store import upsert_chunks
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IngestResult:
|
class IngestResult:
|
||||||
|
"""RAG 入库统一返回结构,供 Documents 模块稳定消费。"""
|
||||||
success: bool
|
success: bool
|
||||||
chunks_count: int = 0
|
chunks_count: int = 0
|
||||||
error: str = ""
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_document(
|
||||||
|
scenario_id: str,
|
||||||
|
source_file: str,
|
||||||
|
text: str,
|
||||||
|
collection: str,
|
||||||
|
document_id: int | None = None,
|
||||||
|
store_path: str | Path | None = None,
|
||||||
|
) -> IngestResult:
|
||||||
|
"""
|
||||||
|
将单个文档文本切分后写入知识库。
|
||||||
|
|
||||||
|
运行策略:
|
||||||
|
- 如果显式传入 `store_path`,说明当前是测试或降级模式,走本地 JSON 存储。
|
||||||
|
- 如果未传入且环境可用 chromadb,则走真实 Chroma 持久化。
|
||||||
|
"""
|
||||||
|
if not text.strip():
|
||||||
|
return IngestResult(success=False, error="文档内容为空")
|
||||||
|
if _should_use_chroma(store_path):
|
||||||
|
return _ingest_chroma_document(
|
||||||
|
document_id=document_id,
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
source_file=source_file,
|
||||||
|
text=text,
|
||||||
|
collection=collection,
|
||||||
|
)
|
||||||
|
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
||||||
|
chunks = _build_chunks(
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
source_file=source_file,
|
||||||
|
text=text,
|
||||||
|
collection=collection,
|
||||||
|
document_id=document_id,
|
||||||
|
chunk_id_prefix=source_file,
|
||||||
|
)
|
||||||
|
persisted_chunks = _filter_out_same_document_chunks(
|
||||||
|
_load_store(resolved_store_path),
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
collection=collection,
|
||||||
|
document_id=document_id,
|
||||||
|
)
|
||||||
|
_save_store(resolved_store_path, [*persisted_chunks, *chunks])
|
||||||
|
return IngestResult(success=True, chunks_count=len(chunks))
|
||||||
|
|
||||||
|
|
||||||
|
def _should_use_chroma(store_path: str | Path | None) -> bool:
|
||||||
|
"""只在未指定测试存储路径且安装 chromadb 时启用真实向量库。"""
|
||||||
|
return store_path is None and importlib.util.find_spec("chromadb") is not None
|
||||||
|
|
||||||
|
|
||||||
def _default_store_path() -> Path:
|
def _default_store_path() -> Path:
|
||||||
return Path(settings.CHROMA_PATH) / "rag_store.json"
|
return Path(settings.CHROMA_PATH) / "rag_store.json"
|
||||||
|
|
||||||
@@ -35,6 +85,13 @@ def _save_store(store_path: Path, chunks: list[dict]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[str]:
|
def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[str]:
|
||||||
|
"""
|
||||||
|
使用固定窗口 + overlap 切分文本。
|
||||||
|
|
||||||
|
该策略简单但稳定,便于解释:
|
||||||
|
- chunk_size 控制每个片段最大长度
|
||||||
|
- overlap 保证相邻片段共享上下文,降低边界信息丢失
|
||||||
|
"""
|
||||||
normalized = re.sub(r"\s+", " ", text).strip()
|
normalized = re.sub(r"\s+", " ", text).strip()
|
||||||
if not normalized:
|
if not normalized:
|
||||||
return []
|
return []
|
||||||
@@ -49,44 +106,46 @@ def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[st
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def ingest_document(
|
def _build_chunks(
|
||||||
scenario_id: str,
|
scenario_id: str,
|
||||||
source_file: str,
|
source_file: str,
|
||||||
text: str,
|
text: str,
|
||||||
collection: str,
|
collection: str,
|
||||||
document_id: int | None = None,
|
document_id: int | None,
|
||||||
store_path: str | Path | None = None,
|
chunk_id_prefix: str,
|
||||||
) -> IngestResult:
|
) -> list[dict]:
|
||||||
if not text.strip():
|
"""把原始文本切分并封装为统一 chunk 结构。"""
|
||||||
return IngestResult(success=False, error="文档内容为空")
|
created_at = datetime.now(timezone.utc).isoformat()
|
||||||
if store_path is None and importlib.util.find_spec("chromadb") is not None:
|
return [
|
||||||
return _ingest_chroma_document(document_id, scenario_id, source_file, text, collection)
|
{
|
||||||
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
"scenario_id": scenario_id,
|
||||||
existing_chunks = [
|
"document_id": document_id,
|
||||||
|
"collection": collection,
|
||||||
|
"source": source_file,
|
||||||
|
"chunk_id": f"{scenario_id}:{chunk_id_prefix}:{index}",
|
||||||
|
"content": chunk_text,
|
||||||
|
"created_at": created_at,
|
||||||
|
}
|
||||||
|
for index, chunk_text in enumerate(_split_text(text), start=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_out_same_document_chunks(
|
||||||
|
chunks: list[dict],
|
||||||
|
scenario_id: str,
|
||||||
|
collection: str,
|
||||||
|
document_id: int | None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""重新入库同一 document_id 时,先删除旧 chunk,避免重复检索。"""
|
||||||
|
return [
|
||||||
chunk
|
chunk
|
||||||
for chunk in _load_store(resolved_store_path)
|
for chunk in chunks
|
||||||
if not (
|
if not (
|
||||||
chunk.get("document_id") == document_id
|
chunk.get("document_id") == document_id
|
||||||
and chunk.get("scenario_id") == scenario_id
|
and chunk.get("scenario_id") == scenario_id
|
||||||
and chunk.get("collection") == collection
|
and chunk.get("collection") == collection
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
created_at = datetime.now(timezone.utc).isoformat()
|
|
||||||
new_chunks = []
|
|
||||||
for index, chunk_text in enumerate(_split_text(text), start=1):
|
|
||||||
new_chunks.append(
|
|
||||||
{
|
|
||||||
"scenario_id": scenario_id,
|
|
||||||
"document_id": document_id,
|
|
||||||
"collection": collection,
|
|
||||||
"source": source_file,
|
|
||||||
"chunk_id": f"{scenario_id}:{source_file}:{index}",
|
|
||||||
"content": chunk_text,
|
|
||||||
"created_at": created_at,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
_save_store(resolved_store_path, [*existing_chunks, *new_chunks])
|
|
||||||
return IngestResult(success=True, chunks_count=len(new_chunks))
|
|
||||||
|
|
||||||
|
|
||||||
def _ingest_chroma_document(
|
def _ingest_chroma_document(
|
||||||
@@ -96,19 +155,15 @@ def _ingest_chroma_document(
|
|||||||
text: str,
|
text: str,
|
||||||
collection: str,
|
collection: str,
|
||||||
) -> IngestResult:
|
) -> IngestResult:
|
||||||
created_at = datetime.now(timezone.utc).isoformat()
|
"""真实 Chroma 模式的入库分支。"""
|
||||||
chunks = [
|
chunks = _build_chunks(
|
||||||
{
|
scenario_id=scenario_id,
|
||||||
"scenario_id": scenario_id,
|
source_file=source_file,
|
||||||
"document_id": document_id,
|
text=text,
|
||||||
"collection": collection,
|
collection=collection,
|
||||||
"source": source_file,
|
document_id=document_id,
|
||||||
"chunk_id": f"{scenario_id}:{document_id or source_file}:{index}",
|
chunk_id_prefix=str(document_id or source_file),
|
||||||
"content": chunk_text,
|
)
|
||||||
"created_at": created_at,
|
|
||||||
}
|
|
||||||
for index, chunk_text in enumerate(_split_text(text), start=1)
|
|
||||||
]
|
|
||||||
try:
|
try:
|
||||||
upsert_chunks(collection=collection, chunks=chunks)
|
upsert_chunks(collection=collection, chunks=chunks)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import importlib.util
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
@@ -8,6 +8,52 @@ from django.conf import settings
|
|||||||
from .chroma_store import query_chunks
|
from .chroma_store import query_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
scenario_id: str,
|
||||||
|
query: str,
|
||||||
|
collection: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
document_ids: list[int] | None = None,
|
||||||
|
store_path: str | Path | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
统一对外提供检索入口。
|
||||||
|
|
||||||
|
与 ingest_document 保持一致:
|
||||||
|
- 真实运行优先走 Chroma
|
||||||
|
- 测试或降级模式走本地 JSON + 轻量文本打分
|
||||||
|
"""
|
||||||
|
if _should_use_chroma(store_path):
|
||||||
|
return query_chunks(
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
query=query,
|
||||||
|
collection=collection,
|
||||||
|
top_k=top_k,
|
||||||
|
document_ids=document_ids,
|
||||||
|
)
|
||||||
|
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
||||||
|
query_tokens = _tokens(query)
|
||||||
|
allowed_document_ids = set(document_ids or [])
|
||||||
|
scored_chunks = []
|
||||||
|
for chunk in _load_store(resolved_store_path):
|
||||||
|
if not _matches_scope(
|
||||||
|
chunk=chunk,
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
collection=collection,
|
||||||
|
allowed_document_ids=allowed_document_ids,
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
score = _score(query_tokens, chunk.get("content", ""))
|
||||||
|
if score <= 0:
|
||||||
|
continue
|
||||||
|
scored_chunks.append({**chunk, "score": score})
|
||||||
|
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]
|
||||||
|
|
||||||
|
|
||||||
|
def _should_use_chroma(store_path: str | Path | None) -> bool:
|
||||||
|
return store_path is None and importlib.util.find_spec("chromadb") is not None
|
||||||
|
|
||||||
|
|
||||||
def _default_store_path() -> Path:
|
def _default_store_path() -> Path:
|
||||||
return Path(settings.CHROMA_PATH) / "rag_store.json"
|
return Path(settings.CHROMA_PATH) / "rag_store.json"
|
||||||
|
|
||||||
@@ -19,7 +65,30 @@ def _load_store(store_path: Path) -> list[dict]:
|
|||||||
return json.load(file)
|
return json.load(file)
|
||||||
|
|
||||||
|
|
||||||
|
def _matches_scope(
|
||||||
|
chunk: dict,
|
||||||
|
scenario_id: str,
|
||||||
|
collection: str,
|
||||||
|
allowed_document_ids: set[int],
|
||||||
|
) -> bool:
|
||||||
|
"""先按场景、collection 和可选文档范围过滤,再进行相关性打分。"""
|
||||||
|
if chunk.get("scenario_id") != scenario_id:
|
||||||
|
return False
|
||||||
|
if chunk.get("collection") != collection:
|
||||||
|
return False
|
||||||
|
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _tokens(text: str) -> set[str]:
|
def _tokens(text: str) -> set[str]:
|
||||||
|
"""
|
||||||
|
兼容中英文的轻量分词策略。
|
||||||
|
|
||||||
|
该分词仅用于 fallback 模式,不替代真实向量检索:
|
||||||
|
- 英文/数字按词提取
|
||||||
|
- 中文按连续词片段和单字同时保留,提升短查询命中率
|
||||||
|
"""
|
||||||
lowered = text.lower()
|
lowered = text.lower()
|
||||||
ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered))
|
ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered))
|
||||||
cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered))
|
cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered))
|
||||||
@@ -28,42 +97,9 @@ def _tokens(text: str) -> set[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _score(query_tokens: set[str], content: str) -> float:
|
def _score(query_tokens: set[str], content: str) -> float:
|
||||||
|
"""使用交集占比计算一个便于排序的简化相关性分数。"""
|
||||||
content_tokens = _tokens(content)
|
content_tokens = _tokens(content)
|
||||||
if not query_tokens or not content_tokens:
|
if not query_tokens or not content_tokens:
|
||||||
return 0.0
|
return 0.0
|
||||||
overlap = query_tokens & content_tokens
|
overlap = query_tokens & content_tokens
|
||||||
return round(len(overlap) / len(query_tokens), 4)
|
return round(len(overlap) / len(query_tokens), 4)
|
||||||
|
|
||||||
|
|
||||||
def retrieve(
|
|
||||||
scenario_id: str,
|
|
||||||
query: str,
|
|
||||||
collection: str,
|
|
||||||
top_k: int = 5,
|
|
||||||
document_ids: list[int] | None = None,
|
|
||||||
store_path: str | Path | None = None,
|
|
||||||
) -> list[dict]:
|
|
||||||
if store_path is None and importlib.util.find_spec("chromadb") is not None:
|
|
||||||
return query_chunks(
|
|
||||||
scenario_id=scenario_id,
|
|
||||||
query=query,
|
|
||||||
collection=collection,
|
|
||||||
top_k=top_k,
|
|
||||||
document_ids=document_ids,
|
|
||||||
)
|
|
||||||
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
|
||||||
query_tokens = _tokens(query)
|
|
||||||
allowed_document_ids = set(document_ids or [])
|
|
||||||
scored_chunks = []
|
|
||||||
for chunk in _load_store(resolved_store_path):
|
|
||||||
if chunk.get("scenario_id") != scenario_id:
|
|
||||||
continue
|
|
||||||
if chunk.get("collection") != collection:
|
|
||||||
continue
|
|
||||||
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
|
|
||||||
continue
|
|
||||||
score = _score(query_tokens, chunk.get("content", ""))
|
|
||||||
if score <= 0:
|
|
||||||
continue
|
|
||||||
scored_chunks.append({**chunk, "score": score})
|
|
||||||
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from zipfile import BadZipFile, ZipFile
|
|
||||||
import re
|
import re
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
from zipfile import BadZipFile, ZipFile
|
||||||
|
|
||||||
from agent_core.rag.ingest import ingest_document
|
from agent_core.rag.ingest import ingest_document
|
||||||
|
|
||||||
@@ -9,7 +9,13 @@ from .models import UploadedDocument
|
|||||||
|
|
||||||
|
|
||||||
def create_uploaded_document(scenario_id: str, uploaded_file) -> UploadedDocument:
|
def create_uploaded_document(scenario_id: str, uploaded_file) -> UploadedDocument:
|
||||||
extension = Path(uploaded_file.name).suffix.lower().lstrip(".")
|
"""
|
||||||
|
保存上传文件的元数据记录。
|
||||||
|
|
||||||
|
Documents 模块只记录文件与场景关系、原始名称、类型和大小,
|
||||||
|
真正的入库动作由用户后续主动触发,避免上传阶段就耦合 RAG 流程。
|
||||||
|
"""
|
||||||
|
extension = _detect_extension(uploaded_file.name)
|
||||||
return UploadedDocument.objects.create(
|
return UploadedDocument.objects.create(
|
||||||
scenario_id=scenario_id,
|
scenario_id=scenario_id,
|
||||||
original_name=uploaded_file.name,
|
original_name=uploaded_file.name,
|
||||||
@@ -21,6 +27,14 @@ def create_uploaded_document(scenario_id: str, uploaded_file) -> UploadedDocumen
|
|||||||
|
|
||||||
|
|
||||||
def extract_text(document: UploadedDocument) -> str:
|
def extract_text(document: UploadedDocument) -> str:
|
||||||
|
"""
|
||||||
|
根据文档类型选择合适的文本抽取策略。
|
||||||
|
|
||||||
|
V1 的目标是“可演示且稳定”,因此:
|
||||||
|
- `.txt` / `.md` 直接按文本读取
|
||||||
|
- `.pdf` 优先走 pypdf,失败时回退为二进制容错读取
|
||||||
|
- `.docx` 优先解析 Word XML,失败时回退为二进制容错读取
|
||||||
|
"""
|
||||||
path = Path(document.file.path)
|
path = Path(document.file.path)
|
||||||
extension = f".{document.file_type.lower().lstrip('.')}"
|
extension = f".{document.file_type.lower().lstrip('.')}"
|
||||||
if extension == ".pdf":
|
if extension == ".pdf":
|
||||||
@@ -30,7 +44,47 @@ def extract_text(document: UploadedDocument) -> str:
|
|||||||
return _read_text_file(path)
|
return _read_text_file(path)
|
||||||
|
|
||||||
|
|
||||||
|
def index_document(document: UploadedDocument) -> UploadedDocument:
|
||||||
|
"""
|
||||||
|
触发单个文档入库,并把成功/失败状态回写到 UploadedDocument。
|
||||||
|
|
||||||
|
这里故意不抛业务异常给 View:
|
||||||
|
View 层只需要知道“最终状态是什么”,而错误信息统一落到模型字段中,
|
||||||
|
便于页面重试和演示。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
text = extract_text(document)
|
||||||
|
ingest_result = ingest_document(
|
||||||
|
document_id=document.id,
|
||||||
|
scenario_id=document.scenario_id,
|
||||||
|
source_file=document.original_name,
|
||||||
|
text=text,
|
||||||
|
collection=document.scenario_id,
|
||||||
|
)
|
||||||
|
_apply_ingest_result(document, ingest_result.success, ingest_result.error)
|
||||||
|
except Exception as exc:
|
||||||
|
_apply_ingest_result(document, success=False, error=str(exc))
|
||||||
|
document.save(update_fields=["status", "error_message", "updated_at"])
|
||||||
|
return document
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_ingest_result(document: UploadedDocument, success: bool, error: str = "") -> None:
|
||||||
|
"""把入库结果映射为 UploadedDocument 的稳定状态字段。"""
|
||||||
|
if success:
|
||||||
|
document.status = UploadedDocument.STATUS_INDEXED
|
||||||
|
document.error_message = ""
|
||||||
|
return
|
||||||
|
document.status = UploadedDocument.STATUS_FAILED
|
||||||
|
document.error_message = error
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_extension(file_name: str) -> str:
|
||||||
|
"""统一将扩展名转成小写且去掉前导点,便于模型字段存储。"""
|
||||||
|
return Path(file_name).suffix.lower().lstrip(".")
|
||||||
|
|
||||||
|
|
||||||
def _read_text_file(path: Path) -> str:
|
def _read_text_file(path: Path) -> str:
|
||||||
|
"""优先按 UTF-8 读取;失败时回退到系统默认编码。"""
|
||||||
try:
|
try:
|
||||||
return path.read_text(encoding="utf-8")
|
return path.read_text(encoding="utf-8")
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
@@ -38,6 +92,7 @@ def _read_text_file(path: Path) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _extract_pdf_text(path: Path) -> str:
|
def _extract_pdf_text(path: Path) -> str:
|
||||||
|
"""优先使用 pypdf 抽取 PDF 文本,失败时回退到容错方案。"""
|
||||||
try:
|
try:
|
||||||
import pypdf
|
import pypdf
|
||||||
|
|
||||||
@@ -48,6 +103,7 @@ def _extract_pdf_text(path: Path) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _extract_docx_text(path: Path) -> str:
|
def _extract_docx_text(path: Path) -> str:
|
||||||
|
"""提取 Word XML 中的可见文字内容,不追求保留样式。"""
|
||||||
try:
|
try:
|
||||||
with ZipFile(path) as archive:
|
with ZipFile(path) as archive:
|
||||||
document_xml = archive.read("word/document.xml")
|
document_xml = archive.read("word/document.xml")
|
||||||
@@ -60,30 +116,12 @@ def _extract_docx_text(path: Path) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _read_binary_text_fallback(path: Path) -> str:
|
def _read_binary_text_fallback(path: Path) -> str:
|
||||||
|
"""
|
||||||
|
当结构化抽取失败时,退回到“尽可能保留纯文本”的保底方案。
|
||||||
|
|
||||||
|
该方案不保证版式,但足以支撑 V1 入库和演示。
|
||||||
|
"""
|
||||||
data = path.read_bytes()
|
data = path.read_bytes()
|
||||||
text = data.decode("utf-8", errors="ignore")
|
text = data.decode("utf-8", errors="ignore")
|
||||||
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]+", " ", text)
|
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]+", " ", text)
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def index_document(document: UploadedDocument) -> UploadedDocument:
|
|
||||||
try:
|
|
||||||
text = extract_text(document)
|
|
||||||
result = ingest_document(
|
|
||||||
document_id=document.id,
|
|
||||||
scenario_id=document.scenario_id,
|
|
||||||
source_file=document.original_name,
|
|
||||||
text=text,
|
|
||||||
collection=document.scenario_id,
|
|
||||||
)
|
|
||||||
if result.success:
|
|
||||||
document.status = UploadedDocument.STATUS_INDEXED
|
|
||||||
document.error_message = ""
|
|
||||||
else:
|
|
||||||
document.status = UploadedDocument.STATUS_FAILED
|
|
||||||
document.error_message = result.error
|
|
||||||
except Exception as exc:
|
|
||||||
document.status = UploadedDocument.STATUS_FAILED
|
|
||||||
document.error_message = str(exc)
|
|
||||||
document.save(update_fields=["status", "error_message", "updated_at"])
|
|
||||||
return document
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from agent_core.orchestrator import build_messages, run_agent
|
from agent_core.orchestrator import build_messages, run_agent
|
||||||
from agent_core.rag.ingest import ingest_document
|
from agent_core.rag.ingest import _split_text, ingest_document
|
||||||
from agent_core.rag.retriever import retrieve
|
from agent_core.rag.retriever import retrieve
|
||||||
|
|
||||||
|
|
||||||
@@ -221,3 +221,30 @@ def test_run_agent_uses_retrieved_document_chunks(tmp_path):
|
|||||||
|
|
||||||
assert result.references[0]["source"] == "sop.md"
|
assert result.references[0]["source"] == "sop.md"
|
||||||
assert "隔离现场" in result.references[0]["content"]
|
assert "隔离现场" in result.references[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_split_text_keeps_overlap_and_non_empty_chunks():
|
||||||
|
chunks = _split_text("A" * 20, chunk_size=8, overlap=3)
|
||||||
|
|
||||||
|
assert chunks == ["AAAAAAAA", "AAAAAAAA", "AAAAAAAA", "AAAAA"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_returns_empty_when_query_has_no_overlap(tmp_path):
|
||||||
|
store_path = tmp_path / "rag_store.json"
|
||||||
|
ingest_document(
|
||||||
|
scenario_id="knowledge_qa",
|
||||||
|
source_file="rules.md",
|
||||||
|
text="这里描述的是报销流程和审批链。",
|
||||||
|
collection="knowledge_qa",
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = retrieve(
|
||||||
|
scenario_id="knowledge_qa",
|
||||||
|
query="设备点检",
|
||||||
|
collection="knowledge_qa",
|
||||||
|
top_k=3,
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chunks == []
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from django.urls import reverse
|
|||||||
|
|
||||||
from apps.documents.forms import DocumentUploadForm
|
from apps.documents.forms import DocumentUploadForm
|
||||||
from apps.documents.models import UploadedDocument
|
from apps.documents.models import UploadedDocument
|
||||||
from apps.documents.services import extract_text
|
from apps.documents.services import extract_text, index_document
|
||||||
|
|
||||||
|
|
||||||
def test_upload_txt_document_creates_uploaded_record(client, db):
|
def test_upload_txt_document_creates_uploaded_record(client, db):
|
||||||
@@ -128,3 +128,20 @@ def test_index_failure_message_is_visible_on_document_list(client, db, monkeypat
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert "文档入库失败,请检查错误原因后重试" in content
|
assert "文档入库失败,请检查错误原因后重试" in content
|
||||||
assert "模拟入库失败" in content
|
assert "模拟入库失败" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_index_document_marks_failed_when_extracted_text_is_empty(db, monkeypatch):
|
||||||
|
document = UploadedDocument.objects.create(
|
||||||
|
scenario_id="knowledge_qa",
|
||||||
|
original_name="empty.md",
|
||||||
|
file_type="md",
|
||||||
|
size=0,
|
||||||
|
status="uploaded",
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("apps.documents.services.extract_text", lambda target: " ")
|
||||||
|
|
||||||
|
updated_document = index_document(document)
|
||||||
|
|
||||||
|
assert updated_document.status == UploadedDocument.STATUS_FAILED
|
||||||
|
assert "文档内容为空" in updated_document.error_message
|
||||||
|
|||||||
Reference in New Issue
Block a user