feat(agent): 接入Agent调试与RAG召回链路

This commit is contained in:
2026-05-31 23:51:55 +08:00
parent 21c9eaa44d
commit 1e004f1a83
29 changed files with 1859 additions and 0 deletions

View File

@@ -0,0 +1,221 @@
package com.bruce.agent;
import com.bruce.agent.dto.request.AgentChatRequest;
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
import com.bruce.agent.dto.response.AgentChatResponse;
import com.bruce.agent.entity.AgentDefinition;
import com.bruce.agent.service.impl.AgentDefinitionServiceImpl;
import com.bruce.modelprovider.entity.ModelCallLog;
import com.bruce.modelprovider.entity.RagStoreModelConfig;
import com.bruce.modelprovider.gateway.ChatRequest;
import com.bruce.modelprovider.gateway.ChatResult;
import com.bruce.modelprovider.gateway.EmbeddingRequest;
import com.bruce.modelprovider.gateway.EmbeddingResult;
import com.bruce.modelprovider.service.IRagStoreModelConfigService;
import com.bruce.rag.dto.response.RagChunkRecallResponse;
import com.bruce.rag.entity.RagStore;
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
import com.bruce.rag.service.IRagStoreService;
import com.bruce.modelprovider.gateway.ChatModelGateway;
import com.bruce.modelprovider.gateway.EmbeddingModelGateway;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class AgentDefinitionServiceImplTests {
@Spy
@InjectMocks
private AgentDefinitionServiceImpl agentDefinitionService;
@Mock
private IRagStoreService ragStoreService;
@Mock
private IRagStoreModelConfigService ragStoreModelConfigService;
@Mock
private RagChunkEmbeddingMapper ragChunkEmbeddingMapper;
@Mock
private EmbeddingModelGateway embeddingModelGateway;
@Mock
private ChatModelGateway chatModelGateway;
@Test
void saveOrUpdateShouldValidateBoundStoreExists() {
AgentDefinitionSaveRequest request = new AgentDefinitionSaveRequest();
request.setAgentCode("A_1");
request.setAgentName("Agent 1");
request.setStoreId(1001L);
when(ragStoreService.getById(1001L)).thenReturn(null);
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.saveOrUpdate(request));
assertTrue(exception.getMessage().contains("绑定知识库不存在"));
}
@Test
void chatShouldRejectDisabledAgent() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setStoreId(2001L);
agent.setStatus("DISABLED");
doReturn(agent).when(agentDefinitionService).getById(1001L);
AgentChatRequest request = new AgentChatRequest();
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("你好");
request.setMessages(List.of(message));
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
assertTrue(exception.getMessage().contains("停用"));
}
@Test
void chatShouldRejectAgentWithoutStore() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setStatus("ENABLED");
agent.setStoreId(null);
doReturn(agent).when(agentDefinitionService).getById(1001L);
AgentChatRequest request = new AgentChatRequest();
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("你好");
request.setMessages(List.of(message));
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
assertTrue(exception.getMessage().contains("未绑定知识库"));
}
@Test
void chatShouldUseStoreScopedRecallAndReturnAnswer() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setAgentCode("AGENT_1");
agent.setAgentName("知识助手");
agent.setSystemPrompt("你是企业知识助手");
agent.setStoreId(2001L);
agent.setStatus("ENABLED");
doReturn(agent).when(agentDefinitionService).getById(1001L);
RagStore store = new RagStore();
store.setId(2001L);
store.setStoreName("企业知识库");
when(ragStoreService.getById(2001L)).thenReturn(store);
RagStoreModelConfig modelConfig = new RagStoreModelConfig();
modelConfig.setStoreId(2001L);
modelConfig.setEmbeddingModelId(3001L);
modelConfig.setEmbeddingDimension(1024);
when(ragStoreModelConfigService.getActiveEntity(2001L)).thenReturn(modelConfig);
EmbeddingResult embeddingResult = new EmbeddingResult();
embeddingResult.setVectors(List.of(List.of(0.12, 0.34, 0.56)));
when(embeddingModelGateway.embed(any(EmbeddingRequest.class))).thenReturn(embeddingResult);
RagChunkRecallResponse recall = new RagChunkRecallResponse();
recall.setChunkId(4001L);
recall.setDocumentId(5001L);
recall.setChunkContent("公司请假流程:先提交审批单。");
recall.setScore(0.91);
when(ragChunkEmbeddingMapper.queryTopKByStore(anyLong(), anyString(), anyInt()))
.thenReturn(List.of(recall));
ModelCallLog callLog = new ModelCallLog();
callLog.setRequestId("req_001");
ChatResult chatResult = new ChatResult();
chatResult.setContent("根据知识库先在OA提交请假审批。");
chatResult.setCallLog(callLog);
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
AgentChatRequest request = new AgentChatRequest();
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("请假流程是什么?");
request.setMessages(List.of(message));
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
assertEquals(1001L, response.getAgentId());
assertEquals(2001L, response.getStoreId());
assertEquals("企业知识库", response.getStoreName());
assertEquals("根据知识库先在OA提交请假审批。", response.getAnswer());
assertEquals("req_001", response.getModelRequestId());
assertEquals(1, response.getReferences().size());
assertEquals(4001L, response.getReferences().getFirst().getChunkId());
ArgumentCaptor<EmbeddingRequest> embeddingRequestCaptor = ArgumentCaptor.forClass(EmbeddingRequest.class);
verify(embeddingModelGateway).embed(embeddingRequestCaptor.capture());
EmbeddingRequest embeddingRequest = embeddingRequestCaptor.getValue();
assertEquals("RAG_QUERY_EMBEDDING", embeddingRequest.getTaskType());
assertEquals("RAG_STORE", embeddingRequest.getMatchScope());
assertEquals(2001L, embeddingRequest.getScopeId());
verify(ragChunkEmbeddingMapper).queryTopKByStore(anyLong(), anyString(), anyInt());
}
@Test
void chatShouldSupportSimpleModeWithoutRagRecall() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setAgentCode("AGENT_1");
agent.setAgentName("知识助手");
agent.setStoreId(2001L);
agent.setStatus("ENABLED");
doReturn(agent).when(agentDefinitionService).getById(1001L);
RagStore store = new RagStore();
store.setId(2001L);
store.setStoreName("企业知识库");
when(ragStoreService.getById(2001L)).thenReturn(store);
ModelCallLog callLog = new ModelCallLog();
callLog.setRequestId("req_simple_001");
ChatResult chatResult = new ChatResult();
chatResult.setContent("这是普通对话回答。");
chatResult.setCallLog(callLog);
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
AgentChatRequest request = new AgentChatRequest();
request.setRagEnabled(false);
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("直接聊聊今天安排");
request.setMessages(List.of(message));
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
assertEquals("这是普通对话回答。", response.getAnswer());
assertTrue(response.getReferences().isEmpty());
verify(embeddingModelGateway, never()).embed(any(EmbeddingRequest.class));
verify(ragChunkEmbeddingMapper, never()).queryTopKByStore(anyLong(), anyString(), anyInt());
ArgumentCaptor<ChatRequest> chatRequestCaptor = ArgumentCaptor.forClass(ChatRequest.class);
verify(chatModelGateway).chat(chatRequestCaptor.capture());
assertEquals("CHAT_SIMPLE", chatRequestCaptor.getValue().getTaskType());
}
}