feat(regulatory): 增加本地法规RAG索引检索

This commit is contained in:
2026-06-07 00:30:53 +08:00
parent 2a4dd6cfab
commit 26490f7c46
7 changed files with 411 additions and 0 deletions

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
from pathlib import Path
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider
from review_agent.regulatory_review.services.rag_index import build_chroma_index
from review_agent.regulatory_review.services.rule_loader import load_rule_file
class Command(BaseCommand):
help = "构建 NMPA 法规材料本地 ChromaDB RAG 索引。"
def add_arguments(self, parser):
parser.add_argument("--provider", default=None, help="覆盖 REGULATORY_RAG_PROVIDER。")
def handle(self, *args, **options):
rule_set = load_rule_file()
source_dir = Path(settings.BASE_DIR) / rule_set["source_material_dir"]
if not source_dir.exists():
raise CommandError(f"法规材料目录不存在:{source_dir}")
try:
provider = get_embedding_provider(options["provider"])
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider)
except Exception as exc:
raise CommandError(str(exc)) from exc
self.stdout.write(
self.style.SUCCESS(
f"已构建法规 RAG 索引collection={settings.REGULATORY_RAG_COLLECTION}, chunks={count}"
)
)

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from pathlib import Path
from django.conf import settings
from .rag_embedding import EmbeddingFunction, get_embedding_provider
class RagIndexUnavailable(RuntimeError):
pass
def retrieve_citations(
query: str,
*,
embedding_provider: EmbeddingFunction | None = None,
collection=None,
n_results: int = 3,
) -> list[dict[str, object]]:
provider = embedding_provider or get_embedding_provider()
if collection is None:
collection = _load_collection()
embeddings = provider([query])
result = collection.query(query_embeddings=embeddings, n_results=n_results)
documents = (result.get("documents") or [[]])[0]
metadatas = (result.get("metadatas") or [[]])[0]
distances = (result.get("distances") or [[]])[0]
if not documents:
return [{"source": "原文依据待补充", "text": "RAG 无命中", "score": None}]
citations = []
for index, document in enumerate(documents):
metadata = metadatas[index] if index < len(metadatas) else {}
distance = distances[index] if index < len(distances) else None
citations.append(
{
"source": metadata.get("source", "法规材料"),
"text": document,
"score": distance,
}
)
return citations
def _load_collection():
persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH)
if not persist_path.exists():
raise RagIndexUnavailable("法规 RAG 索引不存在,请先运行 regulatory_rag_build。")
try:
import chromadb
except ImportError as exc:
raise RagIndexUnavailable("chromadb 未安装,请先安装 requirements.txt。") from exc
client = chromadb.PersistentClient(path=str(persist_path))
try:
return client.get_collection(settings.REGULATORY_RAG_COLLECTION)
except Exception as exc:
raise RagIndexUnavailable("法规 RAG collection 不存在,请先运行 regulatory_rag_build。") from exc

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import hashlib
import random
from typing import Callable, Iterable
import httpx
from django.conf import settings
EmbeddingFunction = Callable[[list[str]], list[list[float]]]
class EmbeddingConfigurationError(RuntimeError):
pass
class SiliconFlowEmbeddingProvider:
def __init__(
self,
*,
api_key: str,
base_url: str,
model: str,
dimensions: int,
timeout: float = 60.0,
):
if not api_key:
raise EmbeddingConfigurationError("SILICONFLOW_API_KEY 未配置。")
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.model = model
self.dimensions = dimensions
self.timeout = timeout
def embed(self, texts: Iterable[str]) -> list[list[float]]:
inputs = list(texts)
response = httpx.post(
f"{self.base_url}/embeddings",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"input": inputs,
"dimensions": self.dimensions,
},
timeout=self.timeout,
)
response.raise_for_status()
payload = response.json()
return [item["embedding"] for item in payload.get("data", [])]
def __call__(self, texts: list[str]) -> list[list[float]]:
return self.embed(texts)
class DeterministicEmbeddingProvider:
"""Small local embedding substitute for tests and explicit dry runs."""
def __init__(self, dimensions: int = 16):
self.dimensions = dimensions
def __call__(self, texts: list[str]) -> list[list[float]]:
vectors = []
for text in texts:
seed = int(hashlib.sha256(text.encode("utf-8")).hexdigest()[:16], 16)
rng = random.Random(seed)
vectors.append([rng.uniform(-1, 1) for _ in range(self.dimensions)])
return vectors
def get_embedding_provider(provider_name: str | None = None) -> EmbeddingFunction:
provider = provider_name or settings.REGULATORY_RAG_PROVIDER
if provider == "siliconflow":
return SiliconFlowEmbeddingProvider(
api_key=settings.SILICONFLOW_API_KEY,
base_url=settings.SILICONFLOW_BASE_URL,
model=settings.SILICONFLOW_EMBEDDING_MODEL,
dimensions=settings.SILICONFLOW_EMBEDDING_DIMENSIONS,
)
if provider in {"deterministic", "local"}:
return DeterministicEmbeddingProvider()
raise EmbeddingConfigurationError(f"不支持的 embedding provider{provider}")

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
import hashlib
import logging
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from django.conf import settings
from docx import Document
from openpyxl import load_workbook
from pypdf import PdfReader
from pptx import Presentation
from .rag_embedding import EmbeddingFunction
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
@dataclass(frozen=True)
class TextChunk:
text: str
metadata: dict[str, object]
def chunk_text(text: str, *, source: str, chunk_size: int = 900, overlap: int = 120) -> list[TextChunk]:
normalized = "\n".join(line.strip() for line in text.splitlines() if line.strip())
if not normalized:
return []
chunks = []
start = 0
index = 0
step = max(1, chunk_size - overlap)
while start < len(normalized):
part = normalized[start : start + chunk_size].strip()
if part:
chunks.append(TextChunk(text=part, metadata={"source": source, "chunk_index": index}))
index += 1
start += step
return chunks
def extract_text_from_path(path: Path) -> str:
suffix = path.suffix.lower()
if suffix in {".txt", ".md"}:
return path.read_text(encoding="utf-8", errors="ignore")
if suffix == ".pdf":
return "\n".join(page.extract_text() or "" for page in PdfReader(str(path)).pages)
if suffix == ".docx":
return "\n".join(paragraph.text for paragraph in Document(str(path)).paragraphs)
if suffix == ".pptx":
presentation = Presentation(str(path))
lines = []
for slide in presentation.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
lines.append(shape.text)
return "\n".join(lines)
if suffix == ".xlsx":
workbook = load_workbook(path, data_only=True, read_only=True)
lines = []
for sheet in workbook.worksheets:
for row in sheet.iter_rows(values_only=True):
values = [str(cell) for cell in row if cell not in {None, ""}]
if values:
lines.append("\t".join(values))
return "\n".join(lines)
if suffix == ".doc":
return _extract_legacy_doc_with_libreoffice(path)
return ""
def _extract_legacy_doc_with_libreoffice(path: Path) -> str:
with tempfile.TemporaryDirectory() as tmp_dir:
target_dir = Path(tmp_dir)
try:
subprocess.run(
[
"soffice",
"--headless",
"--convert-to",
"docx",
"--outdir",
str(target_dir),
str(path),
],
check=True,
capture_output=True,
text=True,
timeout=60,
)
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc:
raise RuntimeError(f"无法通过 LibreOffice 转换法规 .doc 材料:{path.name}") from exc
converted = target_dir / f"{path.stem}.docx"
if not converted.exists():
raise RuntimeError(f"LibreOffice 未生成 docx{path.name}")
return extract_text_from_path(converted)
def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
chunks: list[TextChunk] = []
for path in sorted(source_dir.rglob("*")):
if not path.is_file():
continue
try:
text = extract_text_from_path(path)
except RuntimeError as exc:
logger.warning("Regulatory source extraction skipped", extra={"path": str(path), "error": str(exc)})
continue
chunks.extend(chunk_text(text, source=str(path.relative_to(source_dir))))
return chunks
def build_chroma_index(
*,
source_dir: Path,
embedding_provider: EmbeddingFunction,
persist_path: Path | None = None,
collection_name: str | None = None,
) -> int:
try:
import chromadb
except ImportError as exc:
raise RuntimeError("chromadb 未安装,请先安装 requirements.txt。") from exc
persist_path = persist_path or Path(settings.REGULATORY_RAG_CHROMA_PATH)
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
persist_path.mkdir(parents=True, exist_ok=True)
chunks = collect_source_chunks(source_dir)
client = chromadb.PersistentClient(path=str(persist_path))
collection = client.get_or_create_collection(collection_name)
if not chunks:
return 0
texts = [chunk.text for chunk in chunks]
embeddings = embedding_provider(texts)
ids = [
hashlib.sha256(f"{chunk.metadata['source']}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
for chunk in chunks
]
collection.upsert(
ids=ids,
documents=texts,
metadatas=[chunk.metadata for chunk in chunks],
embeddings=embeddings,
)
return len(chunks)