fix(kb): 完善知识库入库和重建索引

This commit is contained in:
2026-06-08 23:45:34 +08:00
parent d8cd95e590
commit 2b5093040d
9 changed files with 355 additions and 9 deletions

View File

@@ -10,8 +10,8 @@ from django.core.files.uploadedfile import UploadedFile
from review_agent.models import KnowledgeBaseDocument from review_agent.models import KnowledgeBaseDocument
from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations
from review_agent.regulatory_review.services.rag_embedding import DeterministicEmbeddingProvider from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider
from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path, is_excluded_source_path
from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file
@@ -78,6 +78,8 @@ def list_source_documents(source_dir: Path) -> list[dict[str, Any]]:
continue continue
suffix = path.suffix.lower() suffix = path.suffix.lower()
relative_path = str(path.relative_to(source_dir)) relative_path = str(path.relative_to(source_dir))
if is_excluded_source_path(relative_path):
continue
indexed_chunk_count = source_chunk_counts.get(relative_path, 0) indexed_chunk_count = source_chunk_counts.get(relative_path, 0)
documents.append( documents.append(
{ {
@@ -101,7 +103,7 @@ def search_knowledge_base(query: str, *, n_results: int = 3) -> dict[str, Any]:
try: try:
results = retrieve_citations( results = retrieve_citations(
normalized, normalized,
embedding_provider=DeterministicEmbeddingProvider(), embedding_provider=get_embedding_provider(),
n_results=n_results, n_results=n_results,
) )
except RagIndexUnavailable as exc: except RagIndexUnavailable as exc:
@@ -210,7 +212,7 @@ def index_managed_document(document: KnowledgeBaseDocument) -> int:
return 0 return 0
collection = _load_chroma_collection() collection = _load_chroma_collection()
texts = [chunk.text for chunk in chunks] texts = [chunk.text for chunk in chunks]
embeddings = DeterministicEmbeddingProvider()(texts) embeddings = get_embedding_provider()(texts)
ids = [ ids = [
hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest() hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
for chunk in chunks for chunk in chunks

View File

@@ -23,7 +23,7 @@ class Command(BaseCommand):
raise CommandError(f"法规材料目录不存在:{source_dir}") raise CommandError(f"法规材料目录不存在:{source_dir}")
try: try:
provider = get_embedding_provider(options["provider"]) provider = get_embedding_provider(options["provider"])
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider) count = build_chroma_index(source_dir=source_dir, embedding_provider=provider, reset=True)
except Exception as exc: except Exception as exc:
raise CommandError(str(exc)) from exc raise CommandError(str(exc)) from exc
self.stdout.write( self.stdout.write(

View File

@@ -23,6 +23,8 @@ from .rag_embedding import EmbeddingFunction
logger = logging.getLogger("review_agent.regulatory_review.rag_index") logger = logging.getLogger("review_agent.regulatory_review.rag_index")
EXCLUDED_SOURCE_KEYWORDS = ("模拟题二", "试剂盒临床注册文件准备与审核Agent")
@dataclass(frozen=True) @dataclass(frozen=True)
class TextChunk: class TextChunk:
@@ -227,6 +229,8 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
for path in sorted(source_dir.rglob("*")): for path in sorted(source_dir.rglob("*")):
if not path.is_file(): if not path.is_file():
continue continue
if is_excluded_source_path(path.relative_to(source_dir)):
continue
try: try:
text = extract_text_from_path(path) text = extract_text_from_path(path)
except RuntimeError as exc: except RuntimeError as exc:
@@ -238,6 +242,11 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
return chunks return chunks
def is_excluded_source_path(path: Path | str) -> bool:
normalized = str(path)
return any(keyword in normalized for keyword in EXCLUDED_SOURCE_KEYWORDS)
def _is_attachment4(path: Path) -> bool: def _is_attachment4(path: Path) -> bool:
normalized = path.name.replace(" ", "") normalized = path.name.replace(" ", "")
return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized
@@ -249,6 +258,7 @@ def build_chroma_index(
embedding_provider: EmbeddingFunction, embedding_provider: EmbeddingFunction,
persist_path: Path | None = None, persist_path: Path | None = None,
collection_name: str | None = None, collection_name: str | None = None,
reset: bool = False,
) -> int: ) -> int:
try: try:
import chromadb import chromadb
@@ -259,7 +269,22 @@ def build_chroma_index(
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
persist_path.mkdir(parents=True, exist_ok=True) persist_path.mkdir(parents=True, exist_ok=True)
chunks = collect_source_chunks(source_dir) chunks = collect_source_chunks(source_dir)
try:
client = chromadb.PersistentClient(path=str(persist_path)) client = chromadb.PersistentClient(path=str(persist_path))
except Exception:
if not reset:
raise
clear_chroma_system_cache()
clear_chroma_index_dir(persist_path)
persist_path.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(path=str(persist_path))
if reset:
try:
client.delete_collection(collection_name)
clear_chroma_system_cache()
client = chromadb.PersistentClient(path=str(persist_path))
except Exception:
pass
collection = client.get_or_create_collection(collection_name) collection = client.get_or_create_collection(collection_name)
if not chunks: if not chunks:
return 0 return 0
@@ -276,3 +301,22 @@ def build_chroma_index(
embeddings=embeddings, embeddings=embeddings,
) )
return len(chunks) return len(chunks)
def clear_chroma_index_dir(persist_path: Path | str | None = None) -> None:
chroma_path = Path(persist_path or settings.REGULATORY_RAG_CHROMA_PATH).resolve()
media_root = Path(settings.MEDIA_ROOT).resolve()
try:
chroma_path.relative_to(media_root)
except ValueError as exc:
raise RuntimeError("法规 RAG 索引目录必须位于 MEDIA_ROOT 内。") from exc
if chroma_path.exists():
shutil.rmtree(chroma_path)
def clear_chroma_system_cache() -> None:
try:
from chromadb.api.shared_system_client import SharedSystemClient
except Exception:
return
SharedSystemClient.clear_system_cache()

View File

@@ -25,6 +25,7 @@ from .views import (
knowledge_base_document_detail, knowledge_base_document_detail,
knowledge_base_document_index, knowledge_base_document_index,
knowledge_base_documents, knowledge_base_documents,
knowledge_base_rebuild_index,
knowledge_base_search, knowledge_base_search,
knowledge_base_status, knowledge_base_status,
) )
@@ -121,6 +122,11 @@ urlpatterns = [
knowledge_base_search, knowledge_base_search,
name="knowledge_base_search", name="knowledge_base_search",
), ),
path(
"api/review-agent/knowledge-base/rebuild-index/",
knowledge_base_rebuild_index,
name="knowledge_base_rebuild_index",
),
path( path(
"api/review-agent/knowledge-base/documents/", "api/review-agent/knowledge-base/documents/",
knowledge_base_documents, knowledge_base_documents,

View File

@@ -1,6 +1,8 @@
from django.contrib.auth.decorators import login_required from django.contrib.auth.decorators import login_required
from django.conf import settings
from django.db.models import Count, Q, Sum from django.db.models import Count, Q, Sum
import json import json
from pathlib import Path
from django.http import HttpRequest, HttpResponse, JsonResponse, StreamingHttpResponse from django.http import HttpRequest, HttpResponse, JsonResponse, StreamingHttpResponse
from django.shortcuts import redirect, render from django.shortcuts import redirect, render
@@ -27,6 +29,9 @@ from .knowledge_base import (
) )
from .models import KnowledgeBaseDocument from .models import KnowledgeBaseDocument
from .regulatory_review.services.info_extract import ensure_regulatory_condition_candidates from .regulatory_review.services.info_extract import ensure_regulatory_condition_candidates
from .regulatory_review.services.rag_embedding import get_embedding_provider
from .regulatory_review.services.rag_index import build_chroma_index
from .regulatory_review.services.rule_loader import load_rule_file
@login_required @login_required
@@ -151,6 +156,24 @@ def knowledge_base_status(request: HttpRequest) -> JsonResponse:
return JsonResponse(build_knowledge_base_context_for_user(request.user)) return JsonResponse(build_knowledge_base_context_for_user(request.user))
@login_required
@require_http_methods(["POST"])
def knowledge_base_rebuild_index(request: HttpRequest) -> JsonResponse:
payload = rebuild_knowledge_base_index()
return JsonResponse({"knowledge_base": build_knowledge_base_context_for_user(request.user), **payload})
def rebuild_knowledge_base_index() -> dict[str, object]:
rule_set = load_rule_file()
source_dir = Path(settings.BASE_DIR) / rule_set["source_material_dir"]
chunk_count = build_chroma_index(
source_dir=source_dir,
embedding_provider=get_embedding_provider(),
reset=True,
)
return {"chunk_count": chunk_count}
@login_required @login_required
@require_http_methods(["POST"]) @require_http_methods(["POST"])
def knowledge_base_search(request: HttpRequest) -> JsonResponse: def knowledge_base_search(request: HttpRequest) -> JsonResponse:

View File

@@ -15,6 +15,8 @@
var sourceTable = document.getElementById("knowledgeSourceTable"); var sourceTable = document.getElementById("knowledgeSourceTable");
var documentFileInput = document.getElementById("knowledgeDocumentFile"); var documentFileInput = document.getElementById("knowledgeDocumentFile");
var uploadDropzone = document.getElementById("knowledgeUploadDropzone"); var uploadDropzone = document.getElementById("knowledgeUploadDropzone");
var rebuildButton = document.getElementById("knowledgeRebuildIndexButton");
var rebuildStatus = document.getElementById("knowledgeRebuildStatus");
function csrfToken() { function csrfToken() {
var cookie = document.cookie.split("; ").find(function (item) { var cookie = document.cookie.split("; ").find(function (item) {
@@ -68,6 +70,17 @@
return response.json(); return response.json();
} }
async function rebuildIndex() {
var response = await fetch(page.getAttribute("data-rebuild-url"), {
method: "POST",
headers: { "X-CSRFToken": csrfToken() },
});
if (!response.ok) {
throw new Error("法规索引重建失败。");
}
return response.json();
}
function renderResults(payload) { function renderResults(payload) {
if (!results) { if (!results) {
return; return;
@@ -196,6 +209,59 @@
}); });
} }
async function handleRebuild(trigger) {
if (!page.getAttribute("data-rebuild-url")) {
return;
}
var originalText = trigger ? trigger.textContent : "";
if (trigger) {
trigger.disabled = true;
trigger.textContent = "入库中";
}
if (rebuildButton && trigger !== rebuildButton) {
rebuildButton.disabled = true;
}
if (rebuildStatus) {
rebuildStatus.textContent = "正在重建法规 RAG 索引...";
}
try {
var payload = await rebuildIndex();
if (rebuildStatus) {
rebuildStatus.textContent = "重建完成,入库片段 " + (payload.chunk_count || 0) + " 个。";
}
window.setTimeout(function () {
window.location.reload();
}, 600);
} catch (error) {
if (rebuildStatus) {
rebuildStatus.textContent = error.message || "法规索引重建失败。";
}
if (trigger) {
trigger.disabled = false;
trigger.textContent = originalText;
}
if (rebuildButton) {
rebuildButton.disabled = false;
}
}
}
if (rebuildButton) {
rebuildButton.addEventListener("click", function () {
handleRebuild(rebuildButton);
});
}
if (sourceTable) {
sourceTable.addEventListener("click", function (event) {
var button = event.target.closest("[data-source-action='index']");
if (!button) {
return;
}
handleRebuild(button);
});
}
if (searchForm && queryInput) { if (searchForm && queryInput) {
searchForm.addEventListener("submit", async function (event) { searchForm.addEventListener("submit", async function (event) {
event.preventDefault(); event.preventDefault();

View File

@@ -32,6 +32,7 @@
class="knowledge-page" class="knowledge-page"
data-document-url="{% url 'knowledge_base_document_list' %}" data-document-url="{% url 'knowledge_base_document_list' %}"
data-search-url="{% url 'knowledge_base_search' %}" data-search-url="{% url 'knowledge_base_search' %}"
data-rebuild-url="{% url 'knowledge_base_rebuild_index' %}"
> >
<header class="attachment-manager-hero attachment-manager-toolbar"> <header class="attachment-manager-hero attachment-manager-toolbar">
<div> <div>
@@ -96,9 +97,10 @@
</div> </div>
</dl> </dl>
<p class="knowledge-panel-note">{{ knowledge_base.status.message }}</p> <p class="knowledge-panel-note">{{ knowledge_base.status.message }}</p>
<p class="upload-status" id="knowledgeRebuildStatus"></p>
<div class="knowledge-form-actions"> <div class="knowledge-form-actions">
<button type="button" onclick="window.location.reload()">刷新状态</button> <button type="button" onclick="window.location.reload()">刷新状态</button>
<button type="button" disabled>重建索引</button> <button type="button" id="knowledgeRebuildIndexButton">重建索引</button>
</div> </div>
</section> </section>
@@ -182,6 +184,7 @@
<th>类型</th> <th>类型</th>
<th>大小</th> <th>大小</th>
<th>索引</th> <th>索引</th>
<th>操作</th>
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
@@ -192,10 +195,13 @@
<td>{{ source.suffix }}</td> <td>{{ source.suffix }}</td>
<td>{{ source.size }} bytes</td> <td>{{ source.size }} bytes</td>
<td>{{ source.indexed_label }}</td> <td>{{ source.indexed_label }}</td>
<td class="attachment-actions">
<button type="button" data-source-action="index">手动入库</button>
</td>
</tr> </tr>
{% empty %} {% empty %}
<tr> <tr>
<td colspan="5" class="table-empty">暂无法规材料</td> <td colspan="6" class="table-empty">暂无法规材料</td>
</tr> </tr>
{% endfor %} {% endfor %}
</tbody> </tbody>
@@ -209,5 +215,5 @@
{% endblock %} {% endblock %}
{% block scripts %} {% block scripts %}
<script src="{% static 'js/knowledge_base.js' %}?v=20260608-kb5"></script> <script src="{% static 'js/knowledge_base.js' %}?v=20260608-kb6"></script>
{% endblock %} {% endblock %}

View File

@@ -3,6 +3,7 @@ from django.core.files.uploadedfile import SimpleUploadedFile
from django.urls import reverse from django.urls import reverse
from review_agent.knowledge_base import build_knowledge_base_context, delete_document, search_knowledge_base from review_agent.knowledge_base import build_knowledge_base_context, delete_document, search_knowledge_base
from review_agent.views import rebuild_knowledge_base_index
from review_agent.models import KnowledgeBaseDocument from review_agent.models import KnowledgeBaseDocument
@@ -16,6 +17,7 @@ def test_knowledge_base_context_reports_rule_and_sources():
assert context["rule"]["requirement_count"] > 0 assert context["rule"]["requirement_count"] > 0
assert context["source_count"] > 0 assert context["source_count"] > 0
assert context["collection_name"] == "nmpa_ivd_registration_v1" assert context["collection_name"] == "nmpa_ivd_registration_v1"
assert not any("模拟题二" in source["relative_path"] for source in context["sources"])
def test_knowledge_base_page_requires_login(client): def test_knowledge_base_page_requires_login(client):
@@ -36,6 +38,11 @@ def test_knowledge_base_page_renders_for_user(client, django_user_model):
content = response.content.decode("utf-8") content = response.content.decode("utf-8")
tabbar = content[content.index('<div class="tabbar"') : content.index("</div>", content.index('<div class="tabbar"'))] tabbar = content[content.index('<div class="tabbar"') : content.index("</div>", content.index('<div class="tabbar"'))]
assert tabbar.index("审核智能体") < tabbar.index("知识库管理") < tabbar.index("附件管理") assert tabbar.index("审核智能体") < tabbar.index("知识库管理") < tabbar.index("附件管理")
assert "data-rebuild-url=" in content
assert 'id="knowledgeRebuildIndexButton"' in content
assert "重建索引" in content
assert 'data-source-action="index"' in content
assert "手动入库" in content
def test_knowledge_base_status_api(client, django_user_model): def test_knowledge_base_status_api(client, django_user_model):
@@ -48,6 +55,53 @@ def test_knowledge_base_status_api(client, django_user_model):
assert response.json()["rule"]["code"] == "nmpa_ivd_registration_v1" assert response.json()["rule"]["code"] == "nmpa_ivd_registration_v1"
def test_knowledge_base_rebuild_index_api(client, django_user_model, monkeypatch):
user = django_user_model.objects.create_user(username="owner", password="pass")
client.force_login(user)
calls = []
monkeypatch.setattr(
"review_agent.views.rebuild_knowledge_base_index",
lambda: calls.append("rebuild") or {"chunk_count": 12},
)
response = client.post(reverse("knowledge_base_rebuild_index"))
assert response.status_code == 200
assert response.json()["chunk_count"] == 12
assert response.json()["knowledge_base"]["collection"]["count"] >= 0
assert calls == ["rebuild"]
def test_rebuild_knowledge_base_index_requests_reset(settings, tmp_path, monkeypatch):
settings.MEDIA_ROOT = tmp_path
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "chroma"
settings.REGULATORY_RAG_CHROMA_PATH.mkdir()
stale_file = settings.REGULATORY_RAG_CHROMA_PATH / "chroma.sqlite3"
stale_file.write_text("stale", encoding="utf-8")
calls = []
monkeypatch.setattr("review_agent.views.load_rule_file", lambda: {"source_material_dir": "docs/0.原始材料"})
monkeypatch.setattr("review_agent.views.get_embedding_provider", lambda: "provider")
monkeypatch.setattr(
"review_agent.views.build_chroma_index",
lambda source_dir, embedding_provider, reset=False: calls.append(
{
"source_dir": source_dir,
"embedding_provider": embedding_provider,
"reset": reset,
}
)
or 8,
)
payload = rebuild_knowledge_base_index()
assert payload["chunk_count"] == 8
assert calls[0]["embedding_provider"] == "provider"
assert calls[0]["reset"] is True
def test_knowledge_base_search_rejects_blank_query(): def test_knowledge_base_search_rejects_blank_query():
payload = search_knowledge_base("") payload = search_knowledge_base("")
@@ -103,6 +157,8 @@ def test_knowledge_base_search_api_returns_payload(client, django_user_model):
def test_knowledge_base_document_crud_api(client, settings, tmp_path, django_user_model): def test_knowledge_base_document_crud_api(client, settings, tmp_path, django_user_model):
settings.MEDIA_ROOT = tmp_path settings.MEDIA_ROOT = tmp_path
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "chroma"
settings.REGULATORY_RAG_PROVIDER = "deterministic"
user = django_user_model.objects.create_user(username="owner", password="pass") user = django_user_model.objects.create_user(username="owner", password="pass")
client.force_login(user) client.force_login(user)
@@ -199,6 +255,8 @@ def test_knowledge_base_document_api_is_scoped_to_owner(client, django_user_mode
def test_knowledge_base_document_manual_index_api(client, settings, tmp_path, django_user_model): def test_knowledge_base_document_manual_index_api(client, settings, tmp_path, django_user_model):
settings.MEDIA_ROOT = tmp_path settings.MEDIA_ROOT = tmp_path
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "chroma"
settings.REGULATORY_RAG_PROVIDER = "deterministic"
user = django_user_model.objects.create_user(username="owner", password="pass") user = django_user_model.objects.create_user(username="owner", password="pass")
client.force_login(user) client.force_login(user)
source_path = tmp_path / "manual.md" source_path = tmp_path / "manual.md"

View File

@@ -1,3 +1,5 @@
import sys
import pytest import pytest
from review_agent.regulatory_review.services.rag_citation import ( from review_agent.regulatory_review.services.rag_citation import (
@@ -7,6 +9,7 @@ from review_agent.regulatory_review.services.rag_citation import (
from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider
from review_agent.regulatory_review.services.rag_index import chunk_text from review_agent.regulatory_review.services.rag_index import chunk_text
from review_agent.regulatory_review.services.rag_index import collect_source_chunks from review_agent.regulatory_review.services.rag_index import collect_source_chunks
from review_agent.regulatory_review.services.rag_index import build_chroma_index
def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch): def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch):
@@ -86,3 +89,141 @@ def test_collect_source_chunks_requires_attachment4_extraction(monkeypatch, tmp_
with pytest.raises(RuntimeError, match="附件 4"): with pytest.raises(RuntimeError, match="附件 4"):
collect_source_chunks(source_dir) collect_source_chunks(source_dir)
def test_collect_source_chunks_excludes_demo_agent_materials(monkeypatch, tmp_path):
source_dir = tmp_path / "sources"
source_dir.mkdir()
demo_dir = source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent"
demo_dir.mkdir()
(demo_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.md").write_text("题目材料", encoding="utf-8")
(source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.docx").write_bytes(b"demo")
real_source = source_dir / "附件 4 体外诊断试剂注册申报资料要求及说明.doc"
real_source.write_bytes(b"rule")
def fake_extract(path):
return "附件4 正文" if path == real_source else "不应被抽取"
monkeypatch.setattr("review_agent.regulatory_review.services.rag_index.extract_text_from_path", fake_extract)
chunks = collect_source_chunks(source_dir)
assert chunks
assert all("模拟题二" not in chunk.metadata["source"] for chunk in chunks)
def test_build_chroma_index_reset_recreates_collection_without_deleting_index_dir(settings, monkeypatch, tmp_path):
settings.MEDIA_ROOT = tmp_path
persist_path = tmp_path / "chroma"
persist_path.mkdir()
stale_file = persist_path / "chroma.sqlite3"
stale_file.write_text("stale", encoding="utf-8")
source_dir = tmp_path / "sources"
source_dir.mkdir()
(source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8")
client_states = []
deleted_collections = []
class FakeCollection:
def upsert(self, **kwargs):
return None
class FakeClient:
def __init__(self, path):
client_states.append({"path": path, "stale_exists": stale_file.exists()})
def delete_collection(self, name):
deleted_collections.append(name)
def get_or_create_collection(self, name):
return FakeCollection()
class FakeSharedSystemClient:
@staticmethod
def clear_system_cache():
client_states.append({"path": "cache-cleared", "stale_exists": stale_file.exists()})
monkeypatch.setitem(sys.modules, "chromadb", type("FakeChromaModule", (), {"PersistentClient": FakeClient}))
monkeypatch.setitem(
sys.modules,
"chromadb.api.shared_system_client",
type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}),
)
count = build_chroma_index(
source_dir=source_dir,
embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts],
persist_path=persist_path,
collection_name="test",
reset=True,
)
assert count == 1
assert client_states == [
{"path": str(persist_path), "stale_exists": True},
{"path": "cache-cleared", "stale_exists": True},
{"path": str(persist_path), "stale_exists": True},
]
assert stale_file.exists()
assert deleted_collections == ["test"]
def test_build_chroma_index_reset_clears_bad_index_dir_after_chroma_cache_reset(settings, monkeypatch, tmp_path):
settings.MEDIA_ROOT = tmp_path
persist_path = tmp_path / "chroma"
persist_path.mkdir()
stale_file = persist_path / "chroma.sqlite3"
stale_file.write_text("stale", encoding="utf-8")
source_dir = tmp_path / "sources"
source_dir.mkdir()
(source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8")
events = []
class FakeCollection:
def upsert(self, **kwargs):
return None
class BrokenThenFreshClient:
attempts = 0
def __init__(self, path):
BrokenThenFreshClient.attempts += 1
events.append(("client", BrokenThenFreshClient.attempts, stale_file.exists()))
if BrokenThenFreshClient.attempts == 1:
raise ValueError("Could not connect to tenant default_tenant")
def get_or_create_collection(self, name):
return FakeCollection()
class FakeSharedSystemClient:
@staticmethod
def clear_system_cache():
events.append(("clear_cache", stale_file.exists()))
fake_chromadb = type(
"FakeChromaModule",
(),
{"PersistentClient": BrokenThenFreshClient},
)
monkeypatch.setitem(sys.modules, "chromadb", fake_chromadb)
monkeypatch.setitem(
sys.modules,
"chromadb.api.shared_system_client",
type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}),
)
count = build_chroma_index(
source_dir=source_dir,
embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts],
persist_path=persist_path,
collection_name="test",
reset=True,
)
assert count == 1
assert events == [
("client", 1, True),
("clear_cache", True),
("client", 2, False),
]
assert not stale_file.exists()