From 5ecf78c5d6535a120e20d2a959ca8e682e033407 Mon Sep 17 00:00:00 2001 From: bruce Date: Mon, 8 Jun 2026 21:37:32 +0800 Subject: [PATCH] =?UTF-8?q?feat(knowledge-base):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=85=A8=E5=B1=80=E7=9F=A5=E8=AF=86=E5=BA=93=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/urls.py | 3 +- review_agent/knowledge_base.py | 397 ++++++++++++++++++ .../migrations/0008_knowledgebasedocument.py | 80 ++++ review_agent/models.py | 39 ++ .../services/rag_citation.py | 1 + .../regulatory_review/services/rag_index.py | 94 +++++ review_agent/urls.py | 32 ++ review_agent/views.py | 108 +++++ static/js/knowledge_base.js | 238 +++++++++++ templates/attachment_manager.html | 2 +- templates/knowledge_base.html | 213 ++++++++++ tests/test_knowledge_base.py | 220 ++++++++++ 12 files changed, 1425 insertions(+), 2 deletions(-) create mode 100644 review_agent/knowledge_base.py create mode 100644 review_agent/migrations/0008_knowledgebasedocument.py create mode 100644 static/js/knowledge_base.js create mode 100644 templates/knowledge_base.html create mode 100644 tests/test_knowledge_base.py diff --git a/config/urls.py b/config/urls.py index 36df95c..caf51ba 100644 --- a/config/urls.py +++ b/config/urls.py @@ -2,10 +2,11 @@ from django.contrib import admin from django.contrib.auth.views import LoginView, LogoutView, PasswordChangeView from django.urls import include, path -from review_agent.views import attachment_manager, stream_chat, workspace +from review_agent.views import attachment_manager, knowledge_base_manager, stream_chat, workspace urlpatterns = [ path("", workspace, name="home"), + path("knowledge-base/", knowledge_base_manager, name="knowledge_base_manager"), path("attachments/", attachment_manager, name="attachment_manager"), path("", include("review_agent.urls")), path("chat/stream/", stream_chat, name="chat_stream"), diff --git a/review_agent/knowledge_base.py b/review_agent/knowledge_base.py new file mode 100644 index 0000000..12edff7 --- /dev/null +++ b/review_agent/knowledge_base.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +import hashlib +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from django.conf import settings +from django.core.files.uploadedfile import UploadedFile + +from review_agent.models import KnowledgeBaseDocument +from review_agent.regulatory_review.services.rag_citation import RagIndexUnavailable, retrieve_citations +from review_agent.regulatory_review.services.rag_embedding import DeterministicEmbeddingProvider +from review_agent.regulatory_review.services.rag_index import chunk_text, extract_text_from_path +from review_agent.regulatory_review.services.rule_loader import DEFAULT_RULE_PATH, compute_file_sha256, load_rule_file + + +SUPPORTED_SOURCE_SUFFIXES = {".doc", ".docx", ".pdf", ".txt", ".md", ".pptx", ".xlsx"} + + +@dataclass(frozen=True) +class ChromaCollectionState: + exists: bool + count: int = 0 + error_message: str = "" + sample_metadatas: list[dict[str, Any]] | None = None + source_chunk_counts: dict[str, int] | None = None + + +def build_knowledge_base_context() -> dict[str, Any]: + rule_info = _rule_info() + source_dir = Path(settings.BASE_DIR) / str(rule_info.get("source_material_dir") or "docs/0.原始材料") + sources = list_source_documents(source_dir) + collection = get_chroma_collection_state() + return { + "name": "NMPA IVD 注册资料法规库", + "description": "用于体外诊断试剂注册资料法规核查的结构化规则和 RAG 依据检索。", + "provider": settings.REGULATORY_RAG_PROVIDER, + "collection_name": settings.REGULATORY_RAG_COLLECTION, + "chroma_path": settings.REGULATORY_RAG_CHROMA_PATH, + "rule": rule_info, + "source_dir": str(source_dir), + "sources": sources, + "source_count": len(sources), + "supported_source_count": sum(1 for item in sources if item["supported"]), + "collection": { + "exists": collection.exists, + "count": collection.count, + "error_message": collection.error_message, + "sample_metadatas": collection.sample_metadatas or [], + }, + "status": _status_label(collection), + "build_commands": [ + "python manage.py regulatory_rag_build --provider deterministic", + "python manage.py regulatory_rag_build --provider siliconflow", + ], + "managed_documents": [], + } + + +def build_knowledge_base_context_for_user(user) -> dict[str, Any]: + context = build_knowledge_base_context() + documents = list_documents_for_user(user) + context["managed_documents"] = documents + context["managed_document_count"] = len(documents) + context["active_managed_document_count"] = sum(1 for item in documents if item["is_active"]) + return context + + +def list_source_documents(source_dir: Path) -> list[dict[str, Any]]: + if not source_dir.exists(): + return [] + collection = get_chroma_collection_state() + source_chunk_counts = collection.source_chunk_counts or {} + documents: list[dict[str, Any]] = [] + for path in sorted(source_dir.rglob("*")): + if not path.is_file(): + continue + suffix = path.suffix.lower() + relative_path = str(path.relative_to(source_dir)) + indexed_chunk_count = source_chunk_counts.get(relative_path, 0) + documents.append( + { + "name": path.name, + "relative_path": relative_path, + "suffix": suffix.lstrip(".") or "unknown", + "size": path.stat().st_size, + "supported": suffix in SUPPORTED_SOURCE_SUFFIXES, + "indexed": indexed_chunk_count > 0, + "indexed_chunk_count": indexed_chunk_count, + "indexed_label": f"已入库 {indexed_chunk_count} 片" if indexed_chunk_count else "未入库", + } + ) + return documents + + +def search_knowledge_base(query: str, *, n_results: int = 3) -> dict[str, Any]: + normalized = (query or "").strip() + if not normalized: + return {"query": normalized, "results": [], "error_message": "请输入检索问题。"} + try: + results = retrieve_citations( + normalized, + embedding_provider=DeterministicEmbeddingProvider(), + n_results=n_results, + ) + except RagIndexUnavailable as exc: + return {"query": normalized, "results": [], "error_message": str(exc)} + except Exception as exc: + return {"query": normalized, "results": [], "error_message": f"检索失败:{exc}"} + return {"query": normalized, "results": filter_active_knowledge_results(results), "error_message": ""} + + +def list_documents_for_user(user) -> list[dict[str, Any]]: + return [ + serialize_document(document) + for document in KnowledgeBaseDocument.objects.filter(user=user).exclude(status=KnowledgeBaseDocument.Status.DELETED) + ] + + +def create_document_from_upload( + *, + user, + uploaded_file: UploadedFile, + display_name: str = "", + description: str = "", + is_active: bool = True, +) -> KnowledgeBaseDocument: + root = Path(settings.MEDIA_ROOT) / "knowledge_base" / "users" / str(user.pk) + root.mkdir(parents=True, exist_ok=True) + target = _unique_target_path(root, uploaded_file.name) + with target.open("wb") as handle: + for chunk in uploaded_file.chunks(): + handle.write(chunk) + status = KnowledgeBaseDocument.Status.ACTIVE if is_active else KnowledgeBaseDocument.Status.DISABLED + document = KnowledgeBaseDocument.objects.create( + user=user, + display_name=(display_name or uploaded_file.name).strip(), + original_name=uploaded_file.name, + storage_path=str(target), + file_size=target.stat().st_size, + content_type=getattr(uploaded_file, "content_type", "") or "", + description=description.strip(), + status=status, + is_active=is_active, + ) + if is_active: + index_managed_document(document) + return document + + +def update_document(document: KnowledgeBaseDocument, payload: dict[str, Any]) -> KnowledgeBaseDocument: + update_fields = [] + if "display_name" in payload: + document.display_name = str(payload.get("display_name") or "").strip() or document.original_name + update_fields.append("display_name") + if "description" in payload: + document.description = str(payload.get("description") or "").strip() + update_fields.append("description") + if "is_active" in payload: + document.is_active = bool(payload.get("is_active")) + document.status = KnowledgeBaseDocument.Status.ACTIVE if document.is_active else KnowledgeBaseDocument.Status.DISABLED + update_fields.extend(["is_active", "status"]) + if update_fields: + update_fields.append("updated_at") + document.save(update_fields=update_fields) + return document + + +def delete_document(document: KnowledgeBaseDocument) -> KnowledgeBaseDocument: + remove_managed_document_from_index(document) + document.status = KnowledgeBaseDocument.Status.DELETED + document.is_active = False + document.indexed_chunk_count = 0 + document.metadata = {**(document.metadata or {}), "index_status": "deleted", "index_error": ""} + document.save(update_fields=["status", "is_active", "indexed_chunk_count", "metadata", "updated_at"]) + return document + + +def serialize_document(document: KnowledgeBaseDocument) -> dict[str, Any]: + indexed_label = f"已入库 {document.indexed_chunk_count} 片" if document.indexed_chunk_count else "未入库" + return { + "id": document.pk, + "display_name": document.display_name, + "original_name": document.original_name, + "description": document.description, + "file_size": document.file_size, + "content_type": document.content_type, + "status": document.status, + "is_active": document.is_active, + "indexed_chunk_count": document.indexed_chunk_count, + "indexed_label": indexed_label, + "created_at": document.created_at.isoformat() if document.created_at else "", + "updated_at": document.updated_at.isoformat() if document.updated_at else "", + } + + +def index_managed_document(document: KnowledgeBaseDocument) -> int: + path = Path(document.storage_path) + if not path.is_absolute(): + path = Path(settings.MEDIA_ROOT) / document.storage_path + try: + text = extract_text_from_path(path) + source = f"用户知识库/{document.user_id}/{document.pk}/{document.original_name}" + chunks = chunk_text(text, source=source) + if not chunks: + document.indexed_chunk_count = 0 + document.metadata = {**(document.metadata or {}), "index_status": "empty", "index_error": ""} + document.save(update_fields=["indexed_chunk_count", "metadata", "updated_at"]) + return 0 + collection = _load_chroma_collection() + texts = [chunk.text for chunk in chunks] + embeddings = DeterministicEmbeddingProvider()(texts) + ids = [ + hashlib.sha256(f"managed:{document.pk}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest() + for chunk in chunks + ] + metadatas = [ + { + **chunk.metadata, + "source_type": "managed_document", + "document_id": document.pk, + "user_id": document.user_id, + "original_name": document.original_name, + } + for chunk in chunks + ] + collection.upsert(ids=ids, documents=texts, metadatas=metadatas, embeddings=embeddings) + document.indexed_chunk_count = len(chunks) + document.metadata = {**(document.metadata or {}), "index_status": "indexed", "index_error": ""} + document.save(update_fields=["indexed_chunk_count", "metadata", "updated_at"]) + return len(chunks) + except Exception as exc: + document.indexed_chunk_count = 0 + document.metadata = {**(document.metadata or {}), "index_status": "failed", "index_error": str(exc)} + document.save(update_fields=["indexed_chunk_count", "metadata", "updated_at"]) + return 0 + + +def remove_managed_document_from_index(document: KnowledgeBaseDocument) -> None: + try: + collection = _load_chroma_collection() + collection.delete(where={"document_id": document.pk}) + except Exception as exc: + document.metadata = {**(document.metadata or {}), "index_delete_error": str(exc)} + + +def filter_active_knowledge_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]: + managed_ids = { + int((item.get("metadata") or {}).get("document_id")) + for item in results + if (item.get("metadata") or {}).get("source_type") == "managed_document" + and (item.get("metadata") or {}).get("document_id") is not None + } + if not managed_ids: + return results + active_ids = set( + KnowledgeBaseDocument.objects.filter( + pk__in=managed_ids, + status=KnowledgeBaseDocument.Status.ACTIVE, + is_active=True, + ).values_list("pk", flat=True) + ) + filtered = [] + for item in results: + metadata = item.get("metadata") or {} + if metadata.get("source_type") != "managed_document": + filtered.append(item) + continue + try: + document_id = int(metadata.get("document_id")) + except (TypeError, ValueError): + continue + if document_id in active_ids: + filtered.append(item) + return filtered + + +def _load_chroma_collection(): + try: + import chromadb + except ImportError as exc: + raise RuntimeError("chromadb 未安装。") from exc + persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH) + persist_path.mkdir(parents=True, exist_ok=True) + return chromadb.PersistentClient(path=str(persist_path)).get_or_create_collection( + settings.REGULATORY_RAG_COLLECTION + ) + + +def get_chroma_collection_state() -> ChromaCollectionState: + persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH) + if not persist_path.exists(): + return ChromaCollectionState(exists=False, error_message="法规 RAG 索引目录不存在。") + try: + import chromadb + except ImportError: + return ChromaCollectionState(exists=False, error_message="chromadb 未安装。") + try: + collection = chromadb.PersistentClient(path=str(persist_path)).get_collection(settings.REGULATORY_RAG_COLLECTION) + count = collection.count() + metadatas = _load_collection_metadatas(collection, count) + return ChromaCollectionState( + exists=True, + count=count, + sample_metadatas=metadatas[:10], + source_chunk_counts=_count_chunks_by_source(metadatas), + ) + except Exception as exc: + return ChromaCollectionState(exists=False, error_message=f"法规 RAG collection 不可用:{exc}") + + +def _load_collection_metadatas(collection, count: int) -> list[dict[str, Any]]: + metadatas: list[dict[str, Any]] = [] + if count <= 0: + return metadatas + page_size = 500 + for offset in range(0, count, page_size): + payload = collection.get( + include=["metadatas"], + limit=min(page_size, count - offset), + offset=offset, + ) + metadatas.extend(payload.get("metadatas") or []) + return metadatas + + +def _count_chunks_by_source(metadatas: list[dict[str, Any]]) -> dict[str, int]: + counts: dict[str, int] = {} + for metadata in metadatas: + source = str((metadata or {}).get("source") or "") + if source: + counts[source] = counts.get(source, 0) + 1 + return counts + + +def _rule_info() -> dict[str, Any]: + try: + payload = load_rule_file() + requirements = payload.get("requirements") or [] + severity_counts: dict[str, int] = {} + chapter_codes = set() + for requirement in requirements: + severity = str(requirement.get("severity") or "unknown") + severity_counts[severity] = severity_counts.get(severity, 0) + 1 + attachment4_code = str(requirement.get("attachment4_code") or "") + if attachment4_code: + chapter_codes.add(attachment4_code.split(".")[0]) + return { + "status": "ok", + "code": payload.get("code", ""), + "name": payload.get("name", ""), + "path": str(DEFAULT_RULE_PATH), + "hash": compute_file_sha256(DEFAULT_RULE_PATH), + "rag_collection": payload.get("rag_collection", ""), + "source_material_dir": payload.get("source_material_dir", "docs/0.原始材料"), + "requirement_count": len(requirements), + "chapter_count": len(chapter_codes), + "severity_counts": severity_counts, + } + except Exception as exc: + return { + "status": "failed", + "code": "", + "name": "", + "path": str(DEFAULT_RULE_PATH), + "hash": "", + "rag_collection": "", + "source_material_dir": "docs/0.原始材料", + "requirement_count": 0, + "chapter_count": 0, + "severity_counts": {}, + "error_message": str(exc), + } + + +def _status_label(collection: ChromaCollectionState) -> dict[str, str]: + if not collection.exists: + return {"code": "missing", "label": "未构建", "message": collection.error_message} + if collection.count < 20: + return {"code": "thin", "label": "索引过少", "message": "RAG 能力已打通,但当前索引内容较少,建议补齐材料后重建。"} + return {"code": "ready", "label": "可用", "message": "RAG 索引已构建,可用于法规依据辅助检索。"} + + +def _unique_target_path(root: Path, original_name: str) -> Path: + safe_name = Path(original_name).name or "document" + target = root / safe_name + if not target.exists(): + return target + stem = target.stem + suffix = target.suffix + index = 2 + while True: + candidate = root / f"{stem}-{index}{suffix}" + if not candidate.exists(): + return candidate + index += 1 diff --git a/review_agent/migrations/0008_knowledgebasedocument.py b/review_agent/migrations/0008_knowledgebasedocument.py new file mode 100644 index 0000000..10b205f --- /dev/null +++ b/review_agent/migrations/0008_knowledgebasedocument.py @@ -0,0 +1,80 @@ +# Generated by Django 5.2.14 on 2026-06-08 11:58 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("review_agent", "0007_feishuaccesstokencache_feishuusermapping_and_more"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="KnowledgeBaseDocument", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("display_name", models.CharField(max_length=255)), + ("original_name", models.CharField(max_length=255)), + ("storage_path", models.CharField(max_length=500)), + ("file_size", models.BigIntegerField(default=0)), + ( + "content_type", + models.CharField(blank=True, default="", max_length=120), + ), + ("description", models.TextField(blank=True, default="")), + ( + "status", + models.CharField( + choices=[ + ("active", "启用"), + ("disabled", "停用"), + ("deleted", "已删除"), + ], + default="active", + max_length=20, + ), + ), + ("is_active", models.BooleanField(default=True)), + ("indexed_chunk_count", models.PositiveIntegerField(default=0)), + ("metadata", models.JSONField(blank=True, default=dict)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="knowledge_base_documents", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "db_table": "ra_knowledge_base_document", + "ordering": ["-updated_at", "-id"], + "indexes": [ + models.Index( + fields=["user", "status"], name="idx_ra_kb_doc_user_status" + ), + models.Index( + fields=["user", "created_at"], name="idx_ra_kb_doc_user_created" + ), + models.Index( + fields=["status", "updated_at"], + name="idx_ra_kb_doc_status_updated", + ), + ], + }, + ), + ] diff --git a/review_agent/models.py b/review_agent/models.py index 357ddca..6189a69 100644 --- a/review_agent/models.py +++ b/review_agent/models.py @@ -399,6 +399,45 @@ class RegulatoryRuleVersion(models.Model): return self.code +class KnowledgeBaseDocument(models.Model): + """Stores user-managed knowledge-base source documents.""" + + class Status(models.TextChoices): + ACTIVE = "active", "启用" + DISABLED = "disabled", "停用" + DELETED = "deleted", "已删除" + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + related_name="knowledge_base_documents", + ) + display_name = models.CharField(max_length=255) + original_name = models.CharField(max_length=255) + storage_path = models.CharField(max_length=500) + file_size = models.BigIntegerField(default=0) + content_type = models.CharField(max_length=120, blank=True, default="") + description = models.TextField(blank=True, default="") + status = models.CharField(max_length=20, choices=Status.choices, default=Status.ACTIVE) + is_active = models.BooleanField(default=True) + indexed_chunk_count = models.PositiveIntegerField(default=0) + metadata = models.JSONField(default=dict, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + db_table = "ra_knowledge_base_document" + ordering = ["-updated_at", "-id"] + indexes = [ + models.Index(fields=["user", "status"], name="idx_ra_kb_doc_user_status"), + models.Index(fields=["user", "created_at"], name="idx_ra_kb_doc_user_created"), + models.Index(fields=["status", "updated_at"], name="idx_ra_kb_doc_status_updated"), + ] + + def __str__(self) -> str: + return self.display_name + + class ApplicationFormFillBatch(models.Model): """Tracks one application-form auto-fill workflow run.""" diff --git a/review_agent/regulatory_review/services/rag_citation.py b/review_agent/regulatory_review/services/rag_citation.py index 8f54517..7afca0d 100644 --- a/review_agent/regulatory_review/services/rag_citation.py +++ b/review_agent/regulatory_review/services/rag_citation.py @@ -37,6 +37,7 @@ def retrieve_citations( "source": metadata.get("source", "法规材料"), "text": document, "score": distance, + "metadata": metadata, } ) return citations diff --git a/review_agent/regulatory_review/services/rag_index.py b/review_agent/regulatory_review/services/rag_index.py index c806e08..be80cf8 100644 --- a/review_agent/regulatory_review/services/rag_index.py +++ b/review_agent/regulatory_review/services/rag_index.py @@ -2,6 +2,7 @@ from __future__ import annotations import hashlib import logging +import shutil import subprocess import tempfile from dataclasses import dataclass @@ -102,6 +103,33 @@ def _iter_docx_blocks(document): def _extract_legacy_doc_with_libreoffice(path: Path) -> str: + cached = _cached_docx_path(path) + if cached.exists(): + return extract_text_from_path(cached) + try: + return _extract_legacy_doc_with_libreoffice_convert(path) + except RuntimeError as libreoffice_error: + try: + return _extract_legacy_doc_with_word_com(path) + except RuntimeError as word_error: + try: + return _extract_legacy_doc_with_powershell_word_com(path) + except RuntimeError as powershell_error: + raise RuntimeError( + f"无法转换法规 .doc 材料:{path.name};" + f"LibreOffice 错误:{libreoffice_error};" + f"Word COM 错误:{word_error};" + f"PowerShell Word COM 错误:{powershell_error}" + ) from powershell_error + + +def _cached_docx_path(path: Path) -> Path: + digest = hashlib.sha256(str(path.resolve()).encode("utf-8")).hexdigest()[:12] + cache_dir = Path(settings.MEDIA_ROOT) / "regulatory_review" / "docx_cache" + return cache_dir / f"{path.stem}-{digest}.docx" + + +def _extract_legacy_doc_with_libreoffice_convert(path: Path) -> str: with tempfile.TemporaryDirectory() as tmp_dir: target_dir = Path(tmp_dir) try: @@ -128,6 +156,72 @@ def _extract_legacy_doc_with_libreoffice(path: Path) -> str: return extract_text_from_path(converted) +def _extract_legacy_doc_with_word_com(path: Path) -> str: + with tempfile.TemporaryDirectory() as tmp_dir: + target_dir = Path(tmp_dir) + converted = target_dir / f"{path.stem}.docx" + word = None + try: + import pythoncom + import win32com.client + + pythoncom.CoInitialize() + word = win32com.client.DispatchEx("Word.Application") + word.Visible = False + document = word.Documents.Open(str(path.resolve()), ReadOnly=True) + document.SaveAs(str(converted.resolve()), FileFormat=16) + document.Close(False) + except Exception as exc: + raise RuntimeError(f"无法通过 Word COM 转换法规 .doc 材料:{path.name}") from exc + finally: + if word is not None: + try: + word.Quit() + except Exception: + pass + try: + pythoncom.CoUninitialize() + except Exception: + pass + if not converted.exists(): + raise RuntimeError(f"Word COM 未生成 docx:{path.name}") + return extract_text_from_path(converted) + + +def _extract_legacy_doc_with_powershell_word_com(path: Path) -> str: + with tempfile.TemporaryDirectory() as tmp_dir: + target_dir = Path(tmp_dir) + converted = target_dir / f"{path.stem}.docx" + source_path = str(path.resolve()).replace("'", "''") + target_path = str(converted.resolve()).replace("'", "''") + script = ( + "$ErrorActionPreference = 'Stop';" + "$word = New-Object -ComObject Word.Application;" + "$word.Visible = $false;" + "try {" + f"$doc = $word.Documents.Open('{source_path}', $false, $true);" + f"$doc.SaveAs([ref]'{target_path}', [ref]16);" + "$doc.Close([ref]$false);" + "} finally { $word.Quit() }" + ) + powershell = shutil.which("powershell") or shutil.which("pwsh") + if not powershell: + raise RuntimeError("PowerShell 不可用,无法调用 Word COM。") + try: + subprocess.run( + [powershell, "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", script], + check=True, + capture_output=True, + text=True, + timeout=90, + ) + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc: + raise RuntimeError(f"无法通过 PowerShell Word COM 转换法规 .doc 材料:{path.name}") from exc + if not converted.exists(): + raise RuntimeError(f"PowerShell Word COM 未生成 docx:{path.name}") + return extract_text_from_path(converted) + + def collect_source_chunks(source_dir: Path) -> list[TextChunk]: chunks: list[TextChunk] = [] for path in sorted(source_dir.rglob("*")): diff --git a/review_agent/urls.py b/review_agent/urls.py index 44deeb7..a2b1d24 100644 --- a/review_agent/urls.py +++ b/review_agent/urls.py @@ -20,6 +20,13 @@ from .application_form_fill.views import ( batch_status as application_form_fill_batch_status, start as application_form_fill_start, ) +from .views import ( + knowledge_base_document_detail, + knowledge_base_document_index, + knowledge_base_documents, + knowledge_base_search, + knowledge_base_status, +) urlpatterns = [ @@ -98,4 +105,29 @@ urlpatterns = [ application_form_fill_batch_status, name="application_form_fill_batch_status", ), + path( + "api/review-agent/knowledge-base/status/", + knowledge_base_status, + name="knowledge_base_status", + ), + path( + "api/review-agent/knowledge-base/search/", + knowledge_base_search, + name="knowledge_base_search", + ), + path( + "api/review-agent/knowledge-base/documents/", + knowledge_base_documents, + name="knowledge_base_document_list", + ), + path( + "api/review-agent/knowledge-base/documents//", + knowledge_base_document_detail, + name="knowledge_base_document_detail", + ), + path( + "api/review-agent/knowledge-base/documents//index/", + knowledge_base_document_index, + name="knowledge_base_document_index", + ), ] diff --git a/review_agent/views.py b/review_agent/views.py index 4b0d3da..f629cfb 100644 --- a/review_agent/views.py +++ b/review_agent/views.py @@ -1,5 +1,7 @@ from django.contrib.auth.decorators import login_required from django.db.models import Count, Q +import json + from django.http import HttpRequest, HttpResponse, JsonResponse, StreamingHttpResponse from django.shortcuts import redirect, render from django.views.decorators.http import require_http_methods @@ -12,6 +14,17 @@ from .services import ( stream_message, ) from .models import ApplicationFormFillBatch, Conversation, FileAttachment, FileSummaryBatch, RegulatoryReviewBatch, WorkflowNodeRun +from .knowledge_base import build_knowledge_base_context, search_knowledge_base +from .knowledge_base import ( + build_knowledge_base_context_for_user, + create_document_from_upload, + delete_document, + index_managed_document, + list_documents_for_user, + serialize_document, + update_document, +) +from .models import KnowledgeBaseDocument from .regulatory_review.services.info_extract import ensure_regulatory_condition_candidates @@ -94,6 +107,101 @@ def attachment_manager(request: HttpRequest) -> HttpResponse: ) +@login_required +@require_http_methods(["GET"]) +def knowledge_base_manager(request: HttpRequest) -> HttpResponse: + context = build_knowledge_base_context_for_user(request.user) + return render( + request, + "knowledge_base.html", + { + "page_title": "知识库管理", + "knowledge_base": context, + }, + ) + + +@login_required +@require_http_methods(["GET"]) +def knowledge_base_status(request: HttpRequest) -> JsonResponse: + return JsonResponse(build_knowledge_base_context_for_user(request.user)) + + +@login_required +@require_http_methods(["POST"]) +def knowledge_base_search(request: HttpRequest) -> JsonResponse: + if request.content_type == "application/json": + try: + payload = json.loads(request.body.decode("utf-8") or "{}") + except json.JSONDecodeError: + payload = {} + query = payload.get("query", "") + else: + query = request.POST.get("query", "") + return JsonResponse(search_knowledge_base(str(query))) + + +@login_required +@require_http_methods(["GET", "POST"]) +def knowledge_base_documents(request: HttpRequest) -> JsonResponse: + if request.method == "GET": + return JsonResponse({"documents": list_documents_for_user(request.user)}) + uploaded_file = request.FILES.get("file") + if uploaded_file is None: + return JsonResponse({"error": "请上传知识库材料。"}, status=400) + is_active = str(request.POST.get("is_active", "true")).lower() not in {"0", "false", "off"} + document = create_document_from_upload( + user=request.user, + uploaded_file=uploaded_file, + display_name=request.POST.get("display_name", ""), + description=request.POST.get("description", ""), + is_active=is_active, + ) + return JsonResponse({"document": serialize_document(document)}) + + +@login_required +@require_http_methods(["GET", "PATCH", "DELETE"]) +def knowledge_base_document_detail(request: HttpRequest, document_id: int) -> JsonResponse: + try: + document = KnowledgeBaseDocument.objects.get( + pk=document_id, + user=request.user, + ) + except KnowledgeBaseDocument.DoesNotExist: + return JsonResponse({"error": "知识库材料不存在。"}, status=404) + if document.status == KnowledgeBaseDocument.Status.DELETED: + return JsonResponse({"error": "知识库材料不存在。"}, status=404) + if request.method == "GET": + return JsonResponse({"document": serialize_document(document)}) + if request.method == "DELETE": + delete_document(document) + return JsonResponse({"document": serialize_document(document)}) + try: + payload = json.loads(request.body.decode("utf-8") or "{}") + except json.JSONDecodeError: + payload = {} + update_document(document, payload) + return JsonResponse({"document": serialize_document(document)}) + + +@login_required +@require_http_methods(["POST"]) +def knowledge_base_document_index(request: HttpRequest, document_id: int) -> JsonResponse: + try: + document = KnowledgeBaseDocument.objects.get( + pk=document_id, + user=request.user, + ) + except KnowledgeBaseDocument.DoesNotExist: + return JsonResponse({"error": "知识库材料不存在。"}, status=404) + if document.status == KnowledgeBaseDocument.Status.DELETED: + return JsonResponse({"error": "知识库材料不存在。"}, status=404) + chunk_count = index_managed_document(document) + document.refresh_from_db() + return JsonResponse({"document": serialize_document(document), "chunk_count": chunk_count}) + + @login_required @require_http_methods(["POST"]) def stream_chat(request: HttpRequest) -> HttpResponse: diff --git a/static/js/knowledge_base.js b/static/js/knowledge_base.js new file mode 100644 index 0000000..dd6b9d0 --- /dev/null +++ b/static/js/knowledge_base.js @@ -0,0 +1,238 @@ +(function () { + var page = document.querySelector(".knowledge-page"); + if (!page) { + return; + } + + var documentForm = document.getElementById("knowledgeDocumentForm"); + var documentStatus = document.getElementById("knowledgeDocumentStatus"); + var documentTable = document.getElementById("knowledgeDocumentTable"); + var documentSearch = document.getElementById("knowledgeDocumentSearch"); + var searchForm = document.getElementById("knowledgeSearchForm"); + var queryInput = document.getElementById("knowledgeSearchQuery"); + var results = document.getElementById("knowledgeSearchResults"); + var sourceSearch = document.getElementById("knowledgeSourceSearch"); + var sourceTable = document.getElementById("knowledgeSourceTable"); + var documentFileInput = document.getElementById("knowledgeDocumentFile"); + var uploadDropzone = document.getElementById("knowledgeUploadDropzone"); + + function csrfToken() { + var cookie = document.cookie.split("; ").find(function (item) { + return item.indexOf("csrftoken=") === 0; + }); + return cookie ? decodeURIComponent(cookie.split("=")[1]) : ""; + } + + function escapeHtml(value) { + return String(value || "") + .replace(/&/g, "&") + .replace(//g, ">") + .replace(/"/g, """) + .replace(/'/g, "'"); + } + + async function patchDocument(row, payload) { + var response = await fetch(row.getAttribute("data-detail-url"), { + method: "PATCH", + headers: { + "Content-Type": "application/json", + "X-CSRFToken": csrfToken(), + }, + body: JSON.stringify(payload), + }); + if (!response.ok) { + throw new Error("知识库材料更新失败。"); + } + return response.json(); + } + + async function deleteDocument(row) { + var response = await fetch(row.getAttribute("data-detail-url"), { + method: "DELETE", + headers: { "X-CSRFToken": csrfToken() }, + }); + if (!response.ok) { + throw new Error("知识库材料删除失败。"); + } + } + + async function indexDocument(row) { + var response = await fetch(row.getAttribute("data-index-url"), { + method: "POST", + headers: { "X-CSRFToken": csrfToken() }, + }); + if (!response.ok) { + throw new Error("知识库材料解析入库失败。"); + } + return response.json(); + } + + function renderResults(payload) { + if (!results) { + return; + } + if (payload.error_message) { + results.innerHTML = '

' + escapeHtml(payload.error_message) + "

"; + return; + } + if (!payload.results || !payload.results.length) { + results.innerHTML = '

未检索到依据片段。

'; + return; + } + results.innerHTML = payload.results + .map(function (item, index) { + return [ + '
', + "
结果 " + (index + 1) + "" + escapeHtml(item.source || "法规材料") + "
", + "

" + escapeHtml(item.text || "").slice(0, 600) + "

", + item.score === null || item.score === undefined ? "" : "score: " + escapeHtml(item.score) + "", + "
", + ].join(""); + }) + .join(""); + } + + if (documentForm) { + documentForm.addEventListener("submit", async function (event) { + event.preventDefault(); + var formData = new FormData(documentForm); + if (documentStatus) { + documentStatus.textContent = "上传并解析入库中..."; + } + try { + var response = await fetch(page.getAttribute("data-document-url"), { + method: "POST", + headers: { "X-CSRFToken": csrfToken() }, + body: formData, + }); + if (!response.ok) { + throw new Error("新增材料失败。"); + } + window.location.reload(); + } catch (error) { + if (documentStatus) { + documentStatus.textContent = error.message || "新增材料失败。"; + } + } + }); + } + + if (documentFileInput && documentStatus) { + documentFileInput.addEventListener("change", function () { + var file = documentFileInput.files && documentFileInput.files[0]; + documentStatus.textContent = file + ? "已选择:" + file.name + : "上传后会进入当前账号的全局知识库。"; + }); + } + + if (uploadDropzone && documentFileInput) { + uploadDropzone.addEventListener("click", function () { + documentFileInput.click(); + }); + uploadDropzone.addEventListener("keydown", function (event) { + if (event.key === "Enter" || event.key === " ") { + event.preventDefault(); + documentFileInput.click(); + } + }); + ["dragenter", "dragover"].forEach(function (eventName) { + uploadDropzone.addEventListener(eventName, function (event) { + event.preventDefault(); + uploadDropzone.classList.add("dragging"); + }); + }); + ["dragleave", "drop"].forEach(function (eventName) { + uploadDropzone.addEventListener(eventName, function (event) { + event.preventDefault(); + uploadDropzone.classList.remove("dragging"); + }); + }); + uploadDropzone.addEventListener("drop", function (event) { + var files = event.dataTransfer && event.dataTransfer.files; + if (!files || !files.length) { + return; + } + documentFileInput.files = files; + documentFileInput.dispatchEvent(new Event("change", { bubbles: true })); + }); + } + + if (documentTable) { + documentTable.addEventListener("click", async function (event) { + var button = event.target.closest("[data-kb-action]"); + if (!button) { + return; + } + var row = button.closest("tr[data-document-id]"); + if (!row) { + return; + } + var action = button.getAttribute("data-kb-action"); + try { + if (action === "edit") { + var nameCell = row.querySelector(".attachment-name"); + var nextName = window.prompt("请输入新的材料名称", nameCell ? nameCell.textContent.trim() : ""); + if (nextName) { + await patchDocument(row, { display_name: nextName }); + window.location.reload(); + } + } else if (action === "toggle") { + await patchDocument(row, { is_active: button.textContent.trim() === "启用" }); + window.location.reload(); + } else if (action === "index") { + button.disabled = true; + button.textContent = "解析中"; + await indexDocument(row); + window.location.reload(); + } else if (action === "delete" && window.confirm("确认删除该知识库材料?")) { + await deleteDocument(row); + window.location.reload(); + } + } catch (error) { + window.alert(error.message || "知识库材料操作失败。"); + } + }); + } + + if (searchForm && queryInput) { + searchForm.addEventListener("submit", async function (event) { + event.preventDefault(); + var query = queryInput.value.trim(); + if (!query) { + renderResults({ error_message: "请输入检索问题。" }); + return; + } + results.innerHTML = '

检索中...

'; + try { + var response = await fetch(page.getAttribute("data-search-url"), { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-CSRFToken": csrfToken(), + }, + body: JSON.stringify({ query: query }), + }); + renderResults(await response.json()); + } catch (error) { + renderResults({ error_message: "检索失败,请稍后重试。" }); + } + }); + } + + function bindTableSearch(input, table, selector) { + if (!input || !table) { + return; + } + input.addEventListener("input", function () { + var keyword = input.value.trim().toLowerCase(); + table.querySelectorAll(selector).forEach(function (row) { + row.hidden = keyword && row.textContent.toLowerCase().indexOf(keyword) === -1; + }); + }); + } + + bindTableSearch(documentSearch, documentTable, "tbody tr[data-document-id]"); + bindTableSearch(sourceSearch, sourceTable, "tbody tr[data-source-name]"); +})(); diff --git a/templates/attachment_manager.html b/templates/attachment_manager.html index 72e55dc..5b7fd7c 100644 --- a/templates/attachment_manager.html +++ b/templates/attachment_manager.html @@ -10,8 +10,8 @@ diff --git a/templates/knowledge_base.html b/templates/knowledge_base.html new file mode 100644 index 0000000..efc87cd --- /dev/null +++ b/templates/knowledge_base.html @@ -0,0 +1,213 @@ +{% extends "base.html" %} +{% load static %} + +{% block title %}知识库管理 - DEMO-AGENT V2{% endblock %} +{% block body_class %}app-body{% endblock %} + +{% block content %} +
+
+ +
+
+ +
+
+
+ +
+
+
+

知识库管理

+

知识库管理

+

管理当前账号所有对话可调用的法规、制度、模板和审查依据。

+
+
+ {{ knowledge_base.status.label }} + 返回对话 +
+
+ +
+ + +
+
+
+

知识库材料列表

+ +
+
+ + + + + + + + + + + + + + {% for document in knowledge_base.managed_documents %} + + + + + + + + + + {% empty %} + + + + {% endfor %} + +
状态材料名称文件名大小入库状态更新时间操作
{% if document.is_active %}启用{% else %}停用{% endif %}{{ document.display_name }}{{ document.original_name }}{{ document.file_size }} bytes{{ document.indexed_label }}{{ document.updated_at|slice:":19" }} + + + + +
当前知识库暂无材料
+
+
+ +
+
+

内置法规材料

+ +
+
+ + + + + + + + + + + + {% for source in knowledge_base.sources %} + + + + + + + + {% empty %} + + + + {% endfor %} + +
状态文件类型大小索引
{% if source.supported %}可解析{% else %}暂不支持{% endif %}{{ source.relative_path }}{{ source.suffix }}{{ source.size }} bytes{{ source.indexed_label }}
暂无法规材料
+
+
+
+
+
+
+{% endblock %} + +{% block scripts %} + +{% endblock %} diff --git a/tests/test_knowledge_base.py b/tests/test_knowledge_base.py new file mode 100644 index 0000000..936d822 --- /dev/null +++ b/tests/test_knowledge_base.py @@ -0,0 +1,220 @@ +import pytest +from django.core.files.uploadedfile import SimpleUploadedFile +from django.urls import reverse + +from review_agent.knowledge_base import build_knowledge_base_context, delete_document, search_knowledge_base +from review_agent.models import KnowledgeBaseDocument + + +pytestmark = pytest.mark.django_db + + +def test_knowledge_base_context_reports_rule_and_sources(): + context = build_knowledge_base_context() + + assert context["rule"]["code"] == "nmpa_ivd_registration_v1" + assert context["rule"]["requirement_count"] > 0 + assert context["source_count"] > 0 + assert context["collection_name"] == "nmpa_ivd_registration_v1" + + +def test_knowledge_base_page_requires_login(client): + response = client.get(reverse("knowledge_base_manager")) + + assert response.status_code == 302 + + +def test_knowledge_base_page_renders_for_user(client, django_user_model): + user = django_user_model.objects.create_user(username="owner", password="pass") + client.force_login(user) + + response = client.get(reverse("knowledge_base_manager")) + + assert response.status_code == 200 + assert "知识库管理" in response.content.decode("utf-8") + assert "RAG 检索测试" in response.content.decode("utf-8") + content = response.content.decode("utf-8") + tabbar = content[content.index('
", content.index('
0 + + list_response = client.get(reverse("knowledge_base_document_list")) + assert list_response.status_code == 200 + assert list_response.json()["documents"][0]["display_name"] == "注册检验报告要求" + + detail_response = client.get(reverse("knowledge_base_document_detail", args=[document_id])) + assert detail_response.status_code == 200 + assert detail_response.json()["document"]["original_name"] == "report.md" + assert "已入库" in detail_response.json()["document"]["indexed_label"] + + patch_response = client.patch( + reverse("knowledge_base_document_detail", args=[document_id]), + data='{"display_name": "更新后的法规材料", "is_active": false}', + content_type="application/json", + ) + + assert patch_response.status_code == 200 + assert patch_response.json()["document"]["display_name"] == "更新后的法规材料" + assert patch_response.json()["document"]["is_active"] is False + + delete_response = client.delete(reverse("knowledge_base_document_detail", args=[document_id])) + + assert delete_response.status_code == 200 + assert KnowledgeBaseDocument.objects.get(pk=document_id).status == KnowledgeBaseDocument.Status.DELETED + + +def test_delete_document_removes_managed_chunks_from_index(monkeypatch, django_user_model): + user = django_user_model.objects.create_user(username="owner", password="pass") + document = KnowledgeBaseDocument.objects.create( + user=user, + display_name="孙之烨简历", + original_name="孙之烨-260510.pdf", + storage_path="knowledge_base/resume.pdf", + file_size=1, + indexed_chunk_count=7, + metadata={"index_status": "indexed", "index_error": ""}, + ) + deleted_filters = [] + + class FakeCollection: + def delete(self, where): + deleted_filters.append(where) + + monkeypatch.setattr("review_agent.knowledge_base._load_chroma_collection", lambda: FakeCollection()) + + delete_document(document) + + document.refresh_from_db() + assert document.status == KnowledgeBaseDocument.Status.DELETED + assert document.is_active is False + assert document.indexed_chunk_count == 0 + assert document.metadata["index_status"] == "deleted" + assert deleted_filters == [{"document_id": document.pk}] + + +def test_knowledge_base_document_api_is_scoped_to_owner(client, django_user_model): + owner = django_user_model.objects.create_user(username="owner", password="pass") + other = django_user_model.objects.create_user(username="other", password="pass") + document = KnowledgeBaseDocument.objects.create( + user=owner, + display_name="法规材料", + original_name="a.md", + storage_path="knowledge_base/a.md", + file_size=1, + ) + client.force_login(other) + + response = client.patch( + reverse("knowledge_base_document_detail", args=[document.pk]), + data='{"display_name": "越权修改"}', + content_type="application/json", + ) + + assert response.status_code == 404 + + +def test_knowledge_base_document_manual_index_api(client, settings, tmp_path, django_user_model): + settings.MEDIA_ROOT = tmp_path + user = django_user_model.objects.create_user(username="owner", password="pass") + client.force_login(user) + source_path = tmp_path / "manual.md" + source_path.write_text("# manual\n注册检验报告要求", encoding="utf-8") + document = KnowledgeBaseDocument.objects.create( + user=user, + display_name="manual.md", + original_name="manual.md", + storage_path=str(source_path), + file_size=source_path.stat().st_size, + indexed_chunk_count=0, + ) + + response = client.post(reverse("knowledge_base_document_index", args=[document.pk])) + + assert response.status_code == 200 + document.refresh_from_db() + assert document.indexed_chunk_count > 0 + assert "已入库" in response.json()["document"]["indexed_label"]