feat(regulatory): 增加本地法规RAG索引检索
This commit is contained in:
@@ -105,6 +105,23 @@ LLM_API_KEY = os.environ.get("LLM_API_KEY", "")
|
|||||||
LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.siliconflow.cn/v1")
|
LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.siliconflow.cn/v1")
|
||||||
LLM_MODEL = os.environ.get("LLM_MODEL", "")
|
LLM_MODEL = os.environ.get("LLM_MODEL", "")
|
||||||
|
|
||||||
|
REGULATORY_RAG_PROVIDER = os.environ.get("REGULATORY_RAG_PROVIDER", "siliconflow")
|
||||||
|
REGULATORY_RAG_CHROMA_PATH = os.environ.get(
|
||||||
|
"REGULATORY_RAG_CHROMA_PATH",
|
||||||
|
str(MEDIA_ROOT / "regulatory_review" / "rag" / "chroma"),
|
||||||
|
)
|
||||||
|
REGULATORY_RAG_COLLECTION = os.environ.get(
|
||||||
|
"REGULATORY_RAG_COLLECTION",
|
||||||
|
"nmpa_ivd_registration_v1",
|
||||||
|
)
|
||||||
|
SILICONFLOW_BASE_URL = os.environ.get("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1")
|
||||||
|
SILICONFLOW_API_KEY = os.environ.get("SILICONFLOW_API_KEY", "")
|
||||||
|
SILICONFLOW_EMBEDDING_MODEL = os.environ.get(
|
||||||
|
"SILICONFLOW_EMBEDDING_MODEL",
|
||||||
|
"Qwen/Qwen3-Embedding-4B",
|
||||||
|
)
|
||||||
|
SILICONFLOW_EMBEDDING_DIMENSIONS = int(os.environ.get("SILICONFLOW_EMBEDDING_DIMENSIONS", "1024"))
|
||||||
|
|
||||||
LOGGING = {
|
LOGGING = {
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"disable_existing_loggers": False,
|
"disable_existing_loggers": False,
|
||||||
|
|||||||
@@ -8,3 +8,5 @@ olefile>=0.47
|
|||||||
py7zr>=0.21
|
py7zr>=0.21
|
||||||
playwright>=1.60
|
playwright>=1.60
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
|
chromadb>=0.5
|
||||||
|
httpx>=0.27
|
||||||
|
|||||||
33
review_agent/management/commands/regulatory_rag_build.py
Normal file
33
review_agent/management/commands/regulatory_rag_build.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.core.management.base import BaseCommand, CommandError
|
||||||
|
|
||||||
|
from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider
|
||||||
|
from review_agent.regulatory_review.services.rag_index import build_chroma_index
|
||||||
|
from review_agent.regulatory_review.services.rule_loader import load_rule_file
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
help = "构建 NMPA 法规材料本地 ChromaDB RAG 索引。"
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
parser.add_argument("--provider", default=None, help="覆盖 REGULATORY_RAG_PROVIDER。")
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
rule_set = load_rule_file()
|
||||||
|
source_dir = Path(settings.BASE_DIR) / rule_set["source_material_dir"]
|
||||||
|
if not source_dir.exists():
|
||||||
|
raise CommandError(f"法规材料目录不存在:{source_dir}")
|
||||||
|
try:
|
||||||
|
provider = get_embedding_provider(options["provider"])
|
||||||
|
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider)
|
||||||
|
except Exception as exc:
|
||||||
|
raise CommandError(str(exc)) from exc
|
||||||
|
self.stdout.write(
|
||||||
|
self.style.SUCCESS(
|
||||||
|
f"已构建法规 RAG 索引:collection={settings.REGULATORY_RAG_COLLECTION}, chunks={count}"
|
||||||
|
)
|
||||||
|
)
|
||||||
57
review_agent/regulatory_review/services/rag_citation.py
Normal file
57
review_agent/regulatory_review/services/rag_citation.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
from .rag_embedding import EmbeddingFunction, get_embedding_provider
|
||||||
|
|
||||||
|
|
||||||
|
class RagIndexUnavailable(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_citations(
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
embedding_provider: EmbeddingFunction | None = None,
|
||||||
|
collection=None,
|
||||||
|
n_results: int = 3,
|
||||||
|
) -> list[dict[str, object]]:
|
||||||
|
provider = embedding_provider or get_embedding_provider()
|
||||||
|
if collection is None:
|
||||||
|
collection = _load_collection()
|
||||||
|
embeddings = provider([query])
|
||||||
|
result = collection.query(query_embeddings=embeddings, n_results=n_results)
|
||||||
|
documents = (result.get("documents") or [[]])[0]
|
||||||
|
metadatas = (result.get("metadatas") or [[]])[0]
|
||||||
|
distances = (result.get("distances") or [[]])[0]
|
||||||
|
if not documents:
|
||||||
|
return [{"source": "原文依据待补充", "text": "RAG 无命中", "score": None}]
|
||||||
|
citations = []
|
||||||
|
for index, document in enumerate(documents):
|
||||||
|
metadata = metadatas[index] if index < len(metadatas) else {}
|
||||||
|
distance = distances[index] if index < len(distances) else None
|
||||||
|
citations.append(
|
||||||
|
{
|
||||||
|
"source": metadata.get("source", "法规材料"),
|
||||||
|
"text": document,
|
||||||
|
"score": distance,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return citations
|
||||||
|
|
||||||
|
|
||||||
|
def _load_collection():
|
||||||
|
persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH)
|
||||||
|
if not persist_path.exists():
|
||||||
|
raise RagIndexUnavailable("法规 RAG 索引不存在,请先运行 regulatory_rag_build。")
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
except ImportError as exc:
|
||||||
|
raise RagIndexUnavailable("chromadb 未安装,请先安装 requirements.txt。") from exc
|
||||||
|
client = chromadb.PersistentClient(path=str(persist_path))
|
||||||
|
try:
|
||||||
|
return client.get_collection(settings.REGULATORY_RAG_COLLECTION)
|
||||||
|
except Exception as exc:
|
||||||
|
raise RagIndexUnavailable("法规 RAG collection 不存在,请先运行 regulatory_rag_build。") from exc
|
||||||
82
review_agent/regulatory_review/services/rag_embedding.py
Normal file
82
review_agent/regulatory_review/services/rag_embedding.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
|
from typing import Callable, Iterable
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
|
||||||
|
EmbeddingFunction = Callable[[list[str]], list[list[float]]]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingConfigurationError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SiliconFlowEmbeddingProvider:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str,
|
||||||
|
model: str,
|
||||||
|
dimensions: int,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
):
|
||||||
|
if not api_key:
|
||||||
|
raise EmbeddingConfigurationError("SILICONFLOW_API_KEY 未配置。")
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.model = model
|
||||||
|
self.dimensions = dimensions
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def embed(self, texts: Iterable[str]) -> list[list[float]]:
|
||||||
|
inputs = list(texts)
|
||||||
|
response = httpx.post(
|
||||||
|
f"{self.base_url}/embeddings",
|
||||||
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"input": inputs,
|
||||||
|
"dimensions": self.dimensions,
|
||||||
|
},
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
payload = response.json()
|
||||||
|
return [item["embedding"] for item in payload.get("data", [])]
|
||||||
|
|
||||||
|
def __call__(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
return self.embed(texts)
|
||||||
|
|
||||||
|
|
||||||
|
class DeterministicEmbeddingProvider:
|
||||||
|
"""Small local embedding substitute for tests and explicit dry runs."""
|
||||||
|
|
||||||
|
def __init__(self, dimensions: int = 16):
|
||||||
|
self.dimensions = dimensions
|
||||||
|
|
||||||
|
def __call__(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
vectors = []
|
||||||
|
for text in texts:
|
||||||
|
seed = int(hashlib.sha256(text.encode("utf-8")).hexdigest()[:16], 16)
|
||||||
|
rng = random.Random(seed)
|
||||||
|
vectors.append([rng.uniform(-1, 1) for _ in range(self.dimensions)])
|
||||||
|
return vectors
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_provider(provider_name: str | None = None) -> EmbeddingFunction:
|
||||||
|
provider = provider_name or settings.REGULATORY_RAG_PROVIDER
|
||||||
|
if provider == "siliconflow":
|
||||||
|
return SiliconFlowEmbeddingProvider(
|
||||||
|
api_key=settings.SILICONFLOW_API_KEY,
|
||||||
|
base_url=settings.SILICONFLOW_BASE_URL,
|
||||||
|
model=settings.SILICONFLOW_EMBEDDING_MODEL,
|
||||||
|
dimensions=settings.SILICONFLOW_EMBEDDING_DIMENSIONS,
|
||||||
|
)
|
||||||
|
if provider in {"deterministic", "local"}:
|
||||||
|
return DeterministicEmbeddingProvider()
|
||||||
|
raise EmbeddingConfigurationError(f"不支持的 embedding provider:{provider}")
|
||||||
148
review_agent/regulatory_review/services/rag_index.py
Normal file
148
review_agent/regulatory_review/services/rag_index.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from docx import Document
|
||||||
|
from openpyxl import load_workbook
|
||||||
|
from pypdf import PdfReader
|
||||||
|
from pptx import Presentation
|
||||||
|
|
||||||
|
from .rag_embedding import EmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TextChunk:
|
||||||
|
text: str
|
||||||
|
metadata: dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_text(text: str, *, source: str, chunk_size: int = 900, overlap: int = 120) -> list[TextChunk]:
|
||||||
|
normalized = "\n".join(line.strip() for line in text.splitlines() if line.strip())
|
||||||
|
if not normalized:
|
||||||
|
return []
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
index = 0
|
||||||
|
step = max(1, chunk_size - overlap)
|
||||||
|
while start < len(normalized):
|
||||||
|
part = normalized[start : start + chunk_size].strip()
|
||||||
|
if part:
|
||||||
|
chunks.append(TextChunk(text=part, metadata={"source": source, "chunk_index": index}))
|
||||||
|
index += 1
|
||||||
|
start += step
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_from_path(path: Path) -> str:
|
||||||
|
suffix = path.suffix.lower()
|
||||||
|
if suffix in {".txt", ".md"}:
|
||||||
|
return path.read_text(encoding="utf-8", errors="ignore")
|
||||||
|
if suffix == ".pdf":
|
||||||
|
return "\n".join(page.extract_text() or "" for page in PdfReader(str(path)).pages)
|
||||||
|
if suffix == ".docx":
|
||||||
|
return "\n".join(paragraph.text for paragraph in Document(str(path)).paragraphs)
|
||||||
|
if suffix == ".pptx":
|
||||||
|
presentation = Presentation(str(path))
|
||||||
|
lines = []
|
||||||
|
for slide in presentation.slides:
|
||||||
|
for shape in slide.shapes:
|
||||||
|
if hasattr(shape, "text"):
|
||||||
|
lines.append(shape.text)
|
||||||
|
return "\n".join(lines)
|
||||||
|
if suffix == ".xlsx":
|
||||||
|
workbook = load_workbook(path, data_only=True, read_only=True)
|
||||||
|
lines = []
|
||||||
|
for sheet in workbook.worksheets:
|
||||||
|
for row in sheet.iter_rows(values_only=True):
|
||||||
|
values = [str(cell) for cell in row if cell not in {None, ""}]
|
||||||
|
if values:
|
||||||
|
lines.append("\t".join(values))
|
||||||
|
return "\n".join(lines)
|
||||||
|
if suffix == ".doc":
|
||||||
|
return _extract_legacy_doc_with_libreoffice(path)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_legacy_doc_with_libreoffice(path: Path) -> str:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
target_dir = Path(tmp_dir)
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"soffice",
|
||||||
|
"--headless",
|
||||||
|
"--convert-to",
|
||||||
|
"docx",
|
||||||
|
"--outdir",
|
||||||
|
str(target_dir),
|
||||||
|
str(path),
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc:
|
||||||
|
raise RuntimeError(f"无法通过 LibreOffice 转换法规 .doc 材料:{path.name}") from exc
|
||||||
|
converted = target_dir / f"{path.stem}.docx"
|
||||||
|
if not converted.exists():
|
||||||
|
raise RuntimeError(f"LibreOffice 未生成 docx:{path.name}")
|
||||||
|
return extract_text_from_path(converted)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
|
||||||
|
chunks: list[TextChunk] = []
|
||||||
|
for path in sorted(source_dir.rglob("*")):
|
||||||
|
if not path.is_file():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
text = extract_text_from_path(path)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
logger.warning("Regulatory source extraction skipped", extra={"path": str(path), "error": str(exc)})
|
||||||
|
continue
|
||||||
|
chunks.extend(chunk_text(text, source=str(path.relative_to(source_dir))))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def build_chroma_index(
|
||||||
|
*,
|
||||||
|
source_dir: Path,
|
||||||
|
embedding_provider: EmbeddingFunction,
|
||||||
|
persist_path: Path | None = None,
|
||||||
|
collection_name: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
except ImportError as exc:
|
||||||
|
raise RuntimeError("chromadb 未安装,请先安装 requirements.txt。") from exc
|
||||||
|
|
||||||
|
persist_path = persist_path or Path(settings.REGULATORY_RAG_CHROMA_PATH)
|
||||||
|
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
|
||||||
|
persist_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
chunks = collect_source_chunks(source_dir)
|
||||||
|
client = chromadb.PersistentClient(path=str(persist_path))
|
||||||
|
collection = client.get_or_create_collection(collection_name)
|
||||||
|
if not chunks:
|
||||||
|
return 0
|
||||||
|
texts = [chunk.text for chunk in chunks]
|
||||||
|
embeddings = embedding_provider(texts)
|
||||||
|
ids = [
|
||||||
|
hashlib.sha256(f"{chunk.metadata['source']}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
collection.upsert(
|
||||||
|
ids=ids,
|
||||||
|
documents=texts,
|
||||||
|
metadatas=[chunk.metadata for chunk in chunks],
|
||||||
|
embeddings=embeddings,
|
||||||
|
)
|
||||||
|
return len(chunks)
|
||||||
72
tests/test_regulatory_rag.py
Normal file
72
tests/test_regulatory_rag.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from review_agent.regulatory_review.services.rag_citation import (
|
||||||
|
RagIndexUnavailable,
|
||||||
|
retrieve_citations,
|
||||||
|
)
|
||||||
|
from review_agent.regulatory_review.services.rag_embedding import SiliconFlowEmbeddingProvider
|
||||||
|
from review_agent.regulatory_review.services.rag_index import chunk_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_siliconflow_embedding_provider_posts_expected_payload(monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
def raise_for_status(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}
|
||||||
|
|
||||||
|
def fake_post(url, headers, json, timeout):
|
||||||
|
calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout})
|
||||||
|
return FakeResponse()
|
||||||
|
|
||||||
|
monkeypatch.setattr("review_agent.regulatory_review.services.rag_embedding.httpx.post", fake_post)
|
||||||
|
|
||||||
|
provider = SiliconFlowEmbeddingProvider(
|
||||||
|
api_key="secret",
|
||||||
|
base_url="https://api.siliconflow.cn/v1",
|
||||||
|
model="Qwen/Qwen3-Embedding-4B",
|
||||||
|
dimensions=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.embed(["法规依据", "注册检验报告"]) == [[0.1, 0.2], [0.3, 0.4]]
|
||||||
|
assert calls[0]["url"] == "https://api.siliconflow.cn/v1/embeddings"
|
||||||
|
assert calls[0]["headers"]["Authorization"] == "Bearer secret"
|
||||||
|
assert calls[0]["json"]["model"] == "Qwen/Qwen3-Embedding-4B"
|
||||||
|
assert calls[0]["json"]["dimensions"] == 1024
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_text_preserves_source_metadata():
|
||||||
|
chunks = chunk_text(
|
||||||
|
"第一段法规内容。\n" * 20,
|
||||||
|
source="法规.doc",
|
||||||
|
chunk_size=30,
|
||||||
|
overlap=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) > 1
|
||||||
|
assert chunks[0].metadata["source"] == "法规.doc"
|
||||||
|
assert chunks[0].text
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_citations_returns_placeholder_when_no_hits():
|
||||||
|
class EmptyCollection:
|
||||||
|
def query(self, query_embeddings, n_results):
|
||||||
|
return {"documents": [[]], "metadatas": [[]], "distances": [[]]}
|
||||||
|
|
||||||
|
citations = retrieve_citations(
|
||||||
|
"注册检验报告",
|
||||||
|
embedding_provider=lambda texts: [[0.1, 0.2]],
|
||||||
|
collection=EmptyCollection(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert citations[0]["source"] == "原文依据待补充"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_citations_raises_when_index_missing(settings, tmp_path):
|
||||||
|
settings.REGULATORY_RAG_CHROMA_PATH = tmp_path / "missing"
|
||||||
|
|
||||||
|
with pytest.raises(RagIndexUnavailable):
|
||||||
|
retrieve_citations("注册检验报告", embedding_provider=lambda texts: [[0.1]])
|
||||||
Reference in New Issue
Block a user