diff --git a/config/settings.py b/config/settings.py index a2260fa..b8dfc9d 100644 --- a/config/settings.py +++ b/config/settings.py @@ -105,6 +105,23 @@ LLM_API_KEY = os.environ.get("LLM_API_KEY", "") LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.siliconflow.cn/v1") LLM_MODEL = os.environ.get("LLM_MODEL", "") +REGULATORY_RAG_PROVIDER = os.environ.get("REGULATORY_RAG_PROVIDER", "siliconflow") +REGULATORY_RAG_CHROMA_PATH = os.environ.get( + "REGULATORY_RAG_CHROMA_PATH", + str(MEDIA_ROOT / "regulatory_review" / "rag" / "chroma"), +) +REGULATORY_RAG_COLLECTION = os.environ.get( + "REGULATORY_RAG_COLLECTION", + "nmpa_ivd_registration_v1", +) +SILICONFLOW_BASE_URL = os.environ.get("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1") +SILICONFLOW_API_KEY = os.environ.get("SILICONFLOW_API_KEY", "") +SILICONFLOW_EMBEDDING_MODEL = os.environ.get( + "SILICONFLOW_EMBEDDING_MODEL", + "Qwen/Qwen3-Embedding-4B", +) +SILICONFLOW_EMBEDDING_DIMENSIONS = int(os.environ.get("SILICONFLOW_EMBEDDING_DIMENSIONS", "1024")) + LOGGING = { "version": 1, "disable_existing_loggers": False, diff --git a/requirements.txt b/requirements.txt index a04423d..0c4aaa8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ olefile>=0.47 py7zr>=0.21 playwright>=1.60 PyYAML>=6.0 +chromadb>=0.5 +httpx>=0.27 diff --git a/review_agent/management/commands/regulatory_rag_build.py b/review_agent/management/commands/regulatory_rag_build.py new file mode 100644 index 0000000..b8be556 --- /dev/null +++ b/review_agent/management/commands/regulatory_rag_build.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from pathlib import Path + +from django.conf import settings +from django.core.management.base import BaseCommand, CommandError + +from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider +from review_agent.regulatory_review.services.rag_index import build_chroma_index +from review_agent.regulatory_review.services.rule_loader import load_rule_file + + +class Command(BaseCommand): + help = "构建 NMPA 法规材料本地 ChromaDB RAG 索引。" + + def add_arguments(self, parser): + parser.add_argument("--provider", default=None, help="覆盖 REGULATORY_RAG_PROVIDER。") + + def handle(self, *args, **options): + rule_set = load_rule_file() + source_dir = Path(settings.BASE_DIR) / rule_set["source_material_dir"] + if not source_dir.exists(): + raise CommandError(f"法规材料目录不存在:{source_dir}") + try: + provider = get_embedding_provider(options["provider"]) + count = build_chroma_index(source_dir=source_dir, embedding_provider=provider) + except Exception as exc: + raise CommandError(str(exc)) from exc + self.stdout.write( + self.style.SUCCESS( + f"已构建法规 RAG 索引:collection={settings.REGULATORY_RAG_COLLECTION}, chunks={count}" + ) + ) diff --git a/review_agent/regulatory_review/services/rag_citation.py b/review_agent/regulatory_review/services/rag_citation.py new file mode 100644 index 0000000..8f54517 --- /dev/null +++ b/review_agent/regulatory_review/services/rag_citation.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from pathlib import Path + +from django.conf import settings + +from .rag_embedding import EmbeddingFunction, get_embedding_provider + + +class RagIndexUnavailable(RuntimeError): + pass + + +def retrieve_citations( + query: str, + *, + embedding_provider: EmbeddingFunction | None = None, + collection=None, + n_results: int = 3, +) -> list[dict[str, object]]: + provider = embedding_provider or get_embedding_provider() + if collection is None: + collection = _load_collection() + embeddings = provider([query]) + result = collection.query(query_embeddings=embeddings, n_results=n_results) + documents = (result.get("documents") or [[]])[0] + metadatas = (result.get("metadatas") or [[]])[0] + distances = (result.get("distances") or [[]])[0] + if not documents: + return [{"source": "原文依据待补充", "text": "RAG 无命中", "score": None}] + citations = [] + for index, document in enumerate(documents): + metadata = metadatas[index] if index < len(metadatas) else {} + distance = distances[index] if index < len(distances) else None + citations.append( + { + "source": metadata.get("source", "法规材料"), + "text": document, + "score": distance, + } + ) + return citations + + +def _load_collection(): + persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH) + if not persist_path.exists(): + raise RagIndexUnavailable("法规 RAG 索引不存在,请先运行 regulatory_rag_build。") + try: + import chromadb + except ImportError as exc: + raise RagIndexUnavailable("chromadb 未安装,请先安装 requirements.txt。") from exc + client = chromadb.PersistentClient(path=str(persist_path)) + try: + return client.get_collection(settings.REGULATORY_RAG_COLLECTION) + except Exception as exc: + raise RagIndexUnavailable("法规 RAG collection 不存在,请先运行 regulatory_rag_build。") from exc diff --git a/review_agent/regulatory_review/services/rag_embedding.py b/review_agent/regulatory_review/services/rag_embedding.py new file mode 100644 index 0000000..d50de0e --- /dev/null +++ b/review_agent/regulatory_review/services/rag_embedding.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import hashlib +import random +from typing import Callable, Iterable + +import httpx +from django.conf import settings + + +EmbeddingFunction = Callable[[list[str]], list[list[float]]] + + +class EmbeddingConfigurationError(RuntimeError): + pass + + +class SiliconFlowEmbeddingProvider: + def __init__( + self, + *, + api_key: str, + base_url: str, + model: str, + dimensions: int, + timeout: float = 60.0, + ): + if not api_key: + raise EmbeddingConfigurationError("SILICONFLOW_API_KEY 未配置。") + self.api_key = api_key + self.base_url = base_url.rstrip("/") + self.model = model + self.dimensions = dimensions + self.timeout = timeout + + def embed(self, texts: Iterable[str]) -> list[list[float]]: + inputs = list(texts) + response = httpx.post( + f"{self.base_url}/embeddings", + headers={"Authorization": f"Bearer {self.api_key}"}, + json={ + "model": self.model, + "input": inputs, + "dimensions": self.dimensions, + }, + timeout=self.timeout, + ) + response.raise_for_status() + payload = response.json() + return [item["embedding"] for item in payload.get("data", [])] + + def __call__(self, texts: list[str]) -> list[list[float]]: + return self.embed(texts) + + +class DeterministicEmbeddingProvider: + """Small local embedding substitute for tests and explicit dry runs.""" + + def __init__(self, dimensions: int = 16): + self.dimensions = dimensions + + def __call__(self, texts: list[str]) -> list[list[float]]: + vectors = [] + for text in texts: + seed = int(hashlib.sha256(text.encode("utf-8")).hexdigest()[:16], 16) + rng = random.Random(seed) + vectors.append([rng.uniform(-1, 1) for _ in range(self.dimensions)]) + return vectors + + +def get_embedding_provider(provider_name: str | None = None) -> EmbeddingFunction: + provider = provider_name or settings.REGULATORY_RAG_PROVIDER + if provider == "siliconflow": + return SiliconFlowEmbeddingProvider( + api_key=settings.SILICONFLOW_API_KEY, + base_url=settings.SILICONFLOW_BASE_URL, + model=settings.SILICONFLOW_EMBEDDING_MODEL, + dimensions=settings.SILICONFLOW_EMBEDDING_DIMENSIONS, + ) + if provider in {"deterministic", "local"}: + return DeterministicEmbeddingProvider() + raise EmbeddingConfigurationError(f"不支持的 embedding provider:{provider}") diff --git a/review_agent/regulatory_review/services/rag_index.py b/review_agent/regulatory_review/services/rag_index.py new file mode 100644 index 0000000..bbaca66 --- /dev/null +++ b/review_agent/regulatory_review/services/rag_index.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import hashlib +import logging +import subprocess +import tempfile +from dataclasses import dataclass +from pathlib import Path + +from django.conf import settings +from docx import Document +from openpyxl import load_workbook +from pypdf import PdfReader +from pptx import Presentation + +from .rag_embedding import EmbeddingFunction + + +logger = logging.getLogger("review_agent.regulatory_review.rag_index") + + +@dataclass(frozen=True) +class TextChunk: + text: str + metadata: dict[str, object] + + +def chunk_text(text: str, *, source: str, chunk_size: int = 900, overlap: int = 120) -> list[TextChunk]: + normalized = "\n".join(line.strip() for line in text.splitlines() if line.strip()) + if not normalized: + return [] + chunks = [] + start = 0 + index = 0 + step = max(1, chunk_size - overlap) + while start < len(normalized): + part = normalized[start : start + chunk_size].strip() + if part: + chunks.append(TextChunk(text=part, metadata={"source": source, "chunk_index": index})) + index += 1 + start += step + return chunks + + +def extract_text_from_path(path: Path) -> str: + suffix = path.suffix.lower() + if suffix in {".txt", ".md"}: + return path.read_text(encoding="utf-8", errors="ignore") + if suffix == ".pdf": + return "\n".join(page.extract_text() or "" for page in PdfReader(str(path)).pages) + if suffix == ".docx": + return "\n".join(paragraph.text for paragraph in Document(str(path)).paragraphs) + if suffix == ".pptx": + presentation = Presentation(str(path)) + lines = [] + for slide in presentation.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + lines.append(shape.text) + return "\n".join(lines) + if suffix == ".xlsx": + workbook = load_workbook(path, data_only=True, read_only=True) + lines = [] + for sheet in workbook.worksheets: + for row in sheet.iter_rows(values_only=True): + values = [str(cell) for cell in row if cell not in {None, ""}] + if values: + lines.append("\t".join(values)) + return "\n".join(lines) + if suffix == ".doc": + return _extract_legacy_doc_with_libreoffice(path) + return "" + + +def _extract_legacy_doc_with_libreoffice(path: Path) -> str: + with tempfile.TemporaryDirectory() as tmp_dir: + target_dir = Path(tmp_dir) + try: + subprocess.run( + [ + "soffice", + "--headless", + "--convert-to", + "docx", + "--outdir", + str(target_dir), + str(path), + ], + check=True, + capture_output=True, + text=True, + timeout=60, + ) + except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc: + raise RuntimeError(f"无法通过 LibreOffice 转换法规 .doc 材料:{path.name}") from exc + converted = target_dir / f"{path.stem}.docx" + if not converted.exists(): + raise RuntimeError(f"LibreOffice 未生成 docx:{path.name}") + return extract_text_from_path(converted) + + +def collect_source_chunks(source_dir: Path) -> list[TextChunk]: + chunks: list[TextChunk] = [] + for path in sorted(source_dir.rglob("*")): + if not path.is_file(): + continue + try: + text = extract_text_from_path(path) + except RuntimeError as exc: + logger.warning("Regulatory source extraction skipped", extra={"path": str(path), "error": str(exc)}) + continue + chunks.extend(chunk_text(text, source=str(path.relative_to(source_dir)))) + return chunks + + +def build_chroma_index( + *, + source_dir: Path, + embedding_provider: EmbeddingFunction, + persist_path: Path | None = None, + collection_name: str | None = None, +) -> int: + try: + import chromadb + except ImportError as exc: + raise RuntimeError("chromadb 未安装,请先安装 requirements.txt。") from exc + + persist_path = persist_path or Path(settings.REGULATORY_RAG_CHROMA_PATH) + collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION + persist_path.mkdir(parents=True, exist_ok=True) + chunks = collect_source_chunks(source_dir) + client = chromadb.PersistentClient(path=str(persist_path)) + collection = client.get_or_create_collection(collection_name) + if not chunks: + return 0 + texts = [chunk.text for chunk in chunks] + embeddings = embedding_provider(texts) + ids = [ + hashlib.sha256(f"{chunk.metadata['source']}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest() + for chunk in chunks + ] + collection.upsert( + ids=ids, + documents=texts, + metadatas=[chunk.metadata for chunk in chunks], + embeddings=embeddings, + ) + return len(chunks) diff --git a/tests/test_regulatory_rag.py b/tests/test_regulatory_rag.py new file mode 100644 index 0000000..5ea6096 --- /dev/null +++ b/tests/test_regulatory_rag.py @@ -0,0 +1,72 @@ +import pytest + +from review_agent.regulatory_review.services.rag_citation import ( + RagIndexUnavailable, + retrieve_citations, +) +from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider +from review_agent.regulatory_review.services.rag_index import chunk_text + + +def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch): + calls = [] + + class FakeResponse: + def raise_for_status(self): + return None + + def json(self): + return {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]} + + def fake_post(url, headers, json, timeout): + calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout}) + return FakeResponse() + + monkeypatch.setattr("review_agent.regulatory_review.services.rag_embedding.httpx.post", fake_post) + + provider = SiliconFlowEmbeddingProvider( + api_key="secret", + base_url="https://api.siliconflow.cn/v1", + model="Qwen/Qwen3-Embedding-4B", + dimensions=1024, + ) + + assert provider.embed(["法规依据", "注册检验报告"]) == [[0.1, 0.2], [0.3, 0.4]] + assert calls[0]["url"] == "https://api.siliconflow.cn/v1/embeddings" + assert calls[0]["headers"]["Authorization"] == "Bearer secret" + assert calls[0]["json"]["model"] == "Qwen/Qwen3-Embedding-4B" + assert calls[0]["json"]["dimensions"] == 1024 + + +def test_chunk_text_preserves_source_metadata(): + chunks = chunk_text( + "第一段法规内容。\n" * 20, + source="法规.doc", + chunk_size=30, + overlap=5, + ) + + assert len(chunks) > 1 + assert chunks[0].metadata["source"] == "法规.doc" + assert chunks[0].text + + +def test_retrieve_citations_returns_placeholder_when_no_hits(): + class EmptyCollection: + def query(self, query_embeddings, n_results): + return {"documents": [[]], "metadatas": [[]], "distances": [[]]} + + citations = retrieve_citations( + "注册检验报告", + embedding_provider=lambda texts: [[0.1, 0.2]], + collection=EmptyCollection(), + ) + + assert citations[0]["source"] == "原文依据待补充" + + +def test_retrieve_citations_raises_when_index_missing(settings, tmp_path): + settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "missing" + + with pytest.raises(RagIndexUnavailable): + retrieve_citations("注册检验报告", embedding_provider=lambda texts: [[0.1]])