feat(modelprovider): 完善模型调用与RAG召回支撑
This commit is contained in:
@@ -13,6 +13,10 @@ public interface OpenAiCompatibleModelClient {
|
||||
* 方法 embeddings,用于定义接口能力契约。
|
||||
*/
|
||||
List<List<Double>> embeddings(ModelProvider provider, ModelConfig model, List<String> texts, Integer expectedDimension);
|
||||
/**
|
||||
* 方法 chatCompletions,用于定义接口能力契约。
|
||||
*/
|
||||
OpenAiChatCompletionResult chatCompletions(ModelProvider provider, ModelConfig model, List<OpenAiChatMessage> messages);
|
||||
/**
|
||||
* 方法 health,用于定义接口能力契约。
|
||||
*/
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package com.bruce.modelprovider.client;
|
||||
|
||||
import com.bruce.modelprovider.config.AiSecretProperties;
|
||||
import com.bruce.modelprovider.entity.ModelConfig;
|
||||
import com.bruce.modelprovider.entity.ModelProvider;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -22,11 +25,17 @@ import java.util.Map;
|
||||
* 4. API Key 从 `secretRef` 对应环境变量读取,不在代码中硬编码。
|
||||
*/
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
/**
|
||||
* OpenAiCompatibleModelClientImpl,负责模型平台对应层的职责。
|
||||
*/
|
||||
public class OpenAiCompatibleModelClientImpl implements OpenAiCompatibleModelClient {
|
||||
|
||||
/**
|
||||
* 统一读取独立 AI 配置文件中的密钥映射。
|
||||
*/
|
||||
private final AiSecretProperties aiSecretProperties;
|
||||
|
||||
/**
|
||||
* 调用上游 Embedding 接口并解析向量数组。
|
||||
*/
|
||||
@@ -74,6 +83,63 @@ public class OpenAiCompatibleModelClientImpl implements OpenAiCompatibleModelCli
|
||||
return vectors;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public OpenAiChatCompletionResult chatCompletions(ModelProvider provider, ModelConfig model, List<OpenAiChatMessage> messages) {
|
||||
if (messages == null || messages.isEmpty()) {
|
||||
throw new IllegalArgumentException("聊天消息不能为空");
|
||||
}
|
||||
RestClient client = RestClient.builder().baseUrl(provider.getBaseUrl()).build();
|
||||
|
||||
List<Map<String, String>> payloadMessages = new ArrayList<>();
|
||||
for (OpenAiChatMessage message : messages) {
|
||||
if (message == null || !StringUtils.hasText(message.getContent())) {
|
||||
continue;
|
||||
}
|
||||
Map<String, String> item = new HashMap<>();
|
||||
item.put("role", StringUtils.hasText(message.getRole()) ? message.getRole().trim() : "user");
|
||||
item.put("content", message.getContent());
|
||||
payloadMessages.add(item);
|
||||
}
|
||||
if (payloadMessages.isEmpty()) {
|
||||
throw new IllegalArgumentException("聊天消息内容不能为空");
|
||||
}
|
||||
|
||||
Map<String, Object> body = new HashMap<>();
|
||||
body.put("model", model.getUpstreamModel());
|
||||
body.put("messages", payloadMessages);
|
||||
|
||||
RestClient.RequestBodySpec request = client.post().uri("/chat/completions")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.body(body);
|
||||
String apiKey = resolveApiKey(provider);
|
||||
if (apiKey != null) {
|
||||
request = request.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey);
|
||||
}
|
||||
|
||||
Map<String, Object> response = request.retrieve().body(Map.class);
|
||||
if (response == null || !(response.get("choices") instanceof List<?> choices) || choices.isEmpty()) {
|
||||
throw new IllegalStateException("上游Chat响应缺少choices字段");
|
||||
}
|
||||
Object first = choices.getFirst();
|
||||
if (!(first instanceof Map<?, ?> firstChoice)
|
||||
|| !(firstChoice.get("message") instanceof Map<?, ?> message)
|
||||
|| !(message.get("content") instanceof String content)
|
||||
|| !StringUtils.hasText(content)) {
|
||||
throw new IllegalStateException("上游Chat响应缺少message.content");
|
||||
}
|
||||
|
||||
OpenAiChatCompletionResult result = new OpenAiChatCompletionResult();
|
||||
result.setUpstreamRequestId(String.valueOf(response.get("id")));
|
||||
result.setContent(content);
|
||||
if (response.get("usage") instanceof Map<?, ?> usage) {
|
||||
result.setPromptTokens(toInteger(usage.get("prompt_tokens")));
|
||||
result.setCompletionTokens(toInteger(usage.get("completion_tokens")));
|
||||
result.setTotalTokens(toInteger(usage.get("total_tokens")));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 调用 `/models` 做健康探测:成功返回 true,异常返回 false。
|
||||
*/
|
||||
@@ -98,14 +164,34 @@ public class OpenAiCompatibleModelClientImpl implements OpenAiCompatibleModelCli
|
||||
|
||||
/**
|
||||
* 读取服务商密钥:
|
||||
* 有 secretRef 时从环境变量读取;首期不使用数据库密钥明文。
|
||||
* 1) 优先读取 Spring AI 独立配置文件(ai-config.ini);
|
||||
* 2) 再读取环境变量,兼容原有部署方式;
|
||||
* 3) 最后回退数据库密文/占位字段(兼容历史数据)。
|
||||
*/
|
||||
private String resolveApiKey(ModelProvider provider) {
|
||||
if (provider.getSecretRef() != null && !provider.getSecretRef().isBlank()) {
|
||||
return System.getenv(provider.getSecretRef().trim());
|
||||
String secretRef = provider.getSecretRef().trim();
|
||||
String fromSpringConfig = aiSecretProperties.getApiKeyBySecretRef(secretRef);
|
||||
if (StringUtils.hasText(fromSpringConfig)) {
|
||||
return fromSpringConfig;
|
||||
}
|
||||
String fromEnv = System.getenv(secretRef);
|
||||
if (StringUtils.hasText(fromEnv)) {
|
||||
return fromEnv.trim();
|
||||
}
|
||||
}
|
||||
if (StringUtils.hasText(provider.getApiKeyCipher())) {
|
||||
return provider.getApiKeyCipher().trim();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private Integer toInteger(Object value) {
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
return Integer.valueOf(String.valueOf(value));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,32 @@
|
||||
package com.bruce.rag.mapper;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.bruce.rag.dto.response.RagChunkRecallResponse;
|
||||
import com.bruce.rag.entity.RagChunkEmbedding;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
import org.apache.ibatis.annotations.Select;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface RagChunkEmbeddingMapper extends BaseMapper<RagChunkEmbedding> {
|
||||
|
||||
@Select("""
|
||||
SELECT
|
||||
e.chunk_id AS chunkId,
|
||||
e.document_id AS documentId,
|
||||
c.chunk_content AS chunkContent,
|
||||
1 - (e.embedding <=> CAST(#{queryVector} AS vector)) AS score
|
||||
FROM rag_chunk_embedding e
|
||||
INNER JOIN rag_chunk c ON c.id = e.chunk_id
|
||||
WHERE e.store_id = #{storeId}
|
||||
AND e.enabled = TRUE
|
||||
AND c.enabled = TRUE
|
||||
ORDER BY e.embedding <=> CAST(#{queryVector} AS vector)
|
||||
LIMIT #{topK}
|
||||
""")
|
||||
List<RagChunkRecallResponse> queryTopKByStore(@Param("storeId") Long storeId,
|
||||
@Param("queryVector") String queryVector,
|
||||
@Param("topK") int topK);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user