fix(kb): 完善知识库入库和重建索引
This commit is contained in:
@@ -10,8 +10,8 @@ from django.core.files.uploadedfile import UploadedFile
|
||||
|
||||
from review_agent.models import KnowledgeBaseDocument
|
||||
from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations
|
||||
from review_agent.regulatory_review.services.rag_embedding import DeterministicEmbeddingProvider
|
||||
from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path
|
||||
from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider
|
||||
from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path, is_excluded_source_path
|
||||
from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file
|
||||
|
||||
|
||||
@@ -78,6 +78,8 @@ def list_source_documents(source_dir: Path) -> list[dict[str, Any]]:
|
||||
continue
|
||||
suffix = path.suffix.lower()
|
||||
relative_path = str(path.relative_to(source_dir))
|
||||
if is_excluded_source_path(relative_path):
|
||||
continue
|
||||
indexed_chunk_count = source_chunk_counts.get(relative_path, 0)
|
||||
documents.append(
|
||||
{
|
||||
@@ -101,7 +103,7 @@ def search_knowledge_base(query: str, *, n_results: int = 3) -> dict[str, Any]:
|
||||
try:
|
||||
results = retrieve_citations(
|
||||
normalized,
|
||||
embedding_provider=DeterministicEmbeddingProvider(),
|
||||
embedding_provider=get_embedding_provider(),
|
||||
n_results=n_results,
|
||||
)
|
||||
except RagIndexUnavailable as exc:
|
||||
@@ -210,7 +212,7 @@ def index_managed_document(document: KnowledgeBaseDocument) -> int:
|
||||
return 0
|
||||
collection = _load_chroma_collection()
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
embeddings = DeterministicEmbeddingProvider()(texts)
|
||||
embeddings = get_embedding_provider()(texts)
|
||||
ids = [
|
||||
hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
|
||||
for chunk in chunks
|
||||
|
||||
@@ -23,7 +23,7 @@ class Command(BaseCommand):
|
||||
raise CommandError(f"法规材料目录不存在:{source_dir}")
|
||||
try:
|
||||
provider = get_embedding_provider(options["provider"])
|
||||
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider)
|
||||
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider, reset=True)
|
||||
except Exception as exc:
|
||||
raise CommandError(str(exc)) from exc
|
||||
self.stdout.write(
|
||||
|
||||
@@ -23,6 +23,8 @@ from .rag_embedding import EmbeddingFunction
|
||||
|
||||
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
|
||||
|
||||
EXCLUDED_SOURCE_KEYWORDS = ("模拟题二", "试剂盒临床注册文件准备与审核Agent")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextChunk:
|
||||
@@ -227,6 +229,8 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
|
||||
for path in sorted(source_dir.rglob("*")):
|
||||
if not path.is_file():
|
||||
continue
|
||||
if is_excluded_source_path(path.relative_to(source_dir)):
|
||||
continue
|
||||
try:
|
||||
text = extract_text_from_path(path)
|
||||
except RuntimeError as exc:
|
||||
@@ -238,6 +242,11 @@ def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
|
||||
return chunks
|
||||
|
||||
|
||||
def is_excluded_source_path(path: Path | str) -> bool:
|
||||
normalized = str(path)
|
||||
return any(keyword in normalized for keyword in EXCLUDED_SOURCE_KEYWORDS)
|
||||
|
||||
|
||||
def _is_attachment4(path: Path) -> bool:
|
||||
normalized = path.name.replace(" ", "")
|
||||
return "附件4" in normalized and "体外诊断试剂注册申报资料要求及说明" in normalized
|
||||
@@ -249,6 +258,7 @@ def build_chroma_index(
|
||||
embedding_provider: EmbeddingFunction,
|
||||
persist_path: Path | None = None,
|
||||
collection_name: str | None = None,
|
||||
reset: bool = False,
|
||||
) -> int:
|
||||
try:
|
||||
import chromadb
|
||||
@@ -259,7 +269,22 @@ def build_chroma_index(
|
||||
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
|
||||
persist_path.mkdir(parents=True, exist_ok=True)
|
||||
chunks = collect_source_chunks(source_dir)
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
try:
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
except Exception:
|
||||
if not reset:
|
||||
raise
|
||||
clear_chroma_system_cache()
|
||||
clear_chroma_index_dir(persist_path)
|
||||
persist_path.mkdir(parents=True, exist_ok=True)
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
if reset:
|
||||
try:
|
||||
client.delete_collection(collection_name)
|
||||
clear_chroma_system_cache()
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
except Exception:
|
||||
pass
|
||||
collection = client.get_or_create_collection(collection_name)
|
||||
if not chunks:
|
||||
return 0
|
||||
@@ -276,3 +301,22 @@ def build_chroma_index(
|
||||
embeddings=embeddings,
|
||||
)
|
||||
return len(chunks)
|
||||
|
||||
|
||||
def clear_chroma_index_dir(persist_path: Path | str | None = None) -> None:
|
||||
chroma_path = Path(persist_path or settings.REGULATORY_RAG_CHROMA_PATH).resolve()
|
||||
media_root = Path(settings.MEDIA_ROOT).resolve()
|
||||
try:
|
||||
chroma_path.relative_to(media_root)
|
||||
except ValueError as exc:
|
||||
raise RuntimeError("法规 RAG 索引目录必须位于 MEDIA_ROOT 内。") from exc
|
||||
if chroma_path.exists():
|
||||
shutil.rmtree(chroma_path)
|
||||
|
||||
|
||||
def clear_chroma_system_cache() -> None:
|
||||
try:
|
||||
from chromadb.api.shared_system_client import SharedSystemClient
|
||||
except Exception:
|
||||
return
|
||||
SharedSystemClient.clear_system_cache()
|
||||
|
||||
@@ -25,6 +25,7 @@ from .views import (
|
||||
knowledge_base_document_detail,
|
||||
knowledge_base_document_index,
|
||||
knowledge_base_documents,
|
||||
knowledge_base_rebuild_index,
|
||||
knowledge_base_search,
|
||||
knowledge_base_status,
|
||||
)
|
||||
@@ -121,6 +122,11 @@ urlpatterns = [
|
||||
knowledge_base_search,
|
||||
name="knowledge_base_search",
|
||||
),
|
||||
path(
|
||||
"api/review-agent/knowledge-base/rebuild-index/",
|
||||
knowledge_base_rebuild_index,
|
||||
name="knowledge_base_rebuild_index",
|
||||
),
|
||||
path(
|
||||
"api/review-agent/knowledge-base/documents/",
|
||||
knowledge_base_documents,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.conf import settings
|
||||
from django.db.models import Count, Q, Sum
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse, StreamingHttpResponse
|
||||
from django.shortcuts import redirect, render
|
||||
@@ -27,6 +29,9 @@ from .knowledge_base import (
|
||||
)
|
||||
from .models import KnowledgeBaseDocument
|
||||
from .regulatory_review.services.info_extract import ensure_regulatory_condition_candidates
|
||||
from .regulatory_review.services.rag_embedding import get_embedding_provider
|
||||
from .regulatory_review.services.rag_index import build_chroma_index
|
||||
from .regulatory_review.services.rule_loader import load_rule_file
|
||||
|
||||
|
||||
@login_required
|
||||
@@ -151,6 +156,24 @@ def knowledge_base_status(request: HttpRequest) -> JsonResponse:
|
||||
return JsonResponse(build_knowledge_base_context_for_user(request.user))
|
||||
|
||||
|
||||
@login_required
|
||||
@require_http_methods(["POST"])
|
||||
def knowledge_base_rebuild_index(request: HttpRequest) -> JsonResponse:
|
||||
payload = rebuild_knowledge_base_index()
|
||||
return JsonResponse({"knowledge_base": build_knowledge_base_context_for_user(request.user), **payload})
|
||||
|
||||
|
||||
def rebuild_knowledge_base_index() -> dict[str, object]:
|
||||
rule_set = load_rule_file()
|
||||
source_dir = Path(settings.BASE_DIR) / rule_set["source_material_dir"]
|
||||
chunk_count = build_chroma_index(
|
||||
source_dir=source_dir,
|
||||
embedding_provider=get_embedding_provider(),
|
||||
reset=True,
|
||||
)
|
||||
return {"chunk_count": chunk_count}
|
||||
|
||||
|
||||
@login_required
|
||||
@require_http_methods(["POST"])
|
||||
def knowledge_base_search(request: HttpRequest) -> JsonResponse:
|
||||
|
||||
Reference in New Issue
Block a user