feat(agent): 接入Agent调试与RAG召回链路

This commit is contained in:
2026-05-31 23:51:55 +08:00
parent 21c9eaa44d
commit 1e004f1a83
29 changed files with 1859 additions and 0 deletions

View File

@@ -0,0 +1,58 @@
package com.bruce.agent.controller;
import com.bruce.agent.dto.request.AgentChatRequest;
import com.bruce.agent.dto.request.AgentDefinitionQueryRequest;
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
import com.bruce.agent.dto.response.AgentChatResponse;
import com.bruce.agent.dto.response.AgentDefinitionResponse;
import com.bruce.agent.service.IAgentDefinitionService;
import com.bruce.common.domain.model.RequestResult;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@RestController
@RequestMapping("/api/agents")
@RequiredArgsConstructor
public class AgentDefinitionController {
private final IAgentDefinitionService agentDefinitionService;
@PostMapping("/list")
public RequestResult<List<AgentDefinitionResponse>> list() {
return RequestResult.success(agentDefinitionService.listResponses());
}
@PostMapping("/query")
public RequestResult<List<AgentDefinitionResponse>> query(@RequestBody(required = false) AgentDefinitionQueryRequest request) {
return RequestResult.success(agentDefinitionService.query(request));
}
@GetMapping("/detail")
public RequestResult<AgentDefinitionResponse> detail(@RequestParam("id") Long id) {
return RequestResult.success(agentDefinitionService.getResponseById(id));
}
@PostMapping("/save")
public RequestResult<Boolean> save(@RequestBody AgentDefinitionSaveRequest request) {
return RequestResult.success(agentDefinitionService.saveOrUpdate(request));
}
@PostMapping("/delete")
public RequestResult<Boolean> delete(@RequestParam("id") Long id) {
return RequestResult.success(agentDefinitionService.removeById(id));
}
@PostMapping("/{agentId}/chat")
public RequestResult<AgentChatResponse> chat(@PathVariable("agentId") Long agentId,
@RequestBody AgentChatRequest request) {
return RequestResult.success(agentDefinitionService.chat(agentId, request));
}
}

View File

@@ -0,0 +1,17 @@
package com.bruce.agent.dto.request;
import lombok.Data;
import java.util.List;
@Data
public class AgentChatRequest {
private List<AgentMessage> messages;
private Boolean ragEnabled;
@Data
public static class AgentMessage {
private String role;
private String content;
}
}

View File

@@ -0,0 +1,11 @@
package com.bruce.agent.dto.request;
import lombok.Data;
@Data
public class AgentDefinitionQueryRequest {
private String agentCode;
private String agentName;
private String status;
private Long storeId;
}

View File

@@ -0,0 +1,14 @@
package com.bruce.agent.dto.request;
import lombok.Data;
@Data
public class AgentDefinitionSaveRequest {
private Long id;
private String agentCode;
private String agentName;
private String systemPrompt;
private Long storeId;
private String status;
private String remark;
}

View File

@@ -0,0 +1,25 @@
package com.bruce.agent.dto.response;
import lombok.Data;
import java.util.List;
@Data
public class AgentChatResponse {
private Long agentId;
private String agentCode;
private String agentName;
private Long storeId;
private String storeName;
private String answer;
private String modelRequestId;
private List<ReferenceChunk> references;
@Data
public static class ReferenceChunk {
private Long chunkId;
private Long documentId;
private String chunkContent;
private Double score;
}
}

View File

@@ -0,0 +1,25 @@
package com.bruce.agent.dto.response;
import com.bruce.agent.entity.AgentDefinition;
import lombok.Data;
import org.springframework.beans.BeanUtils;
@Data
public class AgentDefinitionResponse {
private Long id;
private String agentCode;
private String agentName;
private String systemPrompt;
private Long storeId;
private String status;
private String remark;
public static AgentDefinitionResponse fromEntity(AgentDefinition entity) {
if (entity == null) {
return null;
}
AgentDefinitionResponse response = new AgentDefinitionResponse();
BeanUtils.copyProperties(entity, response);
return response;
}
}

View File

@@ -0,0 +1,29 @@
package com.bruce.agent.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import com.bruce.common.domain.model.BaseEntity;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
@TableName("agent_definition")
public class AgentDefinition extends BaseEntity {
@TableField("agent_code")
private String agentCode;
@TableField("agent_name")
private String agentName;
@TableField("system_prompt")
private String systemPrompt;
@TableField("store_id")
private Long storeId;
private String status;
private String remark;
}

View File

@@ -0,0 +1,9 @@
package com.bruce.agent.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.bruce.agent.entity.AgentDefinition;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface AgentDefinitionMapper extends BaseMapper<AgentDefinition> {
}

View File

@@ -0,0 +1,23 @@
package com.bruce.agent.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.bruce.agent.dto.request.AgentChatRequest;
import com.bruce.agent.dto.request.AgentDefinitionQueryRequest;
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
import com.bruce.agent.dto.response.AgentChatResponse;
import com.bruce.agent.dto.response.AgentDefinitionResponse;
import com.bruce.agent.entity.AgentDefinition;
import java.util.List;
public interface IAgentDefinitionService extends IService<AgentDefinition> {
List<AgentDefinitionResponse> listResponses();
List<AgentDefinitionResponse> query(AgentDefinitionQueryRequest request);
AgentDefinitionResponse getResponseById(Long id);
boolean saveOrUpdate(AgentDefinitionSaveRequest request);
AgentChatResponse chat(Long agentId, AgentChatRequest request);
}

View File

@@ -0,0 +1,304 @@
package com.bruce.agent.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.bruce.agent.dto.request.AgentChatRequest;
import com.bruce.agent.dto.request.AgentDefinitionQueryRequest;
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
import com.bruce.agent.dto.response.AgentChatResponse;
import com.bruce.agent.dto.response.AgentDefinitionResponse;
import com.bruce.agent.entity.AgentDefinition;
import com.bruce.agent.mapper.AgentDefinitionMapper;
import com.bruce.agent.service.IAgentDefinitionService;
import com.bruce.common.enums.EnableStatusEnum;
import com.bruce.modelprovider.client.OpenAiChatMessage;
import com.bruce.modelprovider.entity.RagStoreModelConfig;
import com.bruce.modelprovider.gateway.ChatModelGateway;
import com.bruce.modelprovider.gateway.ChatRequest;
import com.bruce.modelprovider.gateway.ChatResult;
import com.bruce.modelprovider.gateway.EmbeddingModelGateway;
import com.bruce.modelprovider.gateway.EmbeddingRequest;
import com.bruce.modelprovider.gateway.EmbeddingResult;
import com.bruce.modelprovider.service.IRagStoreModelConfigService;
import com.bruce.rag.dto.response.RagChunkRecallResponse;
import com.bruce.rag.entity.RagStore;
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
import com.bruce.rag.service.IRagStoreService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.ArrayList;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
public class AgentDefinitionServiceImpl extends ServiceImpl<AgentDefinitionMapper, AgentDefinition>
implements IAgentDefinitionService {
private static final int DEFAULT_TOP_K = 5;
private final IRagStoreService ragStoreService;
private final IRagStoreModelConfigService ragStoreModelConfigService;
private final RagChunkEmbeddingMapper ragChunkEmbeddingMapper;
private final EmbeddingModelGateway embeddingModelGateway;
private final ChatModelGateway chatModelGateway;
@Override
public List<AgentDefinitionResponse> listResponses() {
return lambdaQuery()
.orderByAsc(AgentDefinition::getAgentCode)
.list()
.stream()
.map(AgentDefinitionResponse::fromEntity)
.toList();
}
@Override
public List<AgentDefinitionResponse> query(AgentDefinitionQueryRequest request) {
AgentDefinitionQueryRequest queryRequest = request == null ? new AgentDefinitionQueryRequest() : request;
return lambdaQuery()
.eq(StringUtils.hasText(queryRequest.getAgentCode()), AgentDefinition::getAgentCode, trimToNull(queryRequest.getAgentCode()))
.like(StringUtils.hasText(queryRequest.getAgentName()), AgentDefinition::getAgentName, trimToNull(queryRequest.getAgentName()))
.eq(StringUtils.hasText(queryRequest.getStatus()), AgentDefinition::getStatus, trimToNull(queryRequest.getStatus()))
.eq(queryRequest.getStoreId() != null, AgentDefinition::getStoreId, queryRequest.getStoreId())
.orderByAsc(AgentDefinition::getAgentCode)
.list()
.stream()
.map(AgentDefinitionResponse::fromEntity)
.toList();
}
@Override
public AgentDefinitionResponse getResponseById(Long id) {
return AgentDefinitionResponse.fromEntity(getById(id));
}
@Override
public boolean saveOrUpdate(AgentDefinitionSaveRequest request) {
validateSaveRequest(request);
if (ragStoreService.getById(request.getStoreId()) == null) {
throw new IllegalArgumentException("绑定知识库不存在ID: " + request.getStoreId());
}
AgentDefinition duplicate = lambdaQuery()
.eq(AgentDefinition::getAgentCode, request.getAgentCode().trim())
.ne(request.getId() != null, AgentDefinition::getId, request.getId())
.one();
if (duplicate != null) {
throw new IllegalArgumentException("Agent编码已存在: " + request.getAgentCode().trim());
}
AgentDefinition entity = request.getId() == null ? new AgentDefinition() : getById(request.getId());
if (entity == null) {
throw new IllegalArgumentException("Agent不存在ID: " + request.getId());
}
entity.setAgentCode(request.getAgentCode().trim());
entity.setAgentName(request.getAgentName().trim());
entity.setSystemPrompt(trimToNull(request.getSystemPrompt()));
entity.setStoreId(request.getStoreId());
entity.setStatus(StringUtils.hasText(request.getStatus())
? request.getStatus().trim()
: EnableStatusEnum.ENABLED.name());
entity.setRemark(trimToNull(request.getRemark()));
return request.getId() == null ? save(entity) : updateById(entity);
}
@Override
public AgentChatResponse chat(Long agentId, AgentChatRequest request) {
if (agentId == null) {
throw new IllegalArgumentException("Agent ID不能为空");
}
if (request == null || request.getMessages() == null || request.getMessages().isEmpty()) {
throw new IllegalArgumentException("对话消息不能为空");
}
AgentDefinition agent = getById(agentId);
if (agent == null) {
throw new IllegalArgumentException("Agent不存在ID: " + agentId);
}
if (!EnableStatusEnum.ENABLED.name().equals(agent.getStatus())) {
throw new IllegalArgumentException("Agent已停用暂不支持对话");
}
if (agent.getStoreId() == null) {
throw new IllegalArgumentException("Agent未绑定知识库请先保存知识库配置");
}
RagStore store = ragStoreService.getById(agent.getStoreId());
if (store == null) {
throw new IllegalArgumentException("绑定知识库不存在ID: " + agent.getStoreId());
}
String queryText = resolveLatestUserMessage(request.getMessages());
boolean ragEnabled = request.getRagEnabled() == null || request.getRagEnabled();
List<RagChunkRecallResponse> recalls = List.of();
if (ragEnabled) {
RagStoreModelConfig storeModelConfig = ragStoreModelConfigService.getActiveEntity(agent.getStoreId());
if (storeModelConfig == null || storeModelConfig.getEmbeddingModelId() == null) {
throw new IllegalArgumentException("当前知识库未配置Embedding模型无法执行检索对话");
}
EmbeddingRequest embeddingRequest = new EmbeddingRequest();
embeddingRequest.setTexts(List.of(queryText));
embeddingRequest.setTaskType("RAG_QUERY_EMBEDDING");
embeddingRequest.setMatchScope("RAG_STORE");
embeddingRequest.setScopeId(agent.getStoreId());
embeddingRequest.setBizType("AGENT_CHAT");
embeddingRequest.setBizId(String.valueOf(agentId));
embeddingRequest.setExpectedDimension(storeModelConfig.getEmbeddingDimension());
EmbeddingResult queryEmbedding = embeddingModelGateway.embed(embeddingRequest);
if (queryEmbedding.getVectors() == null || queryEmbedding.getVectors().isEmpty()) {
throw new IllegalArgumentException("查询向量生成失败请检查Embedding模型配置");
}
String queryVector = toVectorLiteral(queryEmbedding.getVectors().getFirst());
recalls = ragChunkEmbeddingMapper.queryTopKByStore(
agent.getStoreId(),
queryVector,
DEFAULT_TOP_K
);
if (recalls.isEmpty()) {
throw new IllegalArgumentException("未召回到可用知识切片,请先完成知识库切片与向量化");
}
}
ChatRequest chatRequest = new ChatRequest();
chatRequest.setTaskType(ragEnabled ? "RAG_ANSWER" : "CHAT_SIMPLE");
chatRequest.setMatchScope("AGENT");
chatRequest.setScopeId(agentId);
chatRequest.setBizType("AGENT_CHAT");
chatRequest.setBizId(String.valueOf(agentId));
chatRequest.setMessages(buildChatMessages(agent, recalls, request.getMessages(), ragEnabled));
ChatResult chatResult = chatModelGateway.chat(chatRequest);
AgentChatResponse response = new AgentChatResponse();
response.setAgentId(agent.getId());
response.setAgentCode(agent.getAgentCode());
response.setAgentName(agent.getAgentName());
response.setStoreId(agent.getStoreId());
response.setStoreName(store.getStoreName());
response.setAnswer(chatResult.getContent());
response.setModelRequestId(chatResult.getCallLog().getRequestId());
response.setReferences(toReferenceChunks(recalls));
return response;
}
private void validateSaveRequest(AgentDefinitionSaveRequest request) {
if (request == null) {
throw new IllegalArgumentException("保存请求不能为空");
}
if (!StringUtils.hasText(request.getAgentCode())) {
throw new IllegalArgumentException("Agent编码不能为空");
}
if (!StringUtils.hasText(request.getAgentName())) {
throw new IllegalArgumentException("Agent名称不能为空");
}
if (request.getStoreId() == null) {
throw new IllegalArgumentException("绑定知识库不能为空");
}
}
private String resolveLatestUserMessage(List<AgentChatRequest.AgentMessage> messages) {
for (int index = messages.size() - 1; index >= 0; index--) {
AgentChatRequest.AgentMessage message = messages.get(index);
if (message != null
&& "user".equalsIgnoreCase(message.getRole())
&& StringUtils.hasText(message.getContent())) {
return message.getContent();
}
}
throw new IllegalArgumentException("缺少用户提问内容");
}
private List<OpenAiChatMessage> buildChatMessages(AgentDefinition agent,
List<RagChunkRecallResponse> recalls,
List<AgentChatRequest.AgentMessage> rawMessages,
boolean ragEnabled) {
List<OpenAiChatMessage> messages = new ArrayList<>();
OpenAiChatMessage instructionMessage = new OpenAiChatMessage();
instructionMessage.setRole("system");
instructionMessage.setContent(buildSystemInstruction(agent));
messages.add(instructionMessage);
if (ragEnabled) {
OpenAiChatMessage contextMessage = new OpenAiChatMessage();
contextMessage.setRole("system");
contextMessage.setContent(buildContextText(recalls));
messages.add(contextMessage);
}
for (AgentChatRequest.AgentMessage rawMessage : rawMessages) {
if (rawMessage == null || !StringUtils.hasText(rawMessage.getContent())) {
continue;
}
OpenAiChatMessage message = new OpenAiChatMessage();
message.setRole(normalizeRole(rawMessage.getRole()));
message.setContent(rawMessage.getContent());
messages.add(message);
}
return messages;
}
private String buildSystemInstruction(AgentDefinition agent) {
StringBuilder builder = new StringBuilder();
if (StringUtils.hasText(agent.getSystemPrompt())) {
builder.append(agent.getSystemPrompt().trim()).append("\n\n");
}
builder.append("请优先基于已给出的知识库引用片段回答。");
builder.append("如果引用无法支持结论,请明确告知“知识库中暂无直接依据”。");
return builder.toString();
}
private String buildContextText(List<RagChunkRecallResponse> recalls) {
StringBuilder builder = new StringBuilder("以下是知识库召回片段:\n");
for (int i = 0; i < recalls.size(); i++) {
RagChunkRecallResponse recall = recalls.get(i);
builder.append(i + 1)
.append(". [chunkId=")
.append(recall.getChunkId())
.append(", score=")
.append(String.format("%.4f", recall.getScore() == null ? 0D : recall.getScore()))
.append("] ")
.append(recall.getChunkContent())
.append("\n");
}
return builder.toString();
}
private List<AgentChatResponse.ReferenceChunk> toReferenceChunks(List<RagChunkRecallResponse> recalls) {
return recalls.stream().map(recall -> {
AgentChatResponse.ReferenceChunk chunk = new AgentChatResponse.ReferenceChunk();
chunk.setChunkId(recall.getChunkId());
chunk.setDocumentId(recall.getDocumentId());
chunk.setChunkContent(recall.getChunkContent());
chunk.setScore(recall.getScore());
return chunk;
}).toList();
}
private String normalizeRole(String role) {
if (!StringUtils.hasText(role)) {
return "user";
}
String normalized = role.trim().toLowerCase();
if ("system".equals(normalized) || "assistant".equals(normalized) || "user".equals(normalized)) {
return normalized;
}
return "user";
}
private String toVectorLiteral(List<Double> vector) {
StringBuilder builder = new StringBuilder("[");
for (int index = 0; index < vector.size(); index++) {
if (index > 0) {
builder.append(',');
}
builder.append(vector.get(index));
}
builder.append(']');
return builder.toString();
}
private String trimToNull(String value) {
if (!StringUtils.hasText(value)) {
return null;
}
return value.trim();
}
}

View File

@@ -0,0 +1,12 @@
package com.bruce.modelprovider.client;
import lombok.Data;
@Data
public class OpenAiChatCompletionResult {
private String upstreamRequestId;
private String content;
private Integer promptTokens;
private Integer completionTokens;
private Integer totalTokens;
}

View File

@@ -0,0 +1,9 @@
package com.bruce.modelprovider.client;
import lombok.Data;
@Data
public class OpenAiChatMessage {
private String role;
private String content;
}

View File

@@ -0,0 +1,18 @@
package com.bruce.modelprovider.config;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.PropertySource;
/**
* 加载独立 AI 配置文件。
* <p>
* 说明:
* 1. 该文件使用 INI 扩展名,但内容采用 key=value 形式Spring 可直接按 Properties 解析;
* 2. ignoreResourceNotFound=true允许某些环境不提供该文件避免启动失败
* 3. 具体键值由 {@link AiSecretProperties} 统一绑定与读取。
*/
@Configuration
@PropertySource(value = "classpath:ai-config.ini", ignoreResourceNotFound = true)
public class AiConfigFilePropertySourceConfig {
}

View File

@@ -0,0 +1,41 @@
package com.bruce.modelprovider.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import java.util.HashMap;
import java.util.Map;
/**
* AI 密钥配置绑定。
* <p>
* 支持从 ai-config.ini 读取如下配置:
* ai.secret-refs[SILICONFLOW_API_KEY]=your-key
*/
@Data
@Component
@ConfigurationProperties(prefix = "ai")
public class AiSecretProperties {
/**
* key 为 secretRef例如 SILICONFLOW_API_KEYvalue 为实际密钥。
*/
private Map<String, String> secretRefs = new HashMap<>();
/**
* 根据 secretRef 获取配置文件中的密钥,并做空白清理。
*/
public String getApiKeyBySecretRef(String secretRef) {
if (!StringUtils.hasText(secretRef)) {
return null;
}
String value = secretRefs.get(secretRef.trim());
if (!StringUtils.hasText(value)) {
return null;
}
return value.trim();
}
}

View File

@@ -0,0 +1,5 @@
package com.bruce.modelprovider.gateway;
public interface ChatModelGateway {
ChatResult chat(ChatRequest request);
}

View File

@@ -0,0 +1,135 @@
package com.bruce.modelprovider.gateway;
import com.bruce.modelprovider.client.OpenAiChatCompletionResult;
import com.bruce.modelprovider.client.OpenAiChatMessage;
import com.bruce.modelprovider.client.OpenAiCompatibleModelClient;
import com.bruce.modelprovider.entity.ModelCallLog;
import com.bruce.modelprovider.entity.ModelConfig;
import com.bruce.modelprovider.entity.ModelProvider;
import com.bruce.modelprovider.enums.ModelCallStatusEnum;
import com.bruce.modelprovider.route.ModelRouteContext;
import com.bruce.modelprovider.route.ModelRouteDecision;
import com.bruce.modelprovider.service.IModelCallLogService;
import com.bruce.modelprovider.service.IModelProviderService;
import com.bruce.modelprovider.service.IModelRouteService;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import org.springframework.util.DigestUtils;
import org.springframework.util.StringUtils;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.UUID;
@Component
@RequiredArgsConstructor
public class ChatModelGatewayImpl implements ChatModelGateway {
private final IModelRouteService modelRouteService;
private final IModelProviderService modelProviderService;
private final IModelCallLogService modelCallLogService;
private final OpenAiCompatibleModelClient openAiCompatibleModelClient;
@Override
public ChatResult chat(ChatRequest request) {
if (request == null || request.getMessages() == null || request.getMessages().isEmpty()) {
throw new IllegalArgumentException("聊天请求不能为空");
}
long start = System.currentTimeMillis();
ModelCallLog callLog = new ModelCallLog();
callLog.setRequestId(UUID.randomUUID().toString().replace("-", ""));
callLog.setTaskType(request.getTaskType());
callLog.setBizType(request.getBizType());
callLog.setBizId(request.getBizId());
callLog.setCallType("CHAT");
callLog.setRequestHash(buildRequestHash(request.getMessages()));
try {
ModelRouteContext routeContext = new ModelRouteContext();
routeContext.setTaskType(request.getTaskType());
routeContext.setMatchScope(request.getMatchScope());
routeContext.setScopeId(request.getScopeId());
routeContext.setRequiredModelType("CHAT");
routeContext.setBizType(request.getBizType());
routeContext.setBizId(request.getBizId());
ModelRouteDecision decision = modelRouteService.route(routeContext);
ModelCallExecution execution = executeWithFallback(
decision.getPrimaryModel(),
decision.getFallbackModels(),
request.getMessages()
);
callLog.setProviderId(execution.provider().getId());
callLog.setModelId(execution.model().getId());
callLog.setStatus(ModelCallStatusEnum.SUCCESS.name());
callLog.setPromptTokens(execution.result().getPromptTokens());
callLog.setCompletionTokens(execution.result().getCompletionTokens());
callLog.setTotalTokens(execution.result().getTotalTokens());
callLog.setDurationMs((int) (System.currentTimeMillis() - start));
modelCallLogService.save(callLog);
ChatResult result = new ChatResult();
result.setModelId(execution.model().getId());
result.setModelName(execution.model().getModelName());
result.setContent(execution.result().getContent());
result.setUpstreamRequestId(execution.result().getUpstreamRequestId());
result.setPromptTokens(execution.result().getPromptTokens());
result.setCompletionTokens(execution.result().getCompletionTokens());
result.setTotalTokens(execution.result().getTotalTokens());
result.setCallLog(callLog);
return result;
} catch (Exception ex) {
callLog.setStatus(ModelCallStatusEnum.FAILED.name());
callLog.setDurationMs((int) (System.currentTimeMillis() - start));
callLog.setErrorCode("CHAT_COMPLETION_FAILED");
String message = ex.getMessage();
callLog.setErrorMessage(message == null ? "unknown" : message.substring(0, Math.min(message.length(), 1000)));
modelCallLogService.save(callLog);
throw ex;
}
}
private ModelCallExecution executeWithFallback(ModelConfig primaryModel,
List<ModelConfig> fallbackModels,
List<OpenAiChatMessage> messages) {
ModelProvider primaryProvider = requireAvailableProvider(primaryModel.getProviderId());
try {
OpenAiChatCompletionResult result = openAiCompatibleModelClient.chatCompletions(primaryProvider, primaryModel, messages);
return new ModelCallExecution(primaryProvider, primaryModel, result);
} catch (Exception primaryEx) {
for (ModelConfig fallbackModel : fallbackModels) {
try {
ModelProvider fallbackProvider = requireAvailableProvider(fallbackModel.getProviderId());
OpenAiChatCompletionResult result = openAiCompatibleModelClient.chatCompletions(
fallbackProvider,
fallbackModel,
messages
);
return new ModelCallExecution(fallbackProvider, fallbackModel, result);
} catch (Exception ignored) {
// continue fallback chain
}
}
throw primaryEx;
}
}
private ModelProvider requireAvailableProvider(Long providerId) {
ModelProvider provider = modelProviderService.getById(providerId);
if (provider == null || !Boolean.TRUE.equals(provider.getEnabled())) {
throw new IllegalStateException("模型服务商不可用");
}
return provider;
}
private String buildRequestHash(List<OpenAiChatMessage> messages) {
String plainText = messages.stream()
.map(message -> (StringUtils.hasText(message.getRole()) ? message.getRole() : "user") + ":" + message.getContent())
.reduce((left, right) -> left + "|" + right)
.orElse("");
return DigestUtils.md5DigestAsHex(plainText.getBytes(StandardCharsets.UTF_8));
}
private record ModelCallExecution(ModelProvider provider, ModelConfig model, OpenAiChatCompletionResult result) {
}
}

View File

@@ -0,0 +1,16 @@
package com.bruce.modelprovider.gateway;
import com.bruce.modelprovider.client.OpenAiChatMessage;
import lombok.Data;
import java.util.List;
@Data
public class ChatRequest {
private List<OpenAiChatMessage> messages;
private String taskType;
private String matchScope;
private Long scopeId;
private String bizType;
private String bizId;
}

View File

@@ -0,0 +1,16 @@
package com.bruce.modelprovider.gateway;
import com.bruce.modelprovider.entity.ModelCallLog;
import lombok.Data;
@Data
public class ChatResult {
private Long modelId;
private String modelName;
private String content;
private String upstreamRequestId;
private Integer promptTokens;
private Integer completionTokens;
private Integer totalTokens;
private ModelCallLog callLog;
}

View File

@@ -0,0 +1,11 @@
package com.bruce.rag.dto.response;
import lombok.Data;
@Data
public class RagChunkRecallResponse {
private Long chunkId;
private Long documentId;
private String chunkContent;
private Double score;
}

View File

@@ -0,0 +1,5 @@
# AI 独立配置文件(建议仅本地/环境覆盖使用,不提交真实密钥)
# 格式ai.secret-refs[<secret_ref>]=<api_key>
ai.secret-refs[SILICONFLOW_API_KEY]=your-key

View File

@@ -0,0 +1,63 @@
package com.bruce.agent;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.extension.service.IService;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.bruce.agent.controller.AgentDefinitionController;
import com.bruce.agent.dto.request.AgentChatRequest;
import com.bruce.agent.dto.request.AgentDefinitionQueryRequest;
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
import com.bruce.agent.dto.response.AgentChatResponse;
import com.bruce.agent.dto.response.AgentDefinitionResponse;
import com.bruce.agent.entity.AgentDefinition;
import com.bruce.agent.mapper.AgentDefinitionMapper;
import com.bruce.agent.service.IAgentDefinitionService;
import com.bruce.agent.service.impl.AgentDefinitionServiceImpl;
import com.bruce.common.domain.model.RequestResult;
import org.junit.jupiter.api.Test;
import java.lang.reflect.Method;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
class AgentComponentStructureTests {
@Test
void agentComponentsShouldReuseMybatisPlusBaseTypes() {
assertTrue(BaseMapper.class.isAssignableFrom(AgentDefinitionMapper.class));
assertTrue(IService.class.isAssignableFrom(IAgentDefinitionService.class));
assertTrue(ServiceImpl.class.isAssignableFrom(AgentDefinitionServiceImpl.class));
}
@Test
void agentControllerShouldExposeRequestResultMethods() throws NoSuchMethodException {
Method listMethod = AgentDefinitionController.class.getMethod("list");
Method queryMethod = AgentDefinitionController.class.getMethod("query", AgentDefinitionQueryRequest.class);
Method detailMethod = AgentDefinitionController.class.getMethod("detail", Long.class);
Method saveMethod = AgentDefinitionController.class.getMethod("save", AgentDefinitionSaveRequest.class);
Method deleteMethod = AgentDefinitionController.class.getMethod("delete", Long.class);
Method chatMethod = AgentDefinitionController.class.getMethod("chat", Long.class, AgentChatRequest.class);
Method listServiceMethod = IAgentDefinitionService.class.getMethod("listResponses");
Method queryServiceMethod = IAgentDefinitionService.class.getMethod("query", AgentDefinitionQueryRequest.class);
Method detailServiceMethod = IAgentDefinitionService.class.getMethod("getResponseById", Long.class);
Method saveServiceMethod = IAgentDefinitionService.class.getMethod("saveOrUpdate", AgentDefinitionSaveRequest.class);
Method chatServiceMethod = IAgentDefinitionService.class.getMethod("chat", Long.class, AgentChatRequest.class);
assertEquals(RequestResult.class, listMethod.getReturnType());
assertEquals(RequestResult.class, queryMethod.getReturnType());
assertEquals(RequestResult.class, detailMethod.getReturnType());
assertEquals(RequestResult.class, saveMethod.getReturnType());
assertEquals(RequestResult.class, deleteMethod.getReturnType());
assertEquals(RequestResult.class, chatMethod.getReturnType());
assertEquals(List.class, listServiceMethod.getReturnType());
assertEquals(List.class, queryServiceMethod.getReturnType());
assertEquals(AgentDefinitionResponse.class, detailServiceMethod.getReturnType());
assertEquals(boolean.class, saveServiceMethod.getReturnType());
assertEquals(AgentChatResponse.class, chatServiceMethod.getReturnType());
assertEquals(AgentDefinitionResponse.class, AgentDefinitionResponse.class.getMethod("fromEntity", AgentDefinition.class).getReturnType());
}
}

View File

@@ -0,0 +1,221 @@
package com.bruce.agent;
import com.bruce.agent.dto.request.AgentChatRequest;
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
import com.bruce.agent.dto.response.AgentChatResponse;
import com.bruce.agent.entity.AgentDefinition;
import com.bruce.agent.service.impl.AgentDefinitionServiceImpl;
import com.bruce.modelprovider.entity.ModelCallLog;
import com.bruce.modelprovider.entity.RagStoreModelConfig;
import com.bruce.modelprovider.gateway.ChatRequest;
import com.bruce.modelprovider.gateway.ChatResult;
import com.bruce.modelprovider.gateway.EmbeddingRequest;
import com.bruce.modelprovider.gateway.EmbeddingResult;
import com.bruce.modelprovider.service.IRagStoreModelConfigService;
import com.bruce.rag.dto.response.RagChunkRecallResponse;
import com.bruce.rag.entity.RagStore;
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
import com.bruce.rag.service.IRagStoreService;
import com.bruce.modelprovider.gateway.ChatModelGateway;
import com.bruce.modelprovider.gateway.EmbeddingModelGateway;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class AgentDefinitionServiceImplTests {
@Spy
@InjectMocks
private AgentDefinitionServiceImpl agentDefinitionService;
@Mock
private IRagStoreService ragStoreService;
@Mock
private IRagStoreModelConfigService ragStoreModelConfigService;
@Mock
private RagChunkEmbeddingMapper ragChunkEmbeddingMapper;
@Mock
private EmbeddingModelGateway embeddingModelGateway;
@Mock
private ChatModelGateway chatModelGateway;
@Test
void saveOrUpdateShouldValidateBoundStoreExists() {
AgentDefinitionSaveRequest request = new AgentDefinitionSaveRequest();
request.setAgentCode("A_1");
request.setAgentName("Agent 1");
request.setStoreId(1001L);
when(ragStoreService.getById(1001L)).thenReturn(null);
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.saveOrUpdate(request));
assertTrue(exception.getMessage().contains("绑定知识库不存在"));
}
@Test
void chatShouldRejectDisabledAgent() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setStoreId(2001L);
agent.setStatus("DISABLED");
doReturn(agent).when(agentDefinitionService).getById(1001L);
AgentChatRequest request = new AgentChatRequest();
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("你好");
request.setMessages(List.of(message));
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
assertTrue(exception.getMessage().contains("停用"));
}
@Test
void chatShouldRejectAgentWithoutStore() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setStatus("ENABLED");
agent.setStoreId(null);
doReturn(agent).when(agentDefinitionService).getById(1001L);
AgentChatRequest request = new AgentChatRequest();
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("你好");
request.setMessages(List.of(message));
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
assertTrue(exception.getMessage().contains("未绑定知识库"));
}
@Test
void chatShouldUseStoreScopedRecallAndReturnAnswer() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setAgentCode("AGENT_1");
agent.setAgentName("知识助手");
agent.setSystemPrompt("你是企业知识助手");
agent.setStoreId(2001L);
agent.setStatus("ENABLED");
doReturn(agent).when(agentDefinitionService).getById(1001L);
RagStore store = new RagStore();
store.setId(2001L);
store.setStoreName("企业知识库");
when(ragStoreService.getById(2001L)).thenReturn(store);
RagStoreModelConfig modelConfig = new RagStoreModelConfig();
modelConfig.setStoreId(2001L);
modelConfig.setEmbeddingModelId(3001L);
modelConfig.setEmbeddingDimension(1024);
when(ragStoreModelConfigService.getActiveEntity(2001L)).thenReturn(modelConfig);
EmbeddingResult embeddingResult = new EmbeddingResult();
embeddingResult.setVectors(List.of(List.of(0.12, 0.34, 0.56)));
when(embeddingModelGateway.embed(any(EmbeddingRequest.class))).thenReturn(embeddingResult);
RagChunkRecallResponse recall = new RagChunkRecallResponse();
recall.setChunkId(4001L);
recall.setDocumentId(5001L);
recall.setChunkContent("公司请假流程:先提交审批单。");
recall.setScore(0.91);
when(ragChunkEmbeddingMapper.queryTopKByStore(anyLong(), anyString(), anyInt()))
.thenReturn(List.of(recall));
ModelCallLog callLog = new ModelCallLog();
callLog.setRequestId("req_001");
ChatResult chatResult = new ChatResult();
chatResult.setContent("根据知识库先在OA提交请假审批。");
chatResult.setCallLog(callLog);
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
AgentChatRequest request = new AgentChatRequest();
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("请假流程是什么?");
request.setMessages(List.of(message));
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
assertEquals(1001L, response.getAgentId());
assertEquals(2001L, response.getStoreId());
assertEquals("企业知识库", response.getStoreName());
assertEquals("根据知识库先在OA提交请假审批。", response.getAnswer());
assertEquals("req_001", response.getModelRequestId());
assertEquals(1, response.getReferences().size());
assertEquals(4001L, response.getReferences().getFirst().getChunkId());
ArgumentCaptor<EmbeddingRequest> embeddingRequestCaptor = ArgumentCaptor.forClass(EmbeddingRequest.class);
verify(embeddingModelGateway).embed(embeddingRequestCaptor.capture());
EmbeddingRequest embeddingRequest = embeddingRequestCaptor.getValue();
assertEquals("RAG_QUERY_EMBEDDING", embeddingRequest.getTaskType());
assertEquals("RAG_STORE", embeddingRequest.getMatchScope());
assertEquals(2001L, embeddingRequest.getScopeId());
verify(ragChunkEmbeddingMapper).queryTopKByStore(anyLong(), anyString(), anyInt());
}
@Test
void chatShouldSupportSimpleModeWithoutRagRecall() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setAgentCode("AGENT_1");
agent.setAgentName("知识助手");
agent.setStoreId(2001L);
agent.setStatus("ENABLED");
doReturn(agent).when(agentDefinitionService).getById(1001L);
RagStore store = new RagStore();
store.setId(2001L);
store.setStoreName("企业知识库");
when(ragStoreService.getById(2001L)).thenReturn(store);
ModelCallLog callLog = new ModelCallLog();
callLog.setRequestId("req_simple_001");
ChatResult chatResult = new ChatResult();
chatResult.setContent("这是普通对话回答。");
chatResult.setCallLog(callLog);
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
AgentChatRequest request = new AgentChatRequest();
request.setRagEnabled(false);
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("直接聊聊今天安排");
request.setMessages(List.of(message));
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
assertEquals("这是普通对话回答。", response.getAnswer());
assertTrue(response.getReferences().isEmpty());
verify(embeddingModelGateway, never()).embed(any(EmbeddingRequest.class));
verify(ragChunkEmbeddingMapper, never()).queryTopKByStore(anyLong(), anyString(), anyInt());
ArgumentCaptor<ChatRequest> chatRequestCaptor = ArgumentCaptor.forClass(ChatRequest.class);
verify(chatModelGateway).chat(chatRequestCaptor.capture());
assertEquals("CHAT_SIMPLE", chatRequestCaptor.getValue().getTaskType());
}
}