feat(knowledge-base): 增加全局知识库管理

This commit is contained in:
2026-06-08 21:37:32 +08:00
parent e6fa738fd5
commit 5ecf78c5d6
12 changed files with 1425 additions and 2 deletions

View File

@@ -0,0 +1,397 @@
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from django.conf import settings
from django.core.files.uploadedfile import UploadedFile
from review_agent.models import KnowledgeBaseDocument
from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations
from review_agent.regulatory_review.services.rag_embedding import DeterministicEmbeddingProvider
from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path
from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file
SUPPORTED_SOURCE_SUFFIXES = {".doc", ".docx", ".pdf", ".txt", ".md", ".pptx", ".xlsx"}
@dataclass(frozen=True)
class ChromaCollectionState:
exists: bool
count: int = 0
error_message: str = ""
sample_metadatas: list[dict[str, Any]] | None = None
source_chunk_counts: dict[str, int] | None = None
def build_knowledge_base_context() -> dict[str, Any]:
rule_info = _rule_info()
source_dir = Path(settings.BASE_DIR) / str(rule_info.get("source_material_dir") or "docs/0.原始材料")
sources = list_source_documents(source_dir)
collection = get_chroma_collection_state()
return {
"name": "NMPA IVD 注册资料法规库",
"description": "用于体外诊断试剂注册资料法规核查的结构化规则和 RAG 依据检索。",
"provider": settings.REGULATORY_RAG_PROVIDER,
"collection_name": settings.REGULATORY_RAG_COLLECTION,
"chroma_path": settings.REGULATORY_RAG_CHROMA_PATH,
"rule": rule_info,
"source_dir": str(source_dir),
"sources": sources,
"source_count": len(sources),
"supported_source_count": sum(1 for item in sources if item["supported"]),
"collection": {
"exists": collection.exists,
"count": collection.count,
"error_message": collection.error_message,
"sample_metadatas": collection.sample_metadatas or [],
},
"status": _status_label(collection),
"build_commands": [
"python manage.py regulatory_rag_build --provider deterministic",
"python manage.py regulatory_rag_build --provider siliconflow",
],
"managed_documents": [],
}
def build_knowledge_base_context_for_user(user) -> dict[str, Any]:
context = build_knowledge_base_context()
documents = list_documents_for_user(user)
context["managed_documents"] = documents
context["managed_document_count"] = len(documents)
context["active_managed_document_count"] = sum(1 for item in documents if item["is_active"])
return context
def list_source_documents(source_dir: Path) -> list[dict[str, Any]]:
if not source_dir.exists():
return []
collection = get_chroma_collection_state()
source_chunk_counts = collection.source_chunk_counts or {}
documents: list[dict[str, Any]] = []
for path in sorted(source_dir.rglob("*")):
if not path.is_file():
continue
suffix = path.suffix.lower()
relative_path = str(path.relative_to(source_dir))
indexed_chunk_count = source_chunk_counts.get(relative_path, 0)
documents.append(
{
"name": path.name,
"relative_path": relative_path,
"suffix": suffix.lstrip(".") or "unknown",
"size": path.stat().st_size,
"supported": suffix in SUPPORTED_SOURCE_SUFFIXES,
"indexed": indexed_chunk_count > 0,
"indexed_chunk_count": indexed_chunk_count,
"indexed_label": f"已入库 {indexed_chunk_count}" if indexed_chunk_count else "未入库",
}
)
return documents
def search_knowledge_base(query: str, *, n_results: int = 3) -> dict[str, Any]:
normalized = (query or "").strip()
if not normalized:
return {"query": normalized, "results": [], "error_message": "请输入检索问题。"}
try:
results = retrieve_citations(
normalized,
embedding_provider=DeterministicEmbeddingProvider(),
n_results=n_results,
)
except RagIndexUnavailable as exc:
return {"query": normalized, "results": [], "error_message": str(exc)}
except Exception as exc:
return {"query": normalized, "results": [], "error_message": f"检索失败:{exc}"}
return {"query": normalized, "results": filter_active_knowledge_results(results), "error_message": ""}
def list_documents_for_user(user) -> list[dict[str, Any]]:
return [
serialize_document(document)
for document in KnowledgeBaseDocument.objects.filter(user=user).exclude(status=KnowledgeBaseDocument.Status.DELETED)
]
def create_document_from_upload(
*,
user,
uploaded_file: UploadedFile,
display_name: str = "",
description: str = "",
is_active: bool = True,
) -> KnowledgeBaseDocument:
root = Path(settings.MEDIA_ROOT) / "knowledge_base" / "users" / str(user.pk)
root.mkdir(parents=True, exist_ok=True)
target = _unique_target_path(root, uploaded_file.name)
with target.open("wb") as handle:
for chunk in uploaded_file.chunks():
handle.write(chunk)
status = KnowledgeBaseDocument.Status.ACTIVE if is_active else KnowledgeBaseDocument.Status.DISABLED
document = KnowledgeBaseDocument.objects.create(
user=user,
display_name=(display_name or uploaded_file.name).strip(),
original_name=uploaded_file.name,
storage_path=str(target),
file_size=target.stat().st_size,
content_type=getattr(uploaded_file, "content_type", "") or "",
description=description.strip(),
status=status,
is_active=is_active,
)
if is_active:
index_managed_document(document)
return document
def update_document(document: KnowledgeBaseDocument, payload: dict[str, Any]) -> KnowledgeBaseDocument:
update_fields = []
if "display_name" in payload:
document.display_name = str(payload.get("display_name") or "").strip() or document.original_name
update_fields.append("display_name")
if "description" in payload:
document.description = str(payload.get("description") or "").strip()
update_fields.append("description")
if "is_active" in payload:
document.is_active = bool(payload.get("is_active"))
document.status = KnowledgeBaseDocument.Status.ACTIVE if document.is_active else KnowledgeBaseDocument.Status.DISABLED
update_fields.extend(["is_active", "status"])
if update_fields:
update_fields.append("updated_at")
document.save(update_fields=update_fields)
return document
def delete_document(document: KnowledgeBaseDocument) -> KnowledgeBaseDocument:
remove_managed_document_from_index(document)
document.status = KnowledgeBaseDocument.Status.DELETED
document.is_active = False
document.indexed_chunk_count = 0
document.metadata = {**(document.metadata or {}), "index_status": "deleted", "index_error": ""}
document.save(update_fields=["status", "is_active", "indexed_chunk_count", "metadata", "updated_at"])
return document
def serialize_document(document: KnowledgeBaseDocument) -> dict[str, Any]:
indexed_label = f"已入库 {document.indexed_chunk_count}" if document.indexed_chunk_count else "未入库"
return {
"id": document.pk,
"display_name": document.display_name,
"original_name": document.original_name,
"description": document.description,
"file_size": document.file_size,
"content_type": document.content_type,
"status": document.status,
"is_active": document.is_active,
"indexed_chunk_count": document.indexed_chunk_count,
"indexed_label": indexed_label,
"created_at": document.created_at.isoformat() if document.created_at else "",
"updated_at": document.updated_at.isoformat() if document.updated_at else "",
}
def index_managed_document(document: KnowledgeBaseDocument) -> int:
path = Path(document.storage_path)
if not path.is_absolute():
path = Path(settings.MEDIA_ROOT) / document.storage_path
try:
text = extract_text_from_path(path)
source = f"用户知识库/{document.user_id}/{document.pk}/{document.original_name}"
chunks = chunk_text(text, source=source)
if not chunks:
document.indexed_chunk_count = 0
document.metadata = {**(document.metadata or {}), "index_status": "empty", "index_error": ""}
document.save(update_fields=["indexed_chunk_count", "metadata", "updated_at"])
return 0
collection = _load_chroma_collection()
texts = [chunk.text for chunk in chunks]
embeddings = DeterministicEmbeddingProvider()(texts)
ids = [
hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
for chunk in chunks
]
metadatas = [
{
**chunk.metadata,
"source_type": "managed_document",
"document_id": document.pk,
"user_id": document.user_id,
"original_name": document.original_name,
}
for chunk in chunks
]
collection.upsert(ids=ids, documents=texts, metadatas=metadatas, embeddings=embeddings)
document.indexed_chunk_count = len(chunks)
document.metadata = {**(document.metadata or {}), "index_status": "indexed", "index_error": ""}
document.save(update_fields=["indexed_chunk_count", "metadata", "updated_at"])
return len(chunks)
except Exception as exc:
document.indexed_chunk_count = 0
document.metadata = {**(document.metadata or {}), "index_status": "failed", "index_error": str(exc)}
document.save(update_fields=["indexed_chunk_count", "metadata", "updated_at"])
return 0
def remove_managed_document_from_index(document: KnowledgeBaseDocument) -> None:
try:
collection = _load_chroma_collection()
collection.delete(where={"document_id": document.pk})
except Exception as exc:
document.metadata = {**(document.metadata or {}), "index_delete_error": str(exc)}
def filter_active_knowledge_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]:
managed_ids = {
int((item.get("metadata") or {}).get("document_id"))
for item in results
if (item.get("metadata") or {}).get("source_type") == "managed_document"
and (item.get("metadata") or {}).get("document_id") is not None
}
if not managed_ids:
return results
active_ids = set(
KnowledgeBaseDocument.objects.filter(
pk__in=managed_ids,
status=KnowledgeBaseDocument.Status.ACTIVE,
is_active=True,
).values_list("pk", flat=True)
)
filtered = []
for item in results:
metadata = item.get("metadata") or {}
if metadata.get("source_type") != "managed_document":
filtered.append(item)
continue
try:
document_id = int(metadata.get("document_id"))
except (TypeError, ValueError):
continue
if document_id in active_ids:
filtered.append(item)
return filtered
def _load_chroma_collection():
try:
import chromadb
except ImportError as exc:
raise RuntimeError("chromadb 未安装。") from exc
persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH)
persist_path.mkdir(parents=True, exist_ok=True)
return chromadb.PersistentClient(path=str(persist_path)).get_or_create_collection(
settings.REGULATORY_RAG_COLLECTION
)
def get_chroma_collection_state() -> ChromaCollectionState:
persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH)
if not persist_path.exists():
return ChromaCollectionState(exists=False, error_message="法规 RAG 索引目录不存在。")
try:
import chromadb
except ImportError:
return ChromaCollectionState(exists=False, error_message="chromadb 未安装。")
try:
collection = chromadb.PersistentClient(path=str(persist_path)).get_collection(settings.REGULATORY_RAG_COLLECTION)
count = collection.count()
metadatas = _load_collection_metadatas(collection, count)
return ChromaCollectionState(
exists=True,
count=count,
sample_metadatas=metadatas[:10],
source_chunk_counts=_count_chunks_by_source(metadatas),
)
except Exception as exc:
return ChromaCollectionState(exists=False, error_message=f"法规 RAG collection 不可用:{exc}")
def _load_collection_metadatas(collection, count: int) -> list[dict[str, Any]]:
metadatas: list[dict[str, Any]] = []
if count <= 0:
return metadatas
page_size = 500
for offset in range(0, count, page_size):
payload = collection.get(
include=["metadatas"],
limit=min(page_size, count - offset),
offset=offset,
)
metadatas.extend(payload.get("metadatas") or [])
return metadatas
def _count_chunks_by_source(metadatas: list[dict[str, Any]]) -> dict[str, int]:
counts: dict[str, int] = {}
for metadata in metadatas:
source = str((metadata or {}).get("source") or "")
if source:
counts[source] = counts.get(source, 0) + 1
return counts
def _rule_info() -> dict[str, Any]:
try:
payload = load_rule_file()
requirements = payload.get("requirements") or []
severity_counts: dict[str, int] = {}
chapter_codes = set()
for requirement in requirements:
severity = str(requirement.get("severity") or "unknown")
severity_counts[severity] = severity_counts.get(severity, 0) + 1
attachment4_code = str(requirement.get("attachment4_code") or "")
if attachment4_code:
chapter_codes.add(attachment4_code.split(".")[0])
return {
"status": "ok",
"code": payload.get("code", ""),
"name": payload.get("name", ""),
"path": str(DEFAULT_RULE_PATH),
"hash": compute_file_sha256(DEFAULT_RULE_PATH),
"rag_collection": payload.get("rag_collection", ""),
"source_material_dir": payload.get("source_material_dir", "docs/0.原始材料"),
"requirement_count": len(requirements),
"chapter_count": len(chapter_codes),
"severity_counts": severity_counts,
}
except Exception as exc:
return {
"status": "failed",
"code": "",
"name": "",
"path": str(DEFAULT_RULE_PATH),
"hash": "",
"rag_collection": "",
"source_material_dir": "docs/0.原始材料",
"requirement_count": 0,
"chapter_count": 0,
"severity_counts": {},
"error_message": str(exc),
}
def _status_label(collection: ChromaCollectionState) -> dict[str, str]:
if not collection.exists:
return {"code": "missing", "label": "未构建", "message": collection.error_message}
if collection.count < 20:
return {"code": "thin", "label": "索引过少", "message": "RAG 能力已打通,但当前索引内容较少,建议补齐材料后重建。"}
return {"code": "ready", "label": "可用", "message": "RAG 索引已构建,可用于法规依据辅助检索。"}
def _unique_target_path(root: Path, original_name: str) -> Path:
safe_name = Path(original_name).name or "document"
target = root / safe_name
if not target.exists():
return target
stem = target.stem
suffix = target.suffix
index = 2
while True:
candidate = root / f"{stem}-{index}{suffix}"
if not candidate.exists():
return candidate
index += 1