from __future__ import annotations import hashlib from dataclasses import dataclass from pathlib import Path from typing import Any from django.conf import settings from django.core.files.uploadedfile import UploadedFile from review_agent.models import KnowledgeBaseDocument from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path, is_excluded_source_path from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file SUPPORTED_SOURCE_SUFFIXES = {".doc", ".docx", ".pdf", ".txt", ".md", ".pptx", ".xlsx"} @dataclass(frozen=True) class ChromaCollectionState: exists: bool count: int = 0 error_message: str = "" sample_metadatas: list[dict[str, Any]] | None = None source_chunk_counts: dict[str, int] | None = None def build_knowledge_base_context() -> dict[str, Any]: rule_info = _rule_info() source_dir = Path(settings.BASE_DIR) / str(rule_info.get("source_material_dir") or "docs/0.原始材料") sources = list_source_documents(source_dir) collection = get_chroma_collection_state() return { "name": "NMPA IVD 注册资料法规库", "description": "用于体外诊断试剂注册资料法规核查的结构化规则和 RAG 依据检索。", "provider": settings.REGULATORY_RAG_PROVIDER, "collection_name": settings.REGULATORY_RAG_COLLECTION, "chroma_path": settings.REGULATORY_RAG_CHROMA_PATH, "rule": rule_info, "source_dir": str(source_dir), "sources": sources, "source_count": len(sources), "supported_source_count": sum(1 for item in sources if item["supported"]), "collection": { "exists": collection.exists, "count": collection.count, "error_message": collection.error_message, "sample_metadatas": collection.sample_metadatas or [], }, "status": _status_label(collection), "build_commands": [ "python manage.py regulatory_rag_build --provider deterministic", "python manage.py regulatory_rag_build --provider siliconflow", ], "managed_documents": [], } def build_knowledge_base_context_for_user(user) -> dict[str, Any]: context = build_knowledge_base_context() documents = list_documents_for_user(user) context["managed_documents"] = documents context["managed_document_count"] = len(documents) context["active_managed_document_count"] = sum(1 for item in documents if item["is_active"]) return context def list_source_documents(source_dir: Path) -> list[dict[str, Any]]: if not source_dir.exists(): return [] collection = get_chroma_collection_state() source_chunk_counts = collection.source_chunk_counts or {} documents: list[dict[str, Any]] = [] for path in sorted(source_dir.rglob("*")): if not path.is_file(): continue suffix = path.suffix.lower() relative_path = str(path.relative_to(source_dir)) if is_excluded_source_path(relative_path): continue indexed_chunk_count = source_chunk_counts.get(relative_path, 0) documents.append( { "name": path.name, "relative_path": relative_path, "suffix": suffix.lstrip(".") or "unknown", "size": path.stat().st_size, "supported": suffix in SUPPORTED_SOURCE_SUFFIXES, "indexed": indexed_chunk_count > 0, "indexed_chunk_count": indexed_chunk_count, "indexed_label": f"已入库 {indexed_chunk_count} 片" if indexed_chunk_count else "未入库", } ) return documents def search_knowledge_base(query: str, *, n_results: int = 3) -> dict[str, Any]: normalized = (query or "").strip() if not normalized: return {"query": normalized, "results": [], "error_message": "请输入检索问题。"} try: results = retrieve_citations( normalized, embedding_provider=get_embedding_provider(), n_results=n_results, ) except RagIndexUnavailable as exc: return {"query": normalized, "results": [], "error_message": str(exc)} except Exception as exc: return {"query": normalized, "results": [], "error_message": f"检索失败:{exc}"} return {"query": normalized, "results": filter_active_knowledge_results(results), "error_message": ""} def list_documents_for_user(user) -> list[dict[str, Any]]: return [ serialize_document(document) for document in KnowledgeBaseDocument.objects.filter(user=user).exclude(status=KnowledgeBaseDocument.Status.DELETED) ] def create_document_from_upload( *, user, uploaded_file: UploadedFile, display_name: str = "", description: str = "", is_active: bool = True, ) -> KnowledgeBaseDocument: root = Path(settings.MEDIA_ROOT) / "knowledge_base" / "users" / str(user.pk) root.mkdir(parents=True, exist_ok=True) target = _unique_target_path(root, uploaded_file.name) with target.open("wb") as handle: for chunk in uploaded_file.chunks(): handle.write(chunk) status = KnowledgeBaseDocument.Status.ACTIVE if is_active else KnowledgeBaseDocument.Status.DISABLED document = KnowledgeBaseDocument.objects.create( user=user, display_name=(display_name or uploaded_file.name).strip(), original_name=uploaded_file.name, storage_path=str(target), file_size=target.stat().st_size, content_type=getattr(uploaded_file, "content_type", "") or "", description=description.strip(), status=status, is_active=is_active, ) if is_active: index_managed_document(document) return document def update_document(document: KnowledgeBaseDocument, payload: dict[str, Any]) -> KnowledgeBaseDocument: update_fields = [] if "display_name" in payload: document.display_name = str(payload.get("display_name") or "").strip() or document.original_name update_fields.append("display_name") if "description" in payload: document.description = str(payload.get("description") or "").strip() update_fields.append("description") if "is_active" in payload: document.is_active = bool(payload.get("is_active")) document.status = KnowledgeBaseDocument.Status.ACTIVE if document.is_active else KnowledgeBaseDocument.Status.DISABLED update_fields.extend(["is_active", "status"]) if update_fields: update_fields.append("updated_at") document.save(update_fields=update_fields) return document def delete_document(document: KnowledgeBaseDocument) -> KnowledgeBaseDocument: remove_managed_document_from_index(document) document.status = KnowledgeBaseDocument.Status.DELETED document.is_active = False document.indexed_chunk_count = 0 document.metadata = {**(document.metadata or {}), "index_status": "deleted", "index_error": ""} document.save(update_fields=["status", "is_active", "indexed_chunk_count", "metadata", "updated_at"]) return document def serialize_document(document: KnowledgeBaseDocument) -> dict[str, Any]: indexed_label = f"已入库 {document.indexed_chunk_count} 片" if document.indexed_chunk_count else "未入库" return { "id": document.pk, "display_name": document.display_name, "original_name": document.original_name, "description": document.description, "file_size": document.file_size, "content_type": document.content_type, "status": document.status, "is_active": document.is_active, "indexed_chunk_count": document.indexed_chunk_count, "indexed_label": indexed_label, "created_at": document.created_at.isoformat() if document.created_at else "", "updated_at": document.updated_at.isoformat() if document.updated_at else "", } def index_managed_document(document: KnowledgeBaseDocument) -> int: path = Path(document.storage_path) if not path.is_absolute(): path = Path(settings.MEDIA_ROOT) / document.storage_path try: text = extract_text_from_path(path) source = f"用户知识库/{document.user_id}/{document.pk}/{document.original_name}" chunks = chunk_text(text, source=source) if not chunks: document.indexed_chunk_count = 0 document.metadata = {**(document.metadata or {}), "index_status": "empty", "index_error": ""} document.save(update_fields=["indexed_chunk_count", "metadata", "updated_at"]) return 0 collection = _load_chroma_collection() texts = [chunk.text for chunk in chunks] embeddings = get_embedding_provider()(texts) ids = [ hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest() for chunk in chunks ] metadatas = [ { **chunk.metadata, "source_type": "managed_document", "document_id": document.pk, "user_id": document.user_id, "original_name": document.original_name, } for chunk in chunks ] collection.upsert(ids=ids, documents=texts, metadatas=metadatas, embeddings=embeddings) document.indexed_chunk_count = len(chunks) document.metadata = {**(document.metadata or {}), "index_status": "indexed", "index_error": ""} document.save(update_fields=["indexed_chunk_count", "metadata", "updated_at"]) return len(chunks) except Exception as exc: document.indexed_chunk_count = 0 document.metadata = {**(document.metadata or {}), "index_status": "failed", "index_error": str(exc)} document.save(update_fields=["indexed_chunk_count", "metadata", "updated_at"]) return 0 def remove_managed_document_from_index(document: KnowledgeBaseDocument) -> None: try: collection = _load_chroma_collection() collection.delete(where={"document_id": document.pk}) except Exception as exc: document.metadata = {**(document.metadata or {}), "index_delete_error": str(exc)} def filter_active_knowledge_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]: managed_ids = { int((item.get("metadata") or {}).get("document_id")) for item in results if (item.get("metadata") or {}).get("source_type") == "managed_document" and (item.get("metadata") or {}).get("document_id") is not None } if not managed_ids: return results active_ids = set( KnowledgeBaseDocument.objects.filter( pk__in=managed_ids, status=KnowledgeBaseDocument.Status.ACTIVE, is_active=True, ).values_list("pk", flat=True) ) filtered = [] for item in results: metadata = item.get("metadata") or {} if metadata.get("source_type") != "managed_document": filtered.append(item) continue try: document_id = int(metadata.get("document_id")) except (TypeError, ValueError): continue if document_id in active_ids: filtered.append(item) return filtered def _load_chroma_collection(): try: import chromadb except ImportError as exc: raise RuntimeError("chromadb 未安装。") from exc persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH) persist_path.mkdir(parents=True, exist_ok=True) return chromadb.PersistentClient(path=str(persist_path)).get_or_create_collection( settings.REGULATORY_RAG_COLLECTION ) def get_chroma_collection_state() -> ChromaCollectionState: persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH) if not persist_path.exists(): return ChromaCollectionState(exists=False, error_message="法规 RAG 索引目录不存在。") try: import chromadb except ImportError: return ChromaCollectionState(exists=False, error_message="chromadb 未安装。") try: collection = chromadb.PersistentClient(path=str(persist_path)).get_collection(settings.REGULATORY_RAG_COLLECTION) count = collection.count() metadatas = _load_collection_metadatas(collection, count) return ChromaCollectionState( exists=True, count=count, sample_metadatas=metadatas[:10], source_chunk_counts=_count_chunks_by_source(metadatas), ) except Exception as exc: return ChromaCollectionState(exists=False, error_message=f"法规 RAG collection 不可用:{exc}") def _load_collection_metadatas(collection, count: int) -> list[dict[str, Any]]: metadatas: list[dict[str, Any]] = [] if count <= 0: return metadatas page_size = 500 for offset in range(0, count, page_size): payload = collection.get( include=["metadatas"], limit=min(page_size, count - offset), offset=offset, ) metadatas.extend(payload.get("metadatas") or []) return metadatas def _count_chunks_by_source(metadatas: list[dict[str, Any]]) -> dict[str, int]: counts: dict[str, int] = {} for metadata in metadatas: source = str((metadata or {}).get("source") or "") if source: counts[source] = counts.get(source, 0) + 1 return counts def _rule_info() -> dict[str, Any]: try: payload = load_rule_file() requirements = payload.get("requirements") or [] severity_counts: dict[str, int] = {} chapter_codes = set() for requirement in requirements: severity = str(requirement.get("severity") or "unknown") severity_counts[severity] = severity_counts.get(severity, 0) + 1 attachment4_code = str(requirement.get("attachment4_code") or "") if attachment4_code: chapter_codes.add(attachment4_code.split(".")[0]) return { "status": "ok", "code": payload.get("code", ""), "name": payload.get("name", ""), "path": str(DEFAULT_RULE_PATH), "hash": compute_file_sha256(DEFAULT_RULE_PATH), "rag_collection": payload.get("rag_collection", ""), "source_material_dir": payload.get("source_material_dir", "docs/0.原始材料"), "requirement_count": len(requirements), "chapter_count": len(chapter_codes), "severity_counts": severity_counts, } except Exception as exc: return { "status": "failed", "code": "", "name": "", "path": str(DEFAULT_RULE_PATH), "hash": "", "rag_collection": "", "source_material_dir": "docs/0.原始材料", "requirement_count": 0, "chapter_count": 0, "severity_counts": {}, "error_message": str(exc), } def _status_label(collection: ChromaCollectionState) -> dict[str, str]: if not collection.exists: return {"code": "missing", "label": "未构建", "message": collection.error_message} if collection.count < 20: return {"code": "thin", "label": "索引过少", "message": "RAG 能力已打通,但当前索引内容较少,建议补齐材料后重建。"} return {"code": "ready", "label": "可用", "message": "RAG 索引已构建,可用于法规依据辅助检索。"} def _unique_target_path(root: Path, original_name: str) -> Path: safe_name = Path(original_name).name or "document" target = root / safe_name if not target.exists(): return target stem = target.stem suffix = target.suffix index = 2 while True: candidate = root / f"{stem}-{index}{suffix}" if not candidate.exists(): return candidate index += 1