fix(kb): 完善知识库入库和重建索引

This commit is contained in:
2026-06-08 23:45:34 +08:00
parent d8cd95e590
commit 2b5093040d
9 changed files with 355 additions and 9 deletions

View File

@@ -10,8 +10,8 @@ from django.core.files.uploadedfile import UploadedFile
from review_agent.models import KnowledgeBaseDocument
from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations
from review_agent.regulatory_review.services.rag_embedding import DeterministicEmbeddingProvider
from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path
from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider
from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path, is_excluded_source_path
from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file
@@ -78,6 +78,8 @@ def list_source_documents(source_dir: Path) -> list[dict[str, Any]]:
continue
suffix = path.suffix.lower()
relative_path = str(path.relative_to(source_dir))
if is_excluded_source_path(relative_path):
continue
indexed_chunk_count = source_chunk_counts.get(relative_path, 0)
documents.append(
{
@@ -101,7 +103,7 @@ def search_knowledge_base(query: str, *, n_results: int = 3) -> dict[str, Any]:
try:
results = retrieve_citations(
normalized,
embedding_provider=DeterministicEmbeddingProvider(),
embedding_provider=get_embedding_provider(),
n_results=n_results,
)
except RagIndexUnavailable as exc:
@@ -210,7 +212,7 @@ def index_managed_document(document: KnowledgeBaseDocument) -> int:
return 0
collection = _load_chroma_collection()
texts = [chunk.text for chunk in chunks]
embeddings = DeterministicEmbeddingProvider()(texts)
embeddings = get_embedding_provider()(texts)
ids = [
hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
for chunk in chunks

View File

@@ -23,7 +23,7 @@ class Command(BaseCommand):
raise CommandError(f"法规材料目录不存在:{source_dir}")
try:
provider = get_embedding_provider(options["provider"])
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider)
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider, reset=True)
except Exception as exc:
raise CommandError(str(exc)) from exc
self.stdout.write(

View File

@@ -23,6 +23,8 @@ from .rag_embedding import EmbeddingFunction
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
EXCLUDED_SOURCE_KEYWORDS = ("模拟题二", "试剂盒临床注册文件准备与审核Agent")
@dataclass(frozen=True)
class TextChunk:
@@ -227,6 +229,8 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
for path in sorted(source_dir.rglob("*")):
if not path.is_file():
continue
if is_excluded_source_path(path.relative_to(source_dir)):
continue
try:
text = extract_text_from_path(path)
except RuntimeError as exc:
@@ -238,6 +242,11 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
return chunks
def is_excluded_source_path(path: Path | str) -> bool:
normalized = str(path)
return any(keyword in normalized for keyword in EXCLUDED_SOURCE_KEYWORDS)
def _is_attachment4(path: Path) -> bool:
normalized = path.name.replace(" ", "")
return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized
@@ -249,6 +258,7 @@ def build_chroma_index(
embedding_provider: EmbeddingFunction,
persist_path: Path | None = None,
collection_name: str | None = None,
reset: bool = False,
) -> int:
try:
import chromadb
@@ -259,7 +269,22 @@ def build_chroma_index(
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
persist_path.mkdir(parents=True, exist_ok=True)
chunks = collect_source_chunks(source_dir)
client = chromadb.PersistentClient(path=str(persist_path))
try:
client = chromadb.PersistentClient(path=str(persist_path))
except Exception:
if not reset:
raise
clear_chroma_system_cache()
clear_chroma_index_dir(persist_path)
persist_path.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(path=str(persist_path))
if reset:
try:
client.delete_collection(collection_name)
clear_chroma_system_cache()
client = chromadb.PersistentClient(path=str(persist_path))
except Exception:
pass
collection = client.get_or_create_collection(collection_name)
if not chunks:
return 0
@@ -276,3 +301,22 @@ def build_chroma_index(
embeddings=embeddings,
)
return len(chunks)
def clear_chroma_index_dir(persist_path: Path | str | None = None) -> None:
chroma_path = Path(persist_path or settings.REGULATORY_RAG_CHROMA_PATH).resolve()
media_root = Path(settings.MEDIA_ROOT).resolve()
try:
chroma_path.relative_to(media_root)
except ValueError as exc:
raise RuntimeError("法规 RAG 索引目录必须位于 MEDIA_ROOT 内。") from exc
if chroma_path.exists():
shutil.rmtree(chroma_path)
def clear_chroma_system_cache() -> None:
try:
from chromadb.api.shared_system_client import SharedSystemClient
except Exception:
return
SharedSystemClient.clear_system_cache()

View File

@@ -25,6 +25,7 @@ from .views import (
knowledge_base_document_detail,
knowledge_base_document_index,
knowledge_base_documents,
knowledge_base_rebuild_index,
knowledge_base_search,
knowledge_base_status,
)
@@ -121,6 +122,11 @@ urlpatterns = [
knowledge_base_search,
name="knowledge_base_search",
),
path(
"api/review-agent/knowledge-base/rebuild-index/",
knowledge_base_rebuild_index,
name="knowledge_base_rebuild_index",
),
path(
"api/review-agent/knowledge-base/documents/",
knowledge_base_documents,

View File

@@ -1,6 +1,8 @@
from django.contrib.auth.decorators import login_required
from django.conf import settings
from django.db.models import Count, Q, Sum
import json
from pathlib import Path
from django.http import HttpRequest, HttpResponse, JsonResponse, StreamingHttpResponse
from django.shortcuts import redirect, render
@@ -27,6 +29,9 @@ from .knowledge_base import (
)
from .models import KnowledgeBaseDocument
from .regulatory_review.services.info_extract import ensure_regulatory_condition_candidates
from .regulatory_review.services.rag_embedding import get_embedding_provider
from .regulatory_review.services.rag_index import build_chroma_index
from .regulatory_review.services.rule_loader import load_rule_file
@login_required
@@ -151,6 +156,24 @@ def knowledge_base_status(request: HttpRequest) -> JsonResponse:
return JsonResponse(build_knowledge_base_context_for_user(request.user))
@login_required
@require_http_methods(["POST"])
def knowledge_base_rebuild_index(request: HttpRequest) -> JsonResponse:
payload = rebuild_knowledge_base_index()
return JsonResponse({"knowledge_base": build_knowledge_base_context_for_user(request.user), **payload})
def rebuild_knowledge_base_index() -> dict[str, object]:
rule_set = load_rule_file()
source_dir = Path(settings.BASE_DIR) / rule_set["source_material_dir"]
chunk_count = build_chroma_index(
source_dir=source_dir,
embedding_provider=get_embedding_provider(),
reset=True,
)
return {"chunk_count": chunk_count}
@login_required
@require_http_methods(["POST"])
def knowledge_base_search(request: HttpRequest) -> JsonResponse: