Files
DEMO-AGENT/review_agent/regulatory_review/services/rag_embedding.py

83 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import hashlib
import random
from typing import Callable, Iterable
import httpx
from django.conf import settings
EmbeddingFunction = Callable[[list[str]], list[list[float]]]
class EmbeddingConfigurationError(RuntimeError):
pass
class SiliconFlowEmbeddingProvider:
def __init__(
self,
*,
api_key: str,
base_url: str,
model: str,
dimensions: int,
timeout: float = 60.0,
):
if not api_key:
raise EmbeddingConfigurationError("SILICONFLOW_API_KEY 未配置。")
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.model = model
self.dimensions = dimensions
self.timeout = timeout
def embed(self, texts: Iterable[str]) -> list[list[float]]:
inputs = list(texts)
response = httpx.post(
f"{self.base_url}/embeddings",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"input": inputs,
"dimensions": self.dimensions,
},
timeout=self.timeout,
)
response.raise_for_status()
payload = response.json()
return [item["embedding"] for item in payload.get("data", [])]
def __call__(self, texts: list[str]) -> list[list[float]]:
return self.embed(texts)
class DeterministicEmbeddingProvider:
"""Small local embedding substitute for tests and explicit dry runs."""
def __init__(self, dimensions: int = 16):
self.dimensions = dimensions
def __call__(self, texts: list[str]) -> list[list[float]]:
vectors = []
for text in texts:
seed = int(hashlib.sha256(text.encode("utf-8")).hexdigest()[:16], 16)
rng = random.Random(seed)
vectors.append([rng.uniform(-1, 1) for _ in range(self.dimensions)])
return vectors
def get_embedding_provider(provider_name: str | None = None) -> EmbeddingFunction:
provider = provider_name or settings.REGULATORY_RAG_PROVIDER
if provider == "siliconflow":
return SiliconFlowEmbeddingProvider(
api_key=settings.SILICONFLOW_API_KEY,
base_url=settings.SILICONFLOW_BASE_URL,
model=settings.SILICONFLOW_EMBEDDING_MODEL,
dimensions=settings.SILICONFLOW_EMBEDDING_DIMENSIONS,
)
if provider in {"deterministic", "local"}:
return DeterministicEmbeddingProvider()
raise EmbeddingConfigurationError(f"不支持的 embedding provider{provider}")