feat(modelprovider): 完善模型调用与RAG召回支撑

This commit is contained in:
2026-05-31 23:56:31 +08:00
parent 1e004f1a83
commit ab9b099e9b
9 changed files with 295 additions and 58 deletions

View File

@@ -13,6 +13,10 @@ public interface OpenAiCompatibleModelClient {
* 方法 embeddings用于定义接口能力契约。
*/
List<List<Double>> embeddings(ModelProvider provider, ModelConfig model, List<String> texts, Integer expectedDimension);
/**
* 方法 chatCompletions用于定义接口能力契约。
*/
OpenAiChatCompletionResult chatCompletions(ModelProvider provider, ModelConfig model, List<OpenAiChatMessage> messages);
/**
* 方法 health用于定义接口能力契约。
*/

View File

@@ -1,10 +1,13 @@
package com.bruce.modelprovider.client;
import com.bruce.modelprovider.config.AiSecretProperties;
import com.bruce.modelprovider.entity.ModelConfig;
import com.bruce.modelprovider.entity.ModelProvider;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import java.util.ArrayList;
@@ -22,11 +25,17 @@ import java.util.Map;
* 4. API Key 从 `secretRef` 对应环境变量读取,不在代码中硬编码。
*/
@Component
@RequiredArgsConstructor
/**
* OpenAiCompatibleModelClientImpl负责模型平台对应层的职责。
*/
public class OpenAiCompatibleModelClientImpl implements OpenAiCompatibleModelClient {
/**
* 统一读取独立 AI 配置文件中的密钥映射。
*/
private final AiSecretProperties aiSecretProperties;
/**
* 调用上游 Embedding 接口并解析向量数组。
*/
@@ -74,6 +83,63 @@ public class OpenAiCompatibleModelClientImpl implements OpenAiCompatibleModelCli
return vectors;
}
@Override
@SuppressWarnings("unchecked")
public OpenAiChatCompletionResult chatCompletions(ModelProvider provider, ModelConfig model, List<OpenAiChatMessage> messages) {
if (messages == null || messages.isEmpty()) {
throw new IllegalArgumentException("聊天消息不能为空");
}
RestClient client = RestClient.builder().baseUrl(provider.getBaseUrl()).build();
List<Map<String, String>> payloadMessages = new ArrayList<>();
for (OpenAiChatMessage message : messages) {
if (message == null || !StringUtils.hasText(message.getContent())) {
continue;
}
Map<String, String> item = new HashMap<>();
item.put("role", StringUtils.hasText(message.getRole()) ? message.getRole().trim() : "user");
item.put("content", message.getContent());
payloadMessages.add(item);
}
if (payloadMessages.isEmpty()) {
throw new IllegalArgumentException("聊天消息内容不能为空");
}
Map<String, Object> body = new HashMap<>();
body.put("model", model.getUpstreamModel());
body.put("messages", payloadMessages);
RestClient.RequestBodySpec request = client.post().uri("/chat/completions")
.contentType(MediaType.APPLICATION_JSON)
.body(body);
String apiKey = resolveApiKey(provider);
if (apiKey != null) {
request = request.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey);
}
Map<String, Object> response = request.retrieve().body(Map.class);
if (response == null || !(response.get("choices") instanceof List<?> choices) || choices.isEmpty()) {
throw new IllegalStateException("上游Chat响应缺少choices字段");
}
Object first = choices.getFirst();
if (!(first instanceof Map<?, ?> firstChoice)
|| !(firstChoice.get("message") instanceof Map<?, ?> message)
|| !(message.get("content") instanceof String content)
|| !StringUtils.hasText(content)) {
throw new IllegalStateException("上游Chat响应缺少message.content");
}
OpenAiChatCompletionResult result = new OpenAiChatCompletionResult();
result.setUpstreamRequestId(String.valueOf(response.get("id")));
result.setContent(content);
if (response.get("usage") instanceof Map<?, ?> usage) {
result.setPromptTokens(toInteger(usage.get("prompt_tokens")));
result.setCompletionTokens(toInteger(usage.get("completion_tokens")));
result.setTotalTokens(toInteger(usage.get("total_tokens")));
}
return result;
}
/**
* 调用 `/models` 做健康探测:成功返回 true异常返回 false。
*/
@@ -98,14 +164,34 @@ public class OpenAiCompatibleModelClientImpl implements OpenAiCompatibleModelCli
/**
* 读取服务商密钥:
* 有 secretRef 时从环境变量读取;首期不使用数据库密钥明文。
* 1) 优先读取 Spring AI 独立配置文件ai-config.ini
* 2) 再读取环境变量,兼容原有部署方式;
* 3) 最后回退数据库密文/占位字段(兼容历史数据)。
*/
private String resolveApiKey(ModelProvider provider) {
if (provider.getSecretRef() != null && !provider.getSecretRef().isBlank()) {
return System.getenv(provider.getSecretRef().trim());
String secretRef = provider.getSecretRef().trim();
String fromSpringConfig = aiSecretProperties.getApiKeyBySecretRef(secretRef);
if (StringUtils.hasText(fromSpringConfig)) {
return fromSpringConfig;
}
String fromEnv = System.getenv(secretRef);
if (StringUtils.hasText(fromEnv)) {
return fromEnv.trim();
}
}
if (StringUtils.hasText(provider.getApiKeyCipher())) {
return provider.getApiKeyCipher().trim();
}
return null;
}
private Integer toInteger(Object value) {
if (value == null) {
return null;
}
return Integer.valueOf(String.valueOf(value));
}
}

View File

@@ -1,9 +1,32 @@
package com.bruce.rag.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.bruce.rag.dto.response.RagChunkRecallResponse;
import com.bruce.rag.entity.RagChunkEmbedding;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import java.util.List;
@Mapper
public interface RagChunkEmbeddingMapper extends BaseMapper<RagChunkEmbedding> {
@Select("""
SELECT
e.chunk_id AS chunkId,
e.document_id AS documentId,
c.chunk_content AS chunkContent,
1 - (e.embedding <=> CAST(#{queryVector} AS vector)) AS score
FROM rag_chunk_embedding e
INNER JOIN rag_chunk c ON c.id = e.chunk_id
WHERE e.store_id = #{storeId}
AND e.enabled = TRUE
AND c.enabled = TRUE
ORDER BY e.embedding <=> CAST(#{queryVector} AS vector)
LIMIT #{topK}
""")
List<RagChunkRecallResponse> queryTopKByStore(@Param("storeId") Long storeId,
@Param("queryVector") String queryVector,
@Param("topK") int topK);
}