feat(regulatory): 增加本地法规RAG索引检索
This commit is contained in:
33
review_agent/management/commands/regulatory_rag_build.py
Normal file
33
review_agent/management/commands/regulatory_rag_build.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
|
||||
from review_agent.regulatory_review.services.rag_embedding import get_embedding_provider
|
||||
from review_agent.regulatory_review.services.rag_index import build_chroma_index
|
||||
from review_agent.regulatory_review.services.rule_loader import load_rule_file
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "构建 NMPA 法规材料本地 ChromaDB RAG 索引。"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument("--provider", default=None, help="覆盖 REGULATORY_RAG_PROVIDER。")
|
||||
|
||||
def handle(self, *args, **options):
|
||||
rule_set = load_rule_file()
|
||||
source_dir = Path(settings.BASE_DIR) / rule_set["source_material_dir"]
|
||||
if not source_dir.exists():
|
||||
raise CommandError(f"法规材料目录不存在:{source_dir}")
|
||||
try:
|
||||
provider = get_embedding_provider(options["provider"])
|
||||
count = build_chroma_index(source_dir=source_dir, embedding_provider=provider)
|
||||
except Exception as exc:
|
||||
raise CommandError(str(exc)) from exc
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"已构建法规 RAG 索引:collection={settings.REGULATORY_RAG_COLLECTION}, chunks={count}"
|
||||
)
|
||||
)
|
||||
57
review_agent/regulatory_review/services/rag_citation.py
Normal file
57
review_agent/regulatory_review/services/rag_citation.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from .rag_embedding import EmbeddingFunction, get_embedding_provider
|
||||
|
||||
|
||||
class RagIndexUnavailable(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def retrieve_citations(
|
||||
query: str,
|
||||
*,
|
||||
embedding_provider: EmbeddingFunction | None = None,
|
||||
collection=None,
|
||||
n_results: int = 3,
|
||||
) -> list[dict[str, object]]:
|
||||
provider = embedding_provider or get_embedding_provider()
|
||||
if collection is None:
|
||||
collection = _load_collection()
|
||||
embeddings = provider([query])
|
||||
result = collection.query(query_embeddings=embeddings, n_results=n_results)
|
||||
documents = (result.get("documents") or [[]])[0]
|
||||
metadatas = (result.get("metadatas") or [[]])[0]
|
||||
distances = (result.get("distances") or [[]])[0]
|
||||
if not documents:
|
||||
return [{"source": "原文依据待补充", "text": "RAG 无命中", "score": None}]
|
||||
citations = []
|
||||
for index, document in enumerate(documents):
|
||||
metadata = metadatas[index] if index < len(metadatas) else {}
|
||||
distance = distances[index] if index < len(distances) else None
|
||||
citations.append(
|
||||
{
|
||||
"source": metadata.get("source", "法规材料"),
|
||||
"text": document,
|
||||
"score": distance,
|
||||
}
|
||||
)
|
||||
return citations
|
||||
|
||||
|
||||
def _load_collection():
|
||||
persist_path = Path(settings.REGULATORY_RAG_CHROMA_PATH)
|
||||
if not persist_path.exists():
|
||||
raise RagIndexUnavailable("法规 RAG 索引不存在,请先运行 regulatory_rag_build。")
|
||||
try:
|
||||
import chromadb
|
||||
except ImportError as exc:
|
||||
raise RagIndexUnavailable("chromadb 未安装,请先安装 requirements.txt。") from exc
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
try:
|
||||
return client.get_collection(settings.REGULATORY_RAG_COLLECTION)
|
||||
except Exception as exc:
|
||||
raise RagIndexUnavailable("法规 RAG collection 不存在,请先运行 regulatory_rag_build。") from exc
|
||||
82
review_agent/regulatory_review/services/rag_embedding.py
Normal file
82
review_agent/regulatory_review/services/rag_embedding.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import random
|
||||
from typing import Callable, Iterable
|
||||
|
||||
import httpx
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
EmbeddingFunction = Callable[[list[str]], list[list[float]]]
|
||||
|
||||
|
||||
class EmbeddingConfigurationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class SiliconFlowEmbeddingProvider:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
dimensions: int,
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
if not api_key:
|
||||
raise EmbeddingConfigurationError("SILICONFLOW_API_KEY 未配置。")
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.timeout = timeout
|
||||
|
||||
def embed(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
inputs = list(texts)
|
||||
response = httpx.post(
|
||||
f"{self.base_url}/embeddings",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"input": inputs,
|
||||
"dimensions": self.dimensions,
|
||||
},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
return [item["embedding"] for item in payload.get("data", [])]
|
||||
|
||||
def __call__(self, texts: list[str]) -> list[list[float]]:
|
||||
return self.embed(texts)
|
||||
|
||||
|
||||
class DeterministicEmbeddingProvider:
|
||||
"""Small local embedding substitute for tests and explicit dry runs."""
|
||||
|
||||
def __init__(self, dimensions: int = 16):
|
||||
self.dimensions = dimensions
|
||||
|
||||
def __call__(self, texts: list[str]) -> list[list[float]]:
|
||||
vectors = []
|
||||
for text in texts:
|
||||
seed = int(hashlib.sha256(text.encode("utf-8")).hexdigest()[:16], 16)
|
||||
rng = random.Random(seed)
|
||||
vectors.append([rng.uniform(-1, 1) for _ in range(self.dimensions)])
|
||||
return vectors
|
||||
|
||||
|
||||
def get_embedding_provider(provider_name: str | None = None) -> EmbeddingFunction:
|
||||
provider = provider_name or settings.REGULATORY_RAG_PROVIDER
|
||||
if provider == "siliconflow":
|
||||
return SiliconFlowEmbeddingProvider(
|
||||
api_key=settings.SILICONFLOW_API_KEY,
|
||||
base_url=settings.SILICONFLOW_BASE_URL,
|
||||
model=settings.SILICONFLOW_EMBEDDING_MODEL,
|
||||
dimensions=settings.SILICONFLOW_EMBEDDING_DIMENSIONS,
|
||||
)
|
||||
if provider in {"deterministic", "local"}:
|
||||
return DeterministicEmbeddingProvider()
|
||||
raise EmbeddingConfigurationError(f"不支持的 embedding provider:{provider}")
|
||||
148
review_agent/regulatory_review/services/rag_index.py
Normal file
148
review_agent/regulatory_review/services/rag_index.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from docx import Document
|
||||
from openpyxl import load_workbook
|
||||
from pypdf import PdfReader
|
||||
from pptx import Presentation
|
||||
|
||||
from .rag_embedding import EmbeddingFunction
|
||||
|
||||
|
||||
logger = logging.getLogger("review_agent.regulatory_review.rag_index")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextChunk:
|
||||
text: str
|
||||
metadata: dict[str, object]
|
||||
|
||||
|
||||
def chunk_text(text: str, *, source: str, chunk_size: int = 900, overlap: int = 120) -> list[TextChunk]:
|
||||
normalized = "\n".join(line.strip() for line in text.splitlines() if line.strip())
|
||||
if not normalized:
|
||||
return []
|
||||
chunks = []
|
||||
start = 0
|
||||
index = 0
|
||||
step = max(1, chunk_size - overlap)
|
||||
while start < len(normalized):
|
||||
part = normalized[start : start + chunk_size].strip()
|
||||
if part:
|
||||
chunks.append(TextChunk(text=part, metadata={"source": source, "chunk_index": index}))
|
||||
index += 1
|
||||
start += step
|
||||
return chunks
|
||||
|
||||
|
||||
def extract_text_from_path(path: Path) -> str:
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in {".txt", ".md"}:
|
||||
return path.read_text(encoding="utf-8", errors="ignore")
|
||||
if suffix == ".pdf":
|
||||
return "\n".join(page.extract_text() or "" for page in PdfReader(str(path)).pages)
|
||||
if suffix == ".docx":
|
||||
return "\n".join(paragraph.text for paragraph in Document(str(path)).paragraphs)
|
||||
if suffix == ".pptx":
|
||||
presentation = Presentation(str(path))
|
||||
lines = []
|
||||
for slide in presentation.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
lines.append(shape.text)
|
||||
return "\n".join(lines)
|
||||
if suffix == ".xlsx":
|
||||
workbook = load_workbook(path, data_only=True, read_only=True)
|
||||
lines = []
|
||||
for sheet in workbook.worksheets:
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
values = [str(cell) for cell in row if cell not in {None, ""}]
|
||||
if values:
|
||||
lines.append("\t".join(values))
|
||||
return "\n".join(lines)
|
||||
if suffix == ".doc":
|
||||
return _extract_legacy_doc_with_libreoffice(path)
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_legacy_doc_with_libreoffice(path: Path) -> str:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
target_dir = Path(tmp_dir)
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"soffice",
|
||||
"--headless",
|
||||
"--convert-to",
|
||||
"docx",
|
||||
"--outdir",
|
||||
str(target_dir),
|
||||
str(path),
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc:
|
||||
raise RuntimeError(f"无法通过 LibreOffice 转换法规 .doc 材料:{path.name}") from exc
|
||||
converted = target_dir / f"{path.stem}.docx"
|
||||
if not converted.exists():
|
||||
raise RuntimeError(f"LibreOffice 未生成 docx:{path.name}")
|
||||
return extract_text_from_path(converted)
|
||||
|
||||
|
||||
def collect_source_chunks(source_dir: Path) -> list[TextChunk]:
|
||||
chunks: list[TextChunk] = []
|
||||
for path in sorted(source_dir.rglob("*")):
|
||||
if not path.is_file():
|
||||
continue
|
||||
try:
|
||||
text = extract_text_from_path(path)
|
||||
except RuntimeError as exc:
|
||||
logger.warning("Regulatory source extraction skipped", extra={"path": str(path), "error": str(exc)})
|
||||
continue
|
||||
chunks.extend(chunk_text(text, source=str(path.relative_to(source_dir))))
|
||||
return chunks
|
||||
|
||||
|
||||
def build_chroma_index(
|
||||
*,
|
||||
source_dir: Path,
|
||||
embedding_provider: EmbeddingFunction,
|
||||
persist_path: Path | None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> int:
|
||||
try:
|
||||
import chromadb
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("chromadb 未安装,请先安装 requirements.txt。") from exc
|
||||
|
||||
persist_path = persist_path or Path(settings.REGULATORY_RAG_CHROMA_PATH)
|
||||
collection_name = collection_name or settings.REGULATORY_RAG_COLLECTION
|
||||
persist_path.mkdir(parents=True, exist_ok=True)
|
||||
chunks = collect_source_chunks(source_dir)
|
||||
client = chromadb.PersistentClient(path=str(persist_path))
|
||||
collection = client.get_or_create_collection(collection_name)
|
||||
if not chunks:
|
||||
return 0
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
embeddings = embedding_provider(texts)
|
||||
ids = [
|
||||
hashlib.sha256(f"{chunk.metadata['source']}:{chunk.metadata['chunk_index']}".encode("utf-8")).hexdigest()
|
||||
for chunk in chunks
|
||||
]
|
||||
collection.upsert(
|
||||
ids=ids,
|
||||
documents=texts,
|
||||
metadatas=[chunk.metadata for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
)
|
||||
return len(chunks)
|
||||
Reference in New Issue
Block a user