feat(agent): 接入Agent调试与RAG召回链路
This commit is contained in:
@@ -0,0 +1,221 @@
|
||||
package com.bruce.agent;
|
||||
|
||||
import com.bruce.agent.dto.request.AgentChatRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
|
||||
import com.bruce.agent.dto.response.AgentChatResponse;
|
||||
import com.bruce.agent.entity.AgentDefinition;
|
||||
import com.bruce.agent.service.impl.AgentDefinitionServiceImpl;
|
||||
import com.bruce.modelprovider.entity.ModelCallLog;
|
||||
import com.bruce.modelprovider.entity.RagStoreModelConfig;
|
||||
import com.bruce.modelprovider.gateway.ChatRequest;
|
||||
import com.bruce.modelprovider.gateway.ChatResult;
|
||||
import com.bruce.modelprovider.gateway.EmbeddingRequest;
|
||||
import com.bruce.modelprovider.gateway.EmbeddingResult;
|
||||
import com.bruce.modelprovider.service.IRagStoreModelConfigService;
|
||||
import com.bruce.rag.dto.response.RagChunkRecallResponse;
|
||||
import com.bruce.rag.entity.RagStore;
|
||||
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
|
||||
import com.bruce.rag.service.IRagStoreService;
|
||||
import com.bruce.modelprovider.gateway.ChatModelGateway;
|
||||
import com.bruce.modelprovider.gateway.EmbeddingModelGateway;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.Spy;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class AgentDefinitionServiceImplTests {
|
||||
|
||||
@Spy
|
||||
@InjectMocks
|
||||
private AgentDefinitionServiceImpl agentDefinitionService;
|
||||
|
||||
@Mock
|
||||
private IRagStoreService ragStoreService;
|
||||
|
||||
@Mock
|
||||
private IRagStoreModelConfigService ragStoreModelConfigService;
|
||||
|
||||
@Mock
|
||||
private RagChunkEmbeddingMapper ragChunkEmbeddingMapper;
|
||||
|
||||
@Mock
|
||||
private EmbeddingModelGateway embeddingModelGateway;
|
||||
|
||||
@Mock
|
||||
private ChatModelGateway chatModelGateway;
|
||||
|
||||
@Test
|
||||
void saveOrUpdateShouldValidateBoundStoreExists() {
|
||||
AgentDefinitionSaveRequest request = new AgentDefinitionSaveRequest();
|
||||
request.setAgentCode("A_1");
|
||||
request.setAgentName("Agent 1");
|
||||
request.setStoreId(1001L);
|
||||
|
||||
when(ragStoreService.getById(1001L)).thenReturn(null);
|
||||
|
||||
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.saveOrUpdate(request));
|
||||
assertTrue(exception.getMessage().contains("绑定知识库不存在"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatShouldRejectDisabledAgent() {
|
||||
AgentDefinition agent = new AgentDefinition();
|
||||
agent.setId(1001L);
|
||||
agent.setStoreId(2001L);
|
||||
agent.setStatus("DISABLED");
|
||||
doReturn(agent).when(agentDefinitionService).getById(1001L);
|
||||
|
||||
AgentChatRequest request = new AgentChatRequest();
|
||||
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
|
||||
message.setRole("user");
|
||||
message.setContent("你好");
|
||||
request.setMessages(List.of(message));
|
||||
|
||||
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
|
||||
assertTrue(exception.getMessage().contains("停用"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatShouldRejectAgentWithoutStore() {
|
||||
AgentDefinition agent = new AgentDefinition();
|
||||
agent.setId(1001L);
|
||||
agent.setStatus("ENABLED");
|
||||
agent.setStoreId(null);
|
||||
doReturn(agent).when(agentDefinitionService).getById(1001L);
|
||||
|
||||
AgentChatRequest request = new AgentChatRequest();
|
||||
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
|
||||
message.setRole("user");
|
||||
message.setContent("你好");
|
||||
request.setMessages(List.of(message));
|
||||
|
||||
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
|
||||
assertTrue(exception.getMessage().contains("未绑定知识库"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatShouldUseStoreScopedRecallAndReturnAnswer() {
|
||||
AgentDefinition agent = new AgentDefinition();
|
||||
agent.setId(1001L);
|
||||
agent.setAgentCode("AGENT_1");
|
||||
agent.setAgentName("知识助手");
|
||||
agent.setSystemPrompt("你是企业知识助手");
|
||||
agent.setStoreId(2001L);
|
||||
agent.setStatus("ENABLED");
|
||||
doReturn(agent).when(agentDefinitionService).getById(1001L);
|
||||
|
||||
RagStore store = new RagStore();
|
||||
store.setId(2001L);
|
||||
store.setStoreName("企业知识库");
|
||||
when(ragStoreService.getById(2001L)).thenReturn(store);
|
||||
|
||||
RagStoreModelConfig modelConfig = new RagStoreModelConfig();
|
||||
modelConfig.setStoreId(2001L);
|
||||
modelConfig.setEmbeddingModelId(3001L);
|
||||
modelConfig.setEmbeddingDimension(1024);
|
||||
when(ragStoreModelConfigService.getActiveEntity(2001L)).thenReturn(modelConfig);
|
||||
|
||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||
embeddingResult.setVectors(List.of(List.of(0.12, 0.34, 0.56)));
|
||||
when(embeddingModelGateway.embed(any(EmbeddingRequest.class))).thenReturn(embeddingResult);
|
||||
|
||||
RagChunkRecallResponse recall = new RagChunkRecallResponse();
|
||||
recall.setChunkId(4001L);
|
||||
recall.setDocumentId(5001L);
|
||||
recall.setChunkContent("公司请假流程:先提交审批单。");
|
||||
recall.setScore(0.91);
|
||||
when(ragChunkEmbeddingMapper.queryTopKByStore(anyLong(), anyString(), anyInt()))
|
||||
.thenReturn(List.of(recall));
|
||||
|
||||
ModelCallLog callLog = new ModelCallLog();
|
||||
callLog.setRequestId("req_001");
|
||||
ChatResult chatResult = new ChatResult();
|
||||
chatResult.setContent("根据知识库,先在OA提交请假审批。");
|
||||
chatResult.setCallLog(callLog);
|
||||
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
|
||||
|
||||
AgentChatRequest request = new AgentChatRequest();
|
||||
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
|
||||
message.setRole("user");
|
||||
message.setContent("请假流程是什么?");
|
||||
request.setMessages(List.of(message));
|
||||
|
||||
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
|
||||
|
||||
assertEquals(1001L, response.getAgentId());
|
||||
assertEquals(2001L, response.getStoreId());
|
||||
assertEquals("企业知识库", response.getStoreName());
|
||||
assertEquals("根据知识库,先在OA提交请假审批。", response.getAnswer());
|
||||
assertEquals("req_001", response.getModelRequestId());
|
||||
assertEquals(1, response.getReferences().size());
|
||||
assertEquals(4001L, response.getReferences().getFirst().getChunkId());
|
||||
|
||||
ArgumentCaptor<EmbeddingRequest> embeddingRequestCaptor = ArgumentCaptor.forClass(EmbeddingRequest.class);
|
||||
verify(embeddingModelGateway).embed(embeddingRequestCaptor.capture());
|
||||
EmbeddingRequest embeddingRequest = embeddingRequestCaptor.getValue();
|
||||
assertEquals("RAG_QUERY_EMBEDDING", embeddingRequest.getTaskType());
|
||||
assertEquals("RAG_STORE", embeddingRequest.getMatchScope());
|
||||
assertEquals(2001L, embeddingRequest.getScopeId());
|
||||
|
||||
verify(ragChunkEmbeddingMapper).queryTopKByStore(anyLong(), anyString(), anyInt());
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatShouldSupportSimpleModeWithoutRagRecall() {
|
||||
AgentDefinition agent = new AgentDefinition();
|
||||
agent.setId(1001L);
|
||||
agent.setAgentCode("AGENT_1");
|
||||
agent.setAgentName("知识助手");
|
||||
agent.setStoreId(2001L);
|
||||
agent.setStatus("ENABLED");
|
||||
doReturn(agent).when(agentDefinitionService).getById(1001L);
|
||||
|
||||
RagStore store = new RagStore();
|
||||
store.setId(2001L);
|
||||
store.setStoreName("企业知识库");
|
||||
when(ragStoreService.getById(2001L)).thenReturn(store);
|
||||
|
||||
ModelCallLog callLog = new ModelCallLog();
|
||||
callLog.setRequestId("req_simple_001");
|
||||
ChatResult chatResult = new ChatResult();
|
||||
chatResult.setContent("这是普通对话回答。");
|
||||
chatResult.setCallLog(callLog);
|
||||
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
|
||||
|
||||
AgentChatRequest request = new AgentChatRequest();
|
||||
request.setRagEnabled(false);
|
||||
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
|
||||
message.setRole("user");
|
||||
message.setContent("直接聊聊今天安排");
|
||||
request.setMessages(List.of(message));
|
||||
|
||||
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
|
||||
|
||||
assertEquals("这是普通对话回答。", response.getAnswer());
|
||||
assertTrue(response.getReferences().isEmpty());
|
||||
verify(embeddingModelGateway, never()).embed(any(EmbeddingRequest.class));
|
||||
verify(ragChunkEmbeddingMapper, never()).queryTopKByStore(anyLong(), anyString(), anyInt());
|
||||
|
||||
ArgumentCaptor<ChatRequest> chatRequestCaptor = ArgumentCaptor.forClass(ChatRequest.class);
|
||||
verify(chatModelGateway).chat(chatRequestCaptor.capture());
|
||||
assertEquals("CHAT_SIMPLE", chatRequestCaptor.getValue().getTaskType());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user