feat(regulatory): 增加本地法规RAG索引检索

This commit is contained in:
2026-06-07 00:30:53 +08:00
parent 2a4dd6cfab
commit 26490f7c46
7 changed files with 411 additions and 0 deletions

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import hashlib
import random
from typing import Callable, Iterable
import httpx
from django.conf import settings
EmbeddingFunction = Callable[[list[str]], list[list[float]]]
class EmbeddingConfigurationError(RuntimeError):
pass
class SiliconFlowEmbeddingProvider:
def __init__(
self,
*,
api_key: str,
base_url: str,
model: str,
dimensions: int,
timeout: float = 60.0,
):
if not api_key:
raise EmbeddingConfigurationError("SILICONFLOW_API_KEY 未配置。")
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.model = model
self.dimensions = dimensions
self.timeout = timeout
def embed(self, texts: Iterable[str]) -> list[list[float]]:
inputs = list(texts)
response = httpx.post(
f"{self.base_url}/embeddings",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"input": inputs,
"dimensions": self.dimensions,
},
timeout=self.timeout,
)
response.raise_for_status()
payload = response.json()
return [item["embedding"] for item in payload.get("data", [])]
def __call__(self, texts: list[str]) -> list[list[float]]:
return self.embed(texts)
class DeterministicEmbeddingProvider:
"""Small local embedding substitute for tests and explicit dry runs."""
def __init__(self, dimensions: int = 16):
self.dimensions = dimensions
def __call__(self, texts: list[str]) -> list[list[float]]:
vectors = []
for text in texts:
seed = int(hashlib.sha256(text.encode("utf-8")).hexdigest()[:16], 16)
rng = random.Random(seed)
vectors.append([rng.uniform(-1, 1) for _ in range(self.dimensions)])
return vectors
def get_embedding_provider(provider_name: str | None = None) -> EmbeddingFunction:
provider = provider_name or settings.REGULATORY_RAG_PROVIDER
if provider == "siliconflow":
return SiliconFlowEmbeddingProvider(
api_key=settings.SILICONFLOW_API_KEY,
base_url=settings.SILICONFLOW_BASE_URL,
model=settings.SILICONFLOW_EMBEDDING_MODEL,
dimensions=settings.SILICONFLOW_EMBEDDING_DIMENSIONS,
)
if provider in {"deterministic", "local"}:
return DeterministicEmbeddingProvider()
raise EmbeddingConfigurationError(f"不支持的 embedding provider{provider}")