141 lines
6.4 KiB
Java
141 lines
6.4 KiB
Java
package com.bruce.modelprovider.gateway;
|
||
|
||
import com.bruce.modelprovider.client.OpenAiCompatibleModelClient;
|
||
import com.bruce.modelprovider.entity.ModelCallLog;
|
||
import com.bruce.modelprovider.entity.ModelConfig;
|
||
import com.bruce.modelprovider.entity.ModelProvider;
|
||
import com.bruce.modelprovider.enums.ModelCallStatusEnum;
|
||
import com.bruce.modelprovider.route.ModelRouteContext;
|
||
import com.bruce.modelprovider.route.ModelRouteDecision;
|
||
import com.bruce.modelprovider.service.IModelCallLogService;
|
||
import com.bruce.modelprovider.service.IModelProviderService;
|
||
import com.bruce.modelprovider.service.IModelRouteService;
|
||
import lombok.RequiredArgsConstructor;
|
||
import org.springframework.stereotype.Component;
|
||
import org.springframework.util.DigestUtils;
|
||
|
||
import java.nio.charset.StandardCharsets;
|
||
import java.util.List;
|
||
import java.util.UUID;
|
||
|
||
/**
|
||
* Embedding 网关实现。
|
||
* <p>
|
||
* 主要职责:
|
||
* 1. 将业务请求转换为统一路由上下文并完成模型决策;
|
||
* 2. 调用上游 Embedding 接口并执行结果校验(数量、维度);
|
||
* 3. 记录调用日志(成功/失败、耗时、错误摘要、请求哈希);
|
||
* 4. 在主模型失败时按备用模型顺序执行兜底调用。
|
||
*/
|
||
@Component
|
||
@RequiredArgsConstructor
|
||
/**
|
||
* EmbeddingModelGatewayImpl,负责模型平台对应层的职责。
|
||
*/
|
||
public class EmbeddingModelGatewayImpl implements EmbeddingModelGateway {
|
||
|
||
private final IModelRouteService modelRouteService;
|
||
private final IModelProviderService modelProviderService;
|
||
private final IModelCallLogService modelCallLogService;
|
||
private final OpenAiCompatibleModelClient openAiCompatibleModelClient;
|
||
|
||
/**
|
||
* 统一 Embedding 调用入口。
|
||
*
|
||
* @param request 向量化请求,包含文本、任务类型和业务上下文
|
||
* @return 向量化结果(模型信息 + 向量数组 + 调用日志)
|
||
*/
|
||
@Override
|
||
/**
|
||
* 方法 embed,用于执行业务逻辑处理。
|
||
*/
|
||
public EmbeddingResult embed(EmbeddingRequest request) {
|
||
long start = System.currentTimeMillis();
|
||
ModelCallLog callLog = new ModelCallLog();
|
||
callLog.setRequestId(UUID.randomUUID().toString().replace("-", ""));
|
||
callLog.setTaskType(request.getTaskType());
|
||
callLog.setBizType(request.getBizType());
|
||
callLog.setBizId(request.getBizId());
|
||
callLog.setCallType("EMBEDDING");
|
||
callLog.setRequestHash(DigestUtils.md5DigestAsHex(String.join("|", request.getTexts()).getBytes(StandardCharsets.UTF_8)));
|
||
|
||
try {
|
||
ModelRouteContext routeContext = new ModelRouteContext();
|
||
routeContext.setTaskType(request.getTaskType());
|
||
routeContext.setMatchScope(request.getMatchScope());
|
||
routeContext.setScopeId(request.getScopeId());
|
||
routeContext.setRequiredModelType("EMBEDDING");
|
||
routeContext.setRequiredEmbeddingDimension(request.getExpectedDimension());
|
||
routeContext.setBizType(request.getBizType());
|
||
routeContext.setBizId(request.getBizId());
|
||
ModelRouteDecision decision = modelRouteService.route(routeContext);
|
||
|
||
ModelConfig model = decision.getPrimaryModel();
|
||
ModelProvider provider = modelProviderService.getById(model.getProviderId());
|
||
if (provider == null || !Boolean.TRUE.equals(provider.getEnabled())) {
|
||
throw new IllegalStateException("模型服务商不可用");
|
||
}
|
||
|
||
List<List<Double>> vectors = executeWithFallback(provider, model, decision.getFallbackModels(), request.getTexts(), request.getExpectedDimension());
|
||
if (vectors.size() != request.getTexts().size()) {
|
||
throw new IllegalStateException("向量数量与输入文本数量不一致");
|
||
}
|
||
Integer dimension = vectors.isEmpty() ? 0 : vectors.getFirst().size();
|
||
if (request.getExpectedDimension() != null && !request.getExpectedDimension().equals(dimension)) {
|
||
throw new IllegalStateException("向量维度不匹配,expected=" + request.getExpectedDimension() + ", actual=" + dimension);
|
||
}
|
||
|
||
callLog.setProviderId(provider.getId());
|
||
callLog.setModelId(model.getId());
|
||
callLog.setStatus(ModelCallStatusEnum.SUCCESS.name());
|
||
callLog.setDurationMs((int) (System.currentTimeMillis() - start));
|
||
modelCallLogService.save(callLog);
|
||
|
||
EmbeddingResult result = new EmbeddingResult();
|
||
result.setModelId(model.getId());
|
||
result.setModelName(model.getModelName());
|
||
result.setDimension(dimension);
|
||
result.setVectors(vectors);
|
||
result.setCallLog(callLog);
|
||
return result;
|
||
} catch (Exception ex) {
|
||
callLog.setStatus(ModelCallStatusEnum.FAILED.name());
|
||
callLog.setDurationMs((int) (System.currentTimeMillis() - start));
|
||
callLog.setErrorCode("EMBEDDING_FAILED");
|
||
String msg = ex.getMessage();
|
||
callLog.setErrorMessage(msg == null ? "unknown" : msg.substring(0, Math.min(msg.length(), 1000)));
|
||
modelCallLogService.save(callLog);
|
||
throw ex;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 主模型优先调用,失败后按备用模型顺序重试。
|
||
*/
|
||
private List<List<Double>> executeWithFallback(ModelProvider primaryProvider,
|
||
ModelConfig primaryModel,
|
||
List<ModelConfig> fallbackModels,
|
||
List<String> texts,
|
||
Integer expectedDimension) {
|
||
try {
|
||
return openAiCompatibleModelClient.embeddings(primaryProvider, primaryModel, texts, expectedDimension);
|
||
} catch (Exception primaryEx) {
|
||
for (ModelConfig fallback : fallbackModels) {
|
||
try {
|
||
ModelProvider provider = modelProviderService.getById(fallback.getProviderId());
|
||
if (provider == null || !Boolean.TRUE.equals(provider.getEnabled())) {
|
||
continue;
|
||
}
|
||
return openAiCompatibleModelClient.embeddings(provider, fallback, texts, expectedDimension);
|
||
} catch (Exception ignored) {
|
||
// continue fallback chain
|
||
}
|
||
}
|
||
throw primaryEx;
|
||
}
|
||
}
|
||
}
|
||
|
||
|
||
|