fix(kb): 完善知识库入库和重建索引
This commit is contained in:
@@ -10,8 +10,8 @@ from django.core.files.uploadedfile import UploadedFile
|
|||||||
|
|
||||||
from review_agent.models import KnowledgeBaseDocument
|
from review_agent.models import KnowledgeBaseDocument
|
||||||
from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations
|
from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations
|
||||||
from review_agent.regulatory_review.services.rag_embedding import DeterministicEmbeddingProvider
|
from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider
|
||||||
from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path
|
from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path, is_excluded_source_path
|
||||||
from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file
|
from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file
|
||||||
|
|
||||||
|
|
||||||
@@ -78,6 +78,8 @@ def list_source_documents(source_dir: Path) -> list[dict[str, Any]]:
|
|||||||
continue
|
continue
|
||||||
suffix = path.suffix.lower()
|
suffix = path.suffix.lower()
|
||||||
relative_path = str(path.relative_to(source_dir))
|
relative_path = str(path.relative_to(source_dir))
|
||||||
|
if is_excluded_source_path(relative_path):
|
||||||
|
continue
|
||||||
indexed_chunk_count = source_chunk_counts.get(relative_path, 0)
|
indexed_chunk_count = source_chunk_counts.get(relative_path, 0)
|
||||||
documents.append(
|
documents.append(
|
||||||
{
|
{
|
||||||
@@ -101,7 +103,7 @@ def search_knowledge_base(query: str, *, n_results: int = 3) -> dict[str, Any]:
|
|||||||
try:
|
try:
|
||||||
results = retrieve_citations(
|
results = retrieve_citations(
|
||||||
normalized,
|
normalized,
|
||||||
embedding_provider=DeterministicEmbeddingProvider(),
|
embedding_provider=get_embedding_provider(),
|
||||||
n_results=n_results,
|
n_results=n_results,
|
||||||
)
|
)
|
||||||
except RagIndexUnavailable as exc:
|
except RagIndexUnavailable as exc:
|
||||||
@@ -210,7 +212,7 @@ def index_managed_document(document: KnowledgeBaseDocument) -> int:
|
|||||||
return 0
|
return 0
|
||||||
collection = _load_chroma_collection()
|
collection = _load_chroma_collection()
|
||||||
texts = [chunk.text for chunk in chunks]
|
texts = [chunk.text for chunk in chunks]
|
||||||
embeddings = DeterministicEmbeddingProvider()(texts)
|
embeddings = get_embedding_provider()(texts)
|
||||||
ids = [
|
ids = [
|
||||||
hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
|
hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
|
||||||
for chunk in chunks
|
for chunk in chunks
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class Command(BaseCommand):
|
|||||||
raise CommandError(f"法规材料目录不存在:{source_dir}")
|
raise CommandError(f"法规材料目录不存在:{source_dir}")
|
||||||
try:
|
try:
|
||||||
provider = get_embedding_provider(options["provider"])
|
provider = get_embedding_provider(options["provider"])
|
||||||
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider)
|
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider, reset=True)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise CommandError(str(exc)) from exc
|
raise CommandError(str(exc)) from exc
|
||||||
self.stdout.write(
|
self.stdout.write(
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ from .rag_embedding import EmbeddingFunction
|
|||||||
|
|
||||||
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
|
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
|
||||||
|
|
||||||
|
EXCLUDED_SOURCE_KEYWORDS = ("模拟题二", "试剂盒临床注册文件准备与审核Agent")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TextChunk:
|
class TextChunk:
|
||||||
@@ -227,6 +229,8 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
|
|||||||
for path in sorted(source_dir.rglob("*")):
|
for path in sorted(source_dir.rglob("*")):
|
||||||
if not path.is_file():
|
if not path.is_file():
|
||||||
continue
|
continue
|
||||||
|
if is_excluded_source_path(path.relative_to(source_dir)):
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
text = extract_text_from_path(path)
|
text = extract_text_from_path(path)
|
||||||
except RuntimeError as exc:
|
except RuntimeError as exc:
|
||||||
@@ -238,6 +242,11 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def is_excluded_source_path(path: Path | str) -> bool:
|
||||||
|
normalized = str(path)
|
||||||
|
return any(keyword in normalized for keyword in EXCLUDED_SOURCE_KEYWORDS)
|
||||||
|
|
||||||
|
|
||||||
def _is_attachment4(path: Path) -> bool:
|
def _is_attachment4(path: Path) -> bool:
|
||||||
normalized = path.name.replace(" ", "")
|
normalized = path.name.replace(" ", "")
|
||||||
return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized
|
return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized
|
||||||
@@ -249,6 +258,7 @@ def build_chroma_index(
|
|||||||
embedding_provider: EmbeddingFunction,
|
embedding_provider: EmbeddingFunction,
|
||||||
persist_path: Path | None = None,
|
persist_path: Path | None = None,
|
||||||
collection_name: str | None = None,
|
collection_name: str | None = None,
|
||||||
|
reset: bool = False,
|
||||||
) -> int:
|
) -> int:
|
||||||
try:
|
try:
|
||||||
import chromadb
|
import chromadb
|
||||||
@@ -259,7 +269,22 @@ def build_chroma_index(
|
|||||||
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
|
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
|
||||||
persist_path.mkdir(parents=True, exist_ok=True)
|
persist_path.mkdir(parents=True, exist_ok=True)
|
||||||
chunks = collect_source_chunks(source_dir)
|
chunks = collect_source_chunks(source_dir)
|
||||||
client = chromadb.PersistentClient(path=str(persist_path))
|
try:
|
||||||
|
client = chromadb.PersistentClient(path=str(persist_path))
|
||||||
|
except Exception:
|
||||||
|
if not reset:
|
||||||
|
raise
|
||||||
|
clear_chroma_system_cache()
|
||||||
|
clear_chroma_index_dir(persist_path)
|
||||||
|
persist_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
client = chromadb.PersistentClient(path=str(persist_path))
|
||||||
|
if reset:
|
||||||
|
try:
|
||||||
|
client.delete_collection(collection_name)
|
||||||
|
clear_chroma_system_cache()
|
||||||
|
client = chromadb.PersistentClient(path=str(persist_path))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
collection = client.get_or_create_collection(collection_name)
|
collection = client.get_or_create_collection(collection_name)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return 0
|
return 0
|
||||||
@@ -276,3 +301,22 @@ def build_chroma_index(
|
|||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
)
|
)
|
||||||
return len(chunks)
|
return len(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_chroma_index_dir(persist_path: Path | str | None = None) -> None:
|
||||||
|
chroma_path = Path(persist_path or settings.REGULATORY_RAG_CHROMA_PATH).resolve()
|
||||||
|
media_root = Path(settings.MEDIA_ROOT).resolve()
|
||||||
|
try:
|
||||||
|
chroma_path.relative_to(media_root)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise RuntimeError("法规 RAG 索引目录必须位于 MEDIA_ROOT 内。") from exc
|
||||||
|
if chroma_path.exists():
|
||||||
|
shutil.rmtree(chroma_path)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_chroma_system_cache() -> None:
|
||||||
|
try:
|
||||||
|
from chromadb.api.shared_system_client import SharedSystemClient
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
SharedSystemClient.clear_system_cache()
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from .views import (
|
|||||||
knowledge_base_document_detail,
|
knowledge_base_document_detail,
|
||||||
knowledge_base_document_index,
|
knowledge_base_document_index,
|
||||||
knowledge_base_documents,
|
knowledge_base_documents,
|
||||||
|
knowledge_base_rebuild_index,
|
||||||
knowledge_base_search,
|
knowledge_base_search,
|
||||||
knowledge_base_status,
|
knowledge_base_status,
|
||||||
)
|
)
|
||||||
@@ -121,6 +122,11 @@ urlpatterns = [
|
|||||||
knowledge_base_search,
|
knowledge_base_search,
|
||||||
name="knowledge_base_search",
|
name="knowledge_base_search",
|
||||||
),
|
),
|
||||||
|
path(
|
||||||
|
"api/review-agent/knowledge-base/rebuild-index/",
|
||||||
|
knowledge_base_rebuild_index,
|
||||||
|
name="knowledge_base_rebuild_index",
|
||||||
|
),
|
||||||
path(
|
path(
|
||||||
"api/review-agent/knowledge-base/documents/",
|
"api/review-agent/knowledge-base/documents/",
|
||||||
knowledge_base_documents,
|
knowledge_base_documents,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from django.contrib.auth.decorators import login_required
|
from django.contrib.auth.decorators import login_required
|
||||||
|
from django.conf import settings
|
||||||
from django.db.models import Count, Q, Sum
|
from django.db.models import Count, Q, Sum
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse, JsonResponse, StreamingHttpResponse
|
from django.http import HttpRequest, HttpResponse, JsonResponse, StreamingHttpResponse
|
||||||
from django.shortcuts import redirect, render
|
from django.shortcuts import redirect, render
|
||||||
@@ -27,6 +29,9 @@ from .knowledge_base import (
|
|||||||
)
|
)
|
||||||
from .models import KnowledgeBaseDocument
|
from .models import KnowledgeBaseDocument
|
||||||
from .regulatory_review.services.info_extract import ensure_regulatory_condition_candidates
|
from .regulatory_review.services.info_extract import ensure_regulatory_condition_candidates
|
||||||
|
from .regulatory_review.services.rag_embedding import get_embedding_provider
|
||||||
|
from .regulatory_review.services.rag_index import build_chroma_index
|
||||||
|
from .regulatory_review.services.rule_loader import load_rule_file
|
||||||
|
|
||||||
|
|
||||||
@login_required
|
@login_required
|
||||||
@@ -151,6 +156,24 @@ def knowledge_base_status(request: HttpRequest) -> JsonResponse:
|
|||||||
return JsonResponse(build_knowledge_base_context_for_user(request.user))
|
return JsonResponse(build_knowledge_base_context_for_user(request.user))
|
||||||
|
|
||||||
|
|
||||||
|
@login_required
|
||||||
|
@require_http_methods(["POST"])
|
||||||
|
def knowledge_base_rebuild_index(request: HttpRequest) -> JsonResponse:
|
||||||
|
payload = rebuild_knowledge_base_index()
|
||||||
|
return JsonResponse({"knowledge_base": build_knowledge_base_context_for_user(request.user), **payload})
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_knowledge_base_index() -> dict[str, object]:
|
||||||
|
rule_set = load_rule_file()
|
||||||
|
source_dir = Path(settings.BASE_DIR) / rule_set["source_material_dir"]
|
||||||
|
chunk_count = build_chroma_index(
|
||||||
|
source_dir=source_dir,
|
||||||
|
embedding_provider=get_embedding_provider(),
|
||||||
|
reset=True,
|
||||||
|
)
|
||||||
|
return {"chunk_count": chunk_count}
|
||||||
|
|
||||||
|
|
||||||
@login_required
|
@login_required
|
||||||
@require_http_methods(["POST"])
|
@require_http_methods(["POST"])
|
||||||
def knowledge_base_search(request: HttpRequest) -> JsonResponse:
|
def knowledge_base_search(request: HttpRequest) -> JsonResponse:
|
||||||
|
|||||||
@@ -15,6 +15,8 @@
|
|||||||
var sourceTable = document.getElementById("knowledgeSourceTable");
|
var sourceTable = document.getElementById("knowledgeSourceTable");
|
||||||
var documentFileInput = document.getElementById("knowledgeDocumentFile");
|
var documentFileInput = document.getElementById("knowledgeDocumentFile");
|
||||||
var uploadDropzone = document.getElementById("knowledgeUploadDropzone");
|
var uploadDropzone = document.getElementById("knowledgeUploadDropzone");
|
||||||
|
var rebuildButton = document.getElementById("knowledgeRebuildIndexButton");
|
||||||
|
var rebuildStatus = document.getElementById("knowledgeRebuildStatus");
|
||||||
|
|
||||||
function csrfToken() {
|
function csrfToken() {
|
||||||
var cookie = document.cookie.split("; ").find(function (item) {
|
var cookie = document.cookie.split("; ").find(function (item) {
|
||||||
@@ -68,6 +70,17 @@
|
|||||||
return response.json();
|
return response.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function rebuildIndex() {
|
||||||
|
var response = await fetch(page.getAttribute("data-rebuild-url"), {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "X-CSRFToken": csrfToken() },
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error("法规索引重建失败。");
|
||||||
|
}
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
|
||||||
function renderResults(payload) {
|
function renderResults(payload) {
|
||||||
if (!results) {
|
if (!results) {
|
||||||
return;
|
return;
|
||||||
@@ -196,6 +209,59 @@
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleRebuild(trigger) {
|
||||||
|
if (!page.getAttribute("data-rebuild-url")) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
var originalText = trigger ? trigger.textContent : "";
|
||||||
|
if (trigger) {
|
||||||
|
trigger.disabled = true;
|
||||||
|
trigger.textContent = "入库中";
|
||||||
|
}
|
||||||
|
if (rebuildButton && trigger !== rebuildButton) {
|
||||||
|
rebuildButton.disabled = true;
|
||||||
|
}
|
||||||
|
if (rebuildStatus) {
|
||||||
|
rebuildStatus.textContent = "正在重建法规 RAG 索引...";
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
var payload = await rebuildIndex();
|
||||||
|
if (rebuildStatus) {
|
||||||
|
rebuildStatus.textContent = "重建完成,入库片段 " + (payload.chunk_count || 0) + " 个。";
|
||||||
|
}
|
||||||
|
window.setTimeout(function () {
|
||||||
|
window.location.reload();
|
||||||
|
}, 600);
|
||||||
|
} catch (error) {
|
||||||
|
if (rebuildStatus) {
|
||||||
|
rebuildStatus.textContent = error.message || "法规索引重建失败。";
|
||||||
|
}
|
||||||
|
if (trigger) {
|
||||||
|
trigger.disabled = false;
|
||||||
|
trigger.textContent = originalText;
|
||||||
|
}
|
||||||
|
if (rebuildButton) {
|
||||||
|
rebuildButton.disabled = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rebuildButton) {
|
||||||
|
rebuildButton.addEventListener("click", function () {
|
||||||
|
handleRebuild(rebuildButton);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sourceTable) {
|
||||||
|
sourceTable.addEventListener("click", function (event) {
|
||||||
|
var button = event.target.closest("[data-source-action='index']");
|
||||||
|
if (!button) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
handleRebuild(button);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
if (searchForm && queryInput) {
|
if (searchForm && queryInput) {
|
||||||
searchForm.addEventListener("submit", async function (event) {
|
searchForm.addEventListener("submit", async function (event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|||||||
@@ -32,6 +32,7 @@
|
|||||||
class="knowledge-page"
|
class="knowledge-page"
|
||||||
data-document-url="{% url 'knowledge_base_document_list' %}"
|
data-document-url="{% url 'knowledge_base_document_list' %}"
|
||||||
data-search-url="{% url 'knowledge_base_search' %}"
|
data-search-url="{% url 'knowledge_base_search' %}"
|
||||||
|
data-rebuild-url="{% url 'knowledge_base_rebuild_index' %}"
|
||||||
>
|
>
|
||||||
<header class="attachment-manager-hero attachment-manager-toolbar">
|
<header class="attachment-manager-hero attachment-manager-toolbar">
|
||||||
<div>
|
<div>
|
||||||
@@ -96,9 +97,10 @@
|
|||||||
</div>
|
</div>
|
||||||
</dl>
|
</dl>
|
||||||
<p class="knowledge-panel-note">{{ knowledge_base.status.message }}</p>
|
<p class="knowledge-panel-note">{{ knowledge_base.status.message }}</p>
|
||||||
|
<p class="upload-status" id="knowledgeRebuildStatus"></p>
|
||||||
<div class="knowledge-form-actions">
|
<div class="knowledge-form-actions">
|
||||||
<button type="button" onclick="window.location.reload()">刷新状态</button>
|
<button type="button" onclick="window.location.reload()">刷新状态</button>
|
||||||
<button type="button" disabled>重建索引</button>
|
<button type="button" id="knowledgeRebuildIndexButton">重建索引</button>
|
||||||
</div>
|
</div>
|
||||||
</section>
|
</section>
|
||||||
|
|
||||||
@@ -182,6 +184,7 @@
|
|||||||
<th>类型</th>
|
<th>类型</th>
|
||||||
<th>大小</th>
|
<th>大小</th>
|
||||||
<th>索引</th>
|
<th>索引</th>
|
||||||
|
<th>操作</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
@@ -192,10 +195,13 @@
|
|||||||
<td>{{ source.suffix }}</td>
|
<td>{{ source.suffix }}</td>
|
||||||
<td>{{ source.size }} bytes</td>
|
<td>{{ source.size }} bytes</td>
|
||||||
<td>{{ source.indexed_label }}</td>
|
<td>{{ source.indexed_label }}</td>
|
||||||
|
<td class="attachment-actions">
|
||||||
|
<button type="button" data-source-action="index">手动入库</button>
|
||||||
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
{% empty %}
|
{% empty %}
|
||||||
<tr>
|
<tr>
|
||||||
<td colspan="5" class="table-empty">暂无法规材料</td>
|
<td colspan="6" class="table-empty">暂无法规材料</td>
|
||||||
</tr>
|
</tr>
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
</tbody>
|
</tbody>
|
||||||
@@ -209,5 +215,5 @@
|
|||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
|
||||||
{% block scripts %}
|
{% block scripts %}
|
||||||
<script src="{% static 'js/knowledge_base.js' %}?v=20260608-kb5"></script>
|
<script src="{% static 'js/knowledge_base.js' %}?v=20260608-kb6"></script>
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from django.core.files.uploadedfile import SimpleUploadedFile
|
|||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
|
||||||
from review_agent.knowledge_base import build_knowledge_base_context, delete_document, search_knowledge_base
|
from review_agent.knowledge_base import build_knowledge_base_context, delete_document, search_knowledge_base
|
||||||
|
from review_agent.views import rebuild_knowledge_base_index
|
||||||
from review_agent.models import KnowledgeBaseDocument
|
from review_agent.models import KnowledgeBaseDocument
|
||||||
|
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ def test_knowledge_base_context_reports_rule_and_sources():
|
|||||||
assert context["rule"]["requirement_count"] > 0
|
assert context["rule"]["requirement_count"] > 0
|
||||||
assert context["source_count"] > 0
|
assert context["source_count"] > 0
|
||||||
assert context["collection_name"] == "nmpa_ivd_registration_v1"
|
assert context["collection_name"] == "nmpa_ivd_registration_v1"
|
||||||
|
assert not any("模拟题二" in source["relative_path"] for source in context["sources"])
|
||||||
|
|
||||||
|
|
||||||
def test_knowledge_base_page_requires_login(client):
|
def test_knowledge_base_page_requires_login(client):
|
||||||
@@ -36,6 +38,11 @@ def test_knowledge_base_page_renders_for_user(client, django_user_model):
|
|||||||
content = response.content.decode("utf-8")
|
content = response.content.decode("utf-8")
|
||||||
tabbar = content[content.index('<div class="tabbar"') : content.index("</div>", content.index('<div class="tabbar"'))]
|
tabbar = content[content.index('<div class="tabbar"') : content.index("</div>", content.index('<div class="tabbar"'))]
|
||||||
assert tabbar.index("审核智能体") < tabbar.index("知识库管理") < tabbar.index("附件管理")
|
assert tabbar.index("审核智能体") < tabbar.index("知识库管理") < tabbar.index("附件管理")
|
||||||
|
assert "data-rebuild-url=" in content
|
||||||
|
assert 'id="knowledgeRebuildIndexButton"' in content
|
||||||
|
assert "重建索引" in content
|
||||||
|
assert 'data-source-action="index"' in content
|
||||||
|
assert "手动入库" in content
|
||||||
|
|
||||||
|
|
||||||
def test_knowledge_base_status_api(client, django_user_model):
|
def test_knowledge_base_status_api(client, django_user_model):
|
||||||
@@ -48,6 +55,53 @@ def test_knowledge_base_status_api(client, django_user_model):
|
|||||||
assert response.json()["rule"]["code"] == "nmpa_ivd_registration_v1"
|
assert response.json()["rule"]["code"] == "nmpa_ivd_registration_v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_knowledge_base_rebuild_index_api(client, django_user_model, monkeypatch):
|
||||||
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
|
client.force_login(user)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"review_agent.views.rebuild_knowledge_base_index",
|
||||||
|
lambda: calls.append("rebuild") or {"chunk_count": 12},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(reverse("knowledge_base_rebuild_index"))
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["chunk_count"] == 12
|
||||||
|
assert response.json()["knowledge_base"]["collection"]["count"] >= 0
|
||||||
|
assert calls == ["rebuild"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_rebuild_knowledge_base_index_requests_reset(settings, tmp_path, monkeypatch):
|
||||||
|
settings.MEDIA_ROOT = tmp_path
|
||||||
|
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "chroma"
|
||||||
|
settings.REGULATORY_RAG_CHROMA_PATH.mkdir()
|
||||||
|
stale_file = settings.REGULATORY_RAG_CHROMA_PATH / "chroma.sqlite3"
|
||||||
|
stale_file.write_text("stale", encoding="utf-8")
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
monkeypatch.setattr("review_agent.views.load_rule_file", lambda: {"source_material_dir": "docs/0.原始材料"})
|
||||||
|
monkeypatch.setattr("review_agent.views.get_embedding_provider", lambda: "provider")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"review_agent.views.build_chroma_index",
|
||||||
|
lambda source_dir, embedding_provider, reset=False: calls.append(
|
||||||
|
{
|
||||||
|
"source_dir": source_dir,
|
||||||
|
"embedding_provider": embedding_provider,
|
||||||
|
"reset": reset,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
or 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = rebuild_knowledge_base_index()
|
||||||
|
|
||||||
|
assert payload["chunk_count"] == 8
|
||||||
|
assert calls[0]["embedding_provider"] == "provider"
|
||||||
|
assert calls[0]["reset"] is True
|
||||||
|
|
||||||
|
|
||||||
def test_knowledge_base_search_rejects_blank_query():
|
def test_knowledge_base_search_rejects_blank_query():
|
||||||
payload = search_knowledge_base("")
|
payload = search_knowledge_base("")
|
||||||
|
|
||||||
@@ -103,6 +157,8 @@ def test_knowledge_base_search_api_returns_payload(client, django_user_model):
|
|||||||
|
|
||||||
def test_knowledge_base_document_crud_api(client, settings, tmp_path, django_user_model):
|
def test_knowledge_base_document_crud_api(client, settings, tmp_path, django_user_model):
|
||||||
settings.MEDIA_ROOT = tmp_path
|
settings.MEDIA_ROOT = tmp_path
|
||||||
|
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "chroma"
|
||||||
|
settings.REGULATORY_RAG_PROVIDER = "deterministic"
|
||||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
client.force_login(user)
|
client.force_login(user)
|
||||||
|
|
||||||
@@ -199,6 +255,8 @@ def test_knowledge_base_document_api_is_scoped_to_owner(client, django_user_mode
|
|||||||
|
|
||||||
def test_knowledge_base_document_manual_index_api(client, settings, tmp_path, django_user_model):
|
def test_knowledge_base_document_manual_index_api(client, settings, tmp_path, django_user_model):
|
||||||
settings.MEDIA_ROOT = tmp_path
|
settings.MEDIA_ROOT = tmp_path
|
||||||
|
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "chroma"
|
||||||
|
settings.REGULATORY_RAG_PROVIDER = "deterministic"
|
||||||
user = django_user_model.objects.create_user(username="owner", password="pass")
|
user = django_user_model.objects.create_user(username="owner", password="pass")
|
||||||
client.force_login(user)
|
client.force_login(user)
|
||||||
source_path = tmp_path / "manual.md"
|
source_path = tmp_path / "manual.md"
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from review_agent.regulatory_review.services.rag_citation import (
|
from review_agent.regulatory_review.services.rag_citation import (
|
||||||
@@ -7,6 +9,7 @@ from review_agent.regulatory_review.services.rag_citation import (
|
|||||||
from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider
|
from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider
|
||||||
from review_agent.regulatory_review.services.rag_index import chunk_text
|
from review_agent.regulatory_review.services.rag_index import chunk_text
|
||||||
from review_agent.regulatory_review.services.rag_index import collect_source_chunks
|
from review_agent.regulatory_review.services.rag_index import collect_source_chunks
|
||||||
|
from review_agent.regulatory_review.services.rag_index import build_chroma_index
|
||||||
|
|
||||||
|
|
||||||
def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch):
|
def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch):
|
||||||
@@ -86,3 +89,141 @@ def test_collect_source_chunks_requires_attachment4_extraction(monkeypatch, tmp_
|
|||||||
|
|
||||||
with pytest.raises(RuntimeError, match="附件 4"):
|
with pytest.raises(RuntimeError, match="附件 4"):
|
||||||
collect_source_chunks(source_dir)
|
collect_source_chunks(source_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_source_chunks_excludes_demo_agent_materials(monkeypatch, tmp_path):
|
||||||
|
source_dir = tmp_path / "sources"
|
||||||
|
source_dir.mkdir()
|
||||||
|
demo_dir = source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent"
|
||||||
|
demo_dir.mkdir()
|
||||||
|
(demo_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.md").write_text("题目材料", encoding="utf-8")
|
||||||
|
(source_dir / "【模拟题二】试剂盒临床注册文件准备与审核Agent.docx").write_bytes(b"demo")
|
||||||
|
real_source = source_dir / "附件 4 体外诊断试剂注册申报资料要求及说明.doc"
|
||||||
|
real_source.write_bytes(b"rule")
|
||||||
|
|
||||||
|
def fake_extract(path):
|
||||||
|
return "附件4 正文" if path == real_source else "不应被抽取"
|
||||||
|
|
||||||
|
monkeypatch.setattr("review_agent.regulatory_review.services.rag_index.extract_text_from_path", fake_extract)
|
||||||
|
|
||||||
|
chunks = collect_source_chunks(source_dir)
|
||||||
|
|
||||||
|
assert chunks
|
||||||
|
assert all("模拟题二" not in chunk.metadata["source"] for chunk in chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_chroma_index_reset_recreates_collection_without_deleting_index_dir(settings, monkeypatch, tmp_path):
|
||||||
|
settings.MEDIA_ROOT = tmp_path
|
||||||
|
persist_path = tmp_path / "chroma"
|
||||||
|
persist_path.mkdir()
|
||||||
|
stale_file = persist_path / "chroma.sqlite3"
|
||||||
|
stale_file.write_text("stale", encoding="utf-8")
|
||||||
|
source_dir = tmp_path / "sources"
|
||||||
|
source_dir.mkdir()
|
||||||
|
(source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8")
|
||||||
|
client_states = []
|
||||||
|
deleted_collections = []
|
||||||
|
|
||||||
|
class FakeCollection:
|
||||||
|
def upsert(self, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class FakeClient:
|
||||||
|
def __init__(self, path):
|
||||||
|
client_states.append({"path": path, "stale_exists": stale_file.exists()})
|
||||||
|
|
||||||
|
def delete_collection(self, name):
|
||||||
|
deleted_collections.append(name)
|
||||||
|
|
||||||
|
def get_or_create_collection(self, name):
|
||||||
|
return FakeCollection()
|
||||||
|
|
||||||
|
class FakeSharedSystemClient:
|
||||||
|
@staticmethod
|
||||||
|
def clear_system_cache():
|
||||||
|
client_states.append({"path": "cache-cleared", "stale_exists": stale_file.exists()})
|
||||||
|
|
||||||
|
monkeypatch.setitem(sys.modules, "chromadb", type("FakeChromaModule", (), {"PersistentClient": FakeClient}))
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"chromadb.api.shared_system_client",
|
||||||
|
type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}),
|
||||||
|
)
|
||||||
|
|
||||||
|
count = build_chroma_index(
|
||||||
|
source_dir=source_dir,
|
||||||
|
embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts],
|
||||||
|
persist_path=persist_path,
|
||||||
|
collection_name="test",
|
||||||
|
reset=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert count == 1
|
||||||
|
assert client_states == [
|
||||||
|
{"path": str(persist_path), "stale_exists": True},
|
||||||
|
{"path": "cache-cleared", "stale_exists": True},
|
||||||
|
{"path": str(persist_path), "stale_exists": True},
|
||||||
|
]
|
||||||
|
assert stale_file.exists()
|
||||||
|
assert deleted_collections == ["test"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_chroma_index_reset_clears_bad_index_dir_after_chroma_cache_reset(settings, monkeypatch, tmp_path):
|
||||||
|
settings.MEDIA_ROOT = tmp_path
|
||||||
|
persist_path = tmp_path / "chroma"
|
||||||
|
persist_path.mkdir()
|
||||||
|
stale_file = persist_path / "chroma.sqlite3"
|
||||||
|
stale_file.write_text("stale", encoding="utf-8")
|
||||||
|
source_dir = tmp_path / "sources"
|
||||||
|
source_dir.mkdir()
|
||||||
|
(source_dir / "rule.md").write_text("注册检验报告要求", encoding="utf-8")
|
||||||
|
events = []
|
||||||
|
|
||||||
|
class FakeCollection:
|
||||||
|
def upsert(self, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class BrokenThenFreshClient:
|
||||||
|
attempts = 0
|
||||||
|
|
||||||
|
def __init__(self, path):
|
||||||
|
BrokenThenFreshClient.attempts += 1
|
||||||
|
events.append(("client", BrokenThenFreshClient.attempts, stale_file.exists()))
|
||||||
|
if BrokenThenFreshClient.attempts == 1:
|
||||||
|
raise ValueError("Could not connect to tenant default_tenant")
|
||||||
|
|
||||||
|
def get_or_create_collection(self, name):
|
||||||
|
return FakeCollection()
|
||||||
|
|
||||||
|
class FakeSharedSystemClient:
|
||||||
|
@staticmethod
|
||||||
|
def clear_system_cache():
|
||||||
|
events.append(("clear_cache", stale_file.exists()))
|
||||||
|
|
||||||
|
fake_chromadb = type(
|
||||||
|
"FakeChromaModule",
|
||||||
|
(),
|
||||||
|
{"PersistentClient": BrokenThenFreshClient},
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(sys.modules, "chromadb", fake_chromadb)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"chromadb.api.shared_system_client",
|
||||||
|
type("FakeSharedSystemClientModule", (), {"SharedSystemClient": FakeSharedSystemClient}),
|
||||||
|
)
|
||||||
|
|
||||||
|
count = build_chroma_index(
|
||||||
|
source_dir=source_dir,
|
||||||
|
embedding_provider=lambda texts: [[0.1, 0.2] for _ in texts],
|
||||||
|
persist_path=persist_path,
|
||||||
|
collection_name="test",
|
||||||
|
reset=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert count == 1
|
||||||
|
assert events == [
|
||||||
|
("client", 1, True),
|
||||||
|
("clear_cache", True),
|
||||||
|
("client", 2, False),
|
||||||
|
]
|
||||||
|
assert not stale_file.exists()
|
||||||
|
|||||||
Reference in New Issue
Block a user