Files
common_agent/src/main/java/com/bruce/modelprovider/gateway/EmbeddingModelGatewayImpl.java

141 lines
6.4 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package com.bruce.modelprovider.gateway;
import com.bruce.modelprovider.client.OpenAiCompatibleModelClient;
import com.bruce.modelprovider.entity.ModelCallLog;
import com.bruce.modelprovider.entity.ModelConfig;
import com.bruce.modelprovider.entity.ModelProvider;
import com.bruce.modelprovider.enums.ModelCallStatusEnum;
import com.bruce.modelprovider.route.ModelRouteContext;
import com.bruce.modelprovider.route.ModelRouteDecision;
import com.bruce.modelprovider.service.IModelCallLogService;
import com.bruce.modelprovider.service.IModelProviderService;
import com.bruce.modelprovider.service.IModelRouteService;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import org.springframework.util.DigestUtils;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.UUID;
/**
* Embedding 网关实现。
* <p>
* 主要职责:
* 1. 将业务请求转换为统一路由上下文并完成模型决策;
* 2. 调用上游 Embedding 接口并执行结果校验(数量、维度);
* 3. 记录调用日志(成功/失败、耗时、错误摘要、请求哈希);
* 4. 在主模型失败时按备用模型顺序执行兜底调用。
*/
@Component
@RequiredArgsConstructor
/**
* EmbeddingModelGatewayImpl负责模型平台对应层的职责。
*/
public class EmbeddingModelGatewayImpl implements EmbeddingModelGateway {
private final IModelRouteService modelRouteService;
private final IModelProviderService modelProviderService;
private final IModelCallLogService modelCallLogService;
private final OpenAiCompatibleModelClient openAiCompatibleModelClient;
/**
* 统一 Embedding 调用入口。
*
* @param request 向量化请求,包含文本、任务类型和业务上下文
* @return 向量化结果(模型信息 + 向量数组 + 调用日志)
*/
@Override
/**
* 方法 embed用于执行业务逻辑处理。
*/
public EmbeddingResult embed(EmbeddingRequest request) {
long start = System.currentTimeMillis();
ModelCallLog callLog = new ModelCallLog();
callLog.setRequestId(UUID.randomUUID().toString().replace("-", ""));
callLog.setTaskType(request.getTaskType());
callLog.setBizType(request.getBizType());
callLog.setBizId(request.getBizId());
callLog.setCallType("EMBEDDING");
callLog.setRequestHash(DigestUtils.md5DigestAsHex(String.join("|", request.getTexts()).getBytes(StandardCharsets.UTF_8)));
try {
ModelRouteContext routeContext = new ModelRouteContext();
routeContext.setTaskType(request.getTaskType());
routeContext.setMatchScope(request.getMatchScope());
routeContext.setScopeId(request.getScopeId());
routeContext.setRequiredModelType("EMBEDDING");
routeContext.setRequiredEmbeddingDimension(request.getExpectedDimension());
routeContext.setBizType(request.getBizType());
routeContext.setBizId(request.getBizId());
ModelRouteDecision decision = modelRouteService.route(routeContext);
ModelConfig model = decision.getPrimaryModel();
ModelProvider provider = modelProviderService.getById(model.getProviderId());
if (provider == null || !Boolean.TRUE.equals(provider.getEnabled())) {
throw new IllegalStateException("模型服务商不可用");
}
List<List<Double>> vectors = executeWithFallback(provider, model, decision.getFallbackModels(), request.getTexts(), request.getExpectedDimension());
if (vectors.size() != request.getTexts().size()) {
throw new IllegalStateException("向量数量与输入文本数量不一致");
}
Integer dimension = vectors.isEmpty() ? 0 : vectors.getFirst().size();
if (request.getExpectedDimension() != null && !request.getExpectedDimension().equals(dimension)) {
throw new IllegalStateException("向量维度不匹配expected=" + request.getExpectedDimension() + ", actual=" + dimension);
}
callLog.setProviderId(provider.getId());
callLog.setModelId(model.getId());
callLog.setStatus(ModelCallStatusEnum.SUCCESS.name());
callLog.setDurationMs((int) (System.currentTimeMillis() - start));
modelCallLogService.save(callLog);
EmbeddingResult result = new EmbeddingResult();
result.setModelId(model.getId());
result.setModelName(model.getModelName());
result.setDimension(dimension);
result.setVectors(vectors);
result.setCallLog(callLog);
return result;
} catch (Exception ex) {
callLog.setStatus(ModelCallStatusEnum.FAILED.name());
callLog.setDurationMs((int) (System.currentTimeMillis() - start));
callLog.setErrorCode("EMBEDDING_FAILED");
String msg = ex.getMessage();
callLog.setErrorMessage(msg == null ? "unknown" : msg.substring(0, Math.min(msg.length(), 1000)));
modelCallLogService.save(callLog);
throw ex;
}
}
/**
* 主模型优先调用,失败后按备用模型顺序重试。
*/
private List<List<Double>> executeWithFallback(ModelProvider primaryProvider,
ModelConfig primaryModel,
List<ModelConfig> fallbackModels,
List<String> texts,
Integer expectedDimension) {
try {
return openAiCompatibleModelClient.embeddings(primaryProvider, primaryModel, texts, expectedDimension);
} catch (Exception primaryEx) {
for (ModelConfig fallback : fallbackModels) {
try {
ModelProvider provider = modelProviderService.getById(fallback.getProviderId());
if (provider == null || !Boolean.TRUE.equals(provider.getEnabled())) {
continue;
}
return openAiCompatibleModelClient.embeddings(provider, fallback, texts, expectedDimension);
} catch (Exception ignored) {
// continue fallback chain
}
}
throw primaryEx;
}
}
}