diff --git a/review_agent/knowledge_base.py b/review_agent/knowledge_base.py index 12edff7..79f3aba 100644 --- a/review_agent/knowledge_base.py +++ b/review_agent/knowledge_base.py @@ -10,8 +10,8 @@ from django.core.files.uploadedfile import UploadedFile from review_agent.models import KnowledgeBaseDocument from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations -from review_agent.regulatory_review.services.rag_embedding import DeterministicEmbeddingProvider -from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path +from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider +from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path, is_excluded_source_path from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file @@ -78,6 +78,8 @@ def list_source_documents(source_dir: Path) -> list[dict[str, Any]]: continue suffix = path.suffix.lower() relative_path = str(path.relative_to(source_dir)) + if is_excluded_source_path(relative_path): + continue indexed_chunk_count = source_chunk_counts.get(relative_path, 0) documents.append( { @@ -101,7 +103,7 @@ def search_knowledge_base(query: str, *, n_results: int = 3) -> dict[str, Any]: try: results = retrieve_citations( normalized, - embedding_provider=DeterministicEmbeddingProvider(), + embedding_provider=get_embedding_provider(), n_results=n_results, ) except RagIndexUnavailable as exc: @@ -210,7 +212,7 @@ def index_managed_document(document: KnowledgeBaseDocument) -> int: return 0 collection = _load_chroma_collection() texts = [chunk.text for chunk in chunks] - embeddings = DeterministicEmbeddingProvider()(texts) + embeddings = get_embedding_provider()(texts) ids = [ hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest() for chunk in chunks diff --git a/review_agent/management/commands/regulatory_rag_build.py b/review_agent/management/commands/regulatory_rag_build.py index b8be556..c2263aa 100644 --- a/review_agent/management/commands/regulatory_rag_build.py +++ b/review_agent/management/commands/regulatory_rag_build.py @@ -23,7 +23,7 @@ class Command(BaseCommand): raise CommandError(f"法规材料目录不存在:{source_dir}") try: provider = get_embedding_provider(options["provider"]) - count = build_chroma_index(source_dir=source_dir, embedding_provider=provider) + count = build_chroma_index(source_dir=source_dir, embedding_provider=provider, reset=True) except Exception as exc: raise CommandError(str(exc)) from exc self.stdout.write( diff --git a/review_agent/regulatory_review/services/rag_index.py b/review_agent/regulatory_review/services/rag_index.py index be80cf8..3e58826 100644 --- a/review_agent/regulatory_review/services/rag_index.py +++ b/review_agent/regulatory_review/services/rag_index.py @@ -23,6 +23,8 @@ from .rag_embedding import EmbeddingFunction logger = logging.getLogger("review_agent.regulatory_review.rag_index") +EXCLUDED_SOURCE_KEYWORDS = ("模拟题二", "试剂盒临床注册文件准备与审核Agent") + @dataclass(frozen=True) class TextChunk: @@ -227,6 +229,8 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]: for path in sorted(source_dir.rglob("*")): if not path.is_file(): continue + if is_excluded_source_path(path.relative_to(source_dir)): + continue try: text = extract_text_from_path(path) except RuntimeError as exc: @@ -238,6 +242,11 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]: return chunks +def is_excluded_source_path(path: Path | str) -> bool: + normalized = str(path) + return any(keyword in normalized for keyword in EXCLUDED_SOURCE_KEYWORDS) + + def _is_attachment4(path: Path) -> bool: normalized = path.name.replace(" ", "") return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized @@ -249,6 +258,7 @@ def build_chroma_index( embedding_provider: EmbeddingFunction, persist_path: Path | None = None, collection_name: str | None = None, + reset: bool = False, ) -> int: try: import chromadb @@ -259,7 +269,22 @@ def build_chroma_index( collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION persist_path.mkdir(parents=True, exist_ok=True) chunks = collect_source_chunks(source_dir) - client = chromadb.PersistentClient(path=str(persist_path)) + try: + client = chromadb.PersistentClient(path=str(persist_path)) + except Exception: + if not reset: + raise + clear_chroma_system_cache() + clear_chroma_index_dir(persist_path) + persist_path.mkdir(parents=True, exist_ok=True) + client = chromadb.PersistentClient(path=str(persist_path)) + if reset: + try: + client.delete_collection(collection_name) + clear_chroma_system_cache() + client = chromadb.PersistentClient(path=str(persist_path)) + except Exception: + pass collection = client.get_or_create_collection(collection_name) if not chunks: return 0 @@ -276,3 +301,22 @@ def build_chroma_index( embeddings=embeddings, ) return len(chunks) + + +def clear_chroma_index_dir(persist_path: Path | str | None = None) -> None: + chroma_path = Path(persist_path or settings.REGULATORY_RAG_CHROMA_PATH).resolve() + media_root = Path(settings.MEDIA_ROOT).resolve() + try: + chroma_path.relative_to(media_root) + except ValueError as exc: + raise RuntimeError("法规 RAG 索引目录必须位于 MEDIA_ROOT 内。") from exc + if chroma_path.exists(): + shutil.rmtree(chroma_path) + + +def clear_chroma_system_cache() -> None: + try: + from chromadb.api.shared_system_client import SharedSystemClient + except Exception: + return + SharedSystemClient.clear_system_cache() diff --git a/review_agent/urls.py b/review_agent/urls.py index dfe648c..4d46250 100644 --- a/review_agent/urls.py +++ b/review_agent/urls.py @@ -25,6 +25,7 @@ from .views import ( knowledge_base_document_detail, knowledge_base_document_index, knowledge_base_documents, + knowledge_base_rebuild_index, knowledge_base_search, knowledge_base_status, ) @@ -121,6 +122,11 @@ urlpatterns = [ knowledge_base_search, name="knowledge_base_search", ), + path( + "api/review-agent/knowledge-base/rebuild-index/", + knowledge_base_rebuild_index, + name="knowledge_base_rebuild_index", + ), path( "api/review-agent/knowledge-base/documents/", knowledge_base_documents, diff --git a/review_agent/views.py b/review_agent/views.py index 6297a1d..2933923 100644 --- a/review_agent/views.py +++ b/review_agent/views.py @@ -1,6 +1,8 @@ from django.contrib.auth.decorators import login_required +from django.conf import settings from django.db.models import Count, Q, Sum import json +from pathlib import Path from django.http import HttpRequest, HttpResponse, JsonResponse, StreamingHttpResponse from django.shortcuts import redirect, render @@ -27,6 +29,9 @@ from .knowledge_base import ( ) from .models import KnowledgeBaseDocument from .regulatory_review.services.info_extract import ensure_regulatory_condition_candidates +from .regulatory_review.services.rag_embedding import get_embedding_provider +from .regulatory_review.services.rag_index import build_chroma_index +from .regulatory_review.services.rule_loader import load_rule_file @login_required @@ -151,6 +156,24 @@ def knowledge_base_status(request: HttpRequest) -> JsonResponse: return JsonResponse(build_knowledge_base_context_for_user(request.user)) +@login_required +@require_http_methods(["POST"]) +def knowledge_base_rebuild_index(request: HttpRequest) -> JsonResponse: + payload = rebuild_knowledge_base_index() + return JsonResponse({"knowledge_base": build_knowledge_base_context_for_user(request.user), **payload}) + + +def rebuild_knowledge_base_index() -> dict[str, object]: + rule_set = load_rule_file() + source_dir = Path(settings.BASE_DIR) / rule_set["source_material_dir"] + chunk_count = build_chroma_index( + source_dir=source_dir, + embedding_provider=get_embedding_provider(), + reset=True, + ) + return {"chunk_count": chunk_count} + + @login_required @require_http_methods(["POST"]) def knowledge_base_search(request: HttpRequest) -> JsonResponse: diff --git a/static/js/knowledge_base.js b/static/js/knowledge_base.js index dd6b9d0..cb756db 100644 --- a/static/js/knowledge_base.js +++ b/static/js/knowledge_base.js @@ -15,6 +15,8 @@ var sourceTable = document.getElementById("knowledgeSourceTable"); var documentFileInput = document.getElementById("knowledgeDocumentFile"); var uploadDropzone = document.getElementById("knowledgeUploadDropzone"); + var rebuildButton = document.getElementById("knowledgeRebuildIndexButton"); + var rebuildStatus = document.getElementById("knowledgeRebuildStatus"); function csrfToken() { var cookie = document.cookie.split("; ").find(function (item) { @@ -68,6 +70,17 @@ return response.json(); } + async function rebuildIndex() { + var response = await fetch(page.getAttribute("data-rebuild-url"), { + method: "POST", + headers: { "X-CSRFToken": csrfToken() }, + }); + if (!response.ok) { + throw new Error("法规索引重建失败。"); + } + return response.json(); + } + function renderResults(payload) { if (!results) { return; @@ -196,6 +209,59 @@ }); } + async function handleRebuild(trigger) { + if (!page.getAttribute("data-rebuild-url")) { + return; + } + var originalText = trigger ? trigger.textContent : ""; + if (trigger) { + trigger.disabled = true; + trigger.textContent = "入库中"; + } + if (rebuildButton && trigger !== rebuildButton) { + rebuildButton.disabled = true; + } + if (rebuildStatus) { + rebuildStatus.textContent = "正在重建法规 RAG 索引..."; + } + try { + var payload = await rebuildIndex(); + if (rebuildStatus) { + rebuildStatus.textContent = "重建完成,入库片段 " + (payload.chunk_count || 0) + " 个。"; + } + window.setTimeout(function () { + window.location.reload(); + }, 600); + } catch (error) { + if (rebuildStatus) { + rebuildStatus.textContent = error.message || "法规索引重建失败。"; + } + if (trigger) { + trigger.disabled = false; + trigger.textContent = originalText; + } + if (rebuildButton) { + rebuildButton.disabled = false; + } + } + } + + if (rebuildButton) { + rebuildButton.addEventListener("click", function () { + handleRebuild(rebuildButton); + }); + } + + if (sourceTable) { + sourceTable.addEventListener("click", function (event) { + var button = event.target.closest("[data-source-action='index']"); + if (!button) { + return; + } + handleRebuild(button); + }); + } + if (searchForm && queryInput) { searchForm.addEventListener("submit", async function (event) { event.preventDefault(); diff --git a/templates/knowledge_base.html b/templates/knowledge_base.html index c899103..aa4039c 100644 --- a/templates/knowledge_base.html +++ b/templates/knowledge_base.html @@ -32,6 +32,7 @@ class="knowledge-page" data-document-url="{% url 'knowledge_base_document_list' %}" data-search-url="{% url 'knowledge_base_search' %}" + data-rebuild-url="{% url 'knowledge_base_rebuild_index' %}" >
@@ -96,9 +97,10 @@

{{ knowledge_base.status.message }}

+

- +
@@ -182,6 +184,7 @@ 类型 大小 索引 + 操作 @@ -192,10 +195,13 @@ {{ source.suffix }} {{ source.size }} bytes {{ source.indexed_label }} + + + {% empty %} - 暂无法规材料 + 暂无法规材料 {% endfor %} @@ -209,5 +215,5 @@ {% endblock %} {% block scripts %} - + {% endblock %} diff --git a/tests/test_knowledge_base.py b/tests/test_knowledge_base.py index 936d822..21df46f 100644 --- a/tests/test_knowledge_base.py +++ b/tests/test_knowledge_base.py @@ -3,6 +3,7 @@ from django.core.files.uploadedfile import SimpleUploadedFile from django.urls import reverse from review_agent.knowledge_base import build_knowledge_base_context, delete_document, search_knowledge_base +from review_agent.views import rebuild_knowledge_base_index from review_agent.models import KnowledgeBaseDocument @@ -16,6 +17,7 @@ def test_knowledge_base_context_reports_rule_and_sources(): assert context["rule"]["requirement_count"] > 0 assert context["source_count"] > 0 assert context["collection_name"] == "nmpa_ivd_registration_v1" + assert not any("模拟题二" in source["relative_path"] for source in context["sources"]) def test_knowledge_base_page_requires_login(client): @@ -36,6 +38,11 @@ def test_knowledge_base_page_renders_for_user(client, django_user_model): content = response.content.decode("utf-8") tabbar = content[content.index('
", content.index('
= 0 + assert calls == ["rebuild"] + + +def test_rebuild_knowledge_base_index_requests_reset(settings, tmp_path, monkeypatch): + settings.MEDIA_ROOT = tmp_path + settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "chroma" + settings.REGULATORY_RAG_CHROMA_PATH.mkdir() + stale_file = settings.REGULATORY_RAG_CHROMA_PATH / "chroma.sqlite3" + stale_file.write_text("stale", encoding="utf-8") + calls = [] + + monkeypatch.setattr("review_agent.views.load_rule_file", lambda: {"source_material_dir": "docs/0.原始材料"}) + monkeypatch.setattr("review_agent.views.get_embedding_provider", lambda: "provider") + monkeypatch.setattr( + "review_agent.views.build_chroma_index", + lambda source_dir, embedding_provider, reset=False: calls.append( + { + "source_dir": source_dir, + "embedding_provider": embedding_provider, + "reset": reset, + } + ) + or 8, + ) + + payload = rebuild_knowledge_base_index() + + assert payload["chunk_count"] == 8 + assert calls[0]["embedding_provider"] == "provider" + assert calls[0]["reset"] is True + + def test_knowledge_base_search_rejects_blank_query(): payload = search_knowledge_base("") @@ -103,6 +157,8 @@ def test_knowledge_base_search_api_returns_payload(client, django_user_model): def test_knowledge_base_document_crud_api(client, settings, tmp_path, django_user_model): settings.MEDIA_ROOT = tmp_path + settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "chroma" + settings.REGULATORY_RAG_PROVIDER = "deterministic" user = django_user_model.objects.create_user(username="owner", password="pass") client.force_login(user) @@ -199,6 +255,8 @@ def test_knowledge_base_document_api_is_scoped_to_owner(client, django_user_mode def test_knowledge_base_document_manual_index_api(client, settings, tmp_path, django_user_model): settings.MEDIA_ROOT = tmp_path + settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "chroma" + settings.REGULATORY_RAG_PROVIDER = "deterministic" user = django_user_model.objects.create_user(username="owner", password="pass") client.force_login(user) source_path = tmp_path / "manual.md" diff --git a/tests/test_regulatory_rag.py b/tests/test_regulatory_rag.py index 356ffc6..79930c4 100644 --- a/tests/test_regulatory_rag.py +++ b/tests/test_regulatory_rag.py @@ -1,3 +1,5 @@ +import sys + import pytest from review_agent.regulatory_review.services.rag_citation import ( @@ -7,6 +9,7 @@ from review_agent.regulatory_review.services.rag_citation import ( from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider from review_agent.regulatory_review.services.rag_index import chunk_text from review_agent.regulatory_review.services.rag_index import collect_source_chunks +from review_agent.regulatory_review.services.rag_index import build_chroma_index def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch): @@ -86,3 +89,141 @@ def test_collect_source_chunks_requires_attachment4_extraction(monkeypatch, tmp_ with pytest.raises(RuntimeError, match="附件 4"): collect_source_chunks(source_dir) + + +def test_collect_source_chunks_excludes_demo_agent_materials(monkeypatch, tmp_path): + source_dir = tmp_path / "sources" + source_dir.mkdir() + demo_dir = source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent" + demo_dir.mkdir() + (demo_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.md").write_text("题目材料", encoding="utf-8") + (source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.docx").write_bytes(b"demo") + real_source = source_dir / "附件 4 体外诊断试剂注册申报资料要求及说明.doc" + real_source.write_bytes(b"rule") + + def fake_extract(path): + return "附件4 正文" if path == real_source else "不应被抽取" + + monkeypatch.setattr("review_agent.regulatory_review.services.rag_index.extract_text_from_path", fake_extract) + + chunks = collect_source_chunks(source_dir) + + assert chunks + assert all("模拟题二" not in chunk.metadata["source"] for chunk in chunks) + + +def test_build_chroma_index_reset_recreates_collection_without_deleting_index_dir(settings, monkeypatch, tmp_path): + settings.MEDIA_ROOT = tmp_path + persist_path = tmp_path / "chroma" + persist_path.mkdir() + stale_file = persist_path / "chroma.sqlite3" + stale_file.write_text("stale", encoding="utf-8") + source_dir = tmp_path / "sources" + source_dir.mkdir() + (source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8") + client_states = [] + deleted_collections = [] + + class FakeCollection: + def upsert(self, **kwargs): + return None + + class FakeClient: + def __init__(self, path): + client_states.append({"path": path, "stale_exists": stale_file.exists()}) + + def delete_collection(self, name): + deleted_collections.append(name) + + def get_or_create_collection(self, name): + return FakeCollection() + + class FakeSharedSystemClient: + @staticmethod + def clear_system_cache(): + client_states.append({"path": "cache-cleared", "stale_exists": stale_file.exists()}) + + monkeypatch.setitem(sys.modules, "chromadb", type("FakeChromaModule", (), {"PersistentClient": FakeClient})) + monkeypatch.setitem( + sys.modules, + "chromadb.api.shared_system_client", + type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}), + ) + + count = build_chroma_index( + source_dir=source_dir, + embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts], + persist_path=persist_path, + collection_name="test", + reset=True, + ) + + assert count == 1 + assert client_states == [ + {"path": str(persist_path), "stale_exists": True}, + {"path": "cache-cleared", "stale_exists": True}, + {"path": str(persist_path), "stale_exists": True}, + ] + assert stale_file.exists() + assert deleted_collections == ["test"] + + +def test_build_chroma_index_reset_clears_bad_index_dir_after_chroma_cache_reset(settings, monkeypatch, tmp_path): + settings.MEDIA_ROOT = tmp_path + persist_path = tmp_path / "chroma" + persist_path.mkdir() + stale_file = persist_path / "chroma.sqlite3" + stale_file.write_text("stale", encoding="utf-8") + source_dir = tmp_path / "sources" + source_dir.mkdir() + (source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8") + events = [] + + class FakeCollection: + def upsert(self, **kwargs): + return None + + class BrokenThenFreshClient: + attempts = 0 + + def __init__(self, path): + BrokenThenFreshClient.attempts += 1 + events.append(("client", BrokenThenFreshClient.attempts, stale_file.exists())) + if BrokenThenFreshClient.attempts == 1: + raise ValueError("Could not connect to tenant default_tenant") + + def get_or_create_collection(self, name): + return FakeCollection() + + class FakeSharedSystemClient: + @staticmethod + def clear_system_cache(): + events.append(("clear_cache", stale_file.exists())) + + fake_chromadb = type( + "FakeChromaModule", + (), + {"PersistentClient": BrokenThenFreshClient}, + ) + monkeypatch.setitem(sys.modules, "chromadb", fake_chromadb) + monkeypatch.setitem( + sys.modules, + "chromadb.api.shared_system_client", + type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}), + ) + + count = build_chroma_index( + source_dir=source_dir, + embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts], + persist_path=persist_path, + collection_name="test", + reset=True, + ) + + assert count == 1 + assert events == [ + ("client", 1, True), + ("clear_cache", True), + ("client", 2, False), + ] + assert not stale_file.exists()