package com.bruce.agent; import com.bruce.agent.dto.request.AgentChatRequest; import com.bruce.agent.dto.request.AgentDefinitionSaveRequest; import com.bruce.agent.dto.response.AgentChatResponse; import com.bruce.agent.entity.AgentDefinition; import com.bruce.agent.service.impl.AgentDefinitionServiceImpl; import com.bruce.modelprovider.entity.ModelCallLog; import com.bruce.modelprovider.entity.RagStoreModelConfig; import com.bruce.modelprovider.gateway.ChatRequest; import com.bruce.modelprovider.gateway.ChatResult; import com.bruce.modelprovider.gateway.EmbeddingRequest; import com.bruce.modelprovider.gateway.EmbeddingResult; import com.bruce.modelprovider.service.IRagStoreModelConfigService; import com.bruce.rag.dto.response.RagChunkRecallResponse; import com.bruce.rag.entity.RagStore; import com.bruce.rag.mapper.RagChunkEmbeddingMapper; import com.bruce.rag.service.IRagStoreService; import com.bruce.modelprovider.gateway.ChatModelGateway; import com.bruce.modelprovider.gateway.EmbeddingModelGateway; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Spy; import org.mockito.junit.jupiter.MockitoExtension; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class AgentDefinitionServiceImplTests { @Spy @InjectMocks private AgentDefinitionServiceImpl agentDefinitionService; @Mock private IRagStoreService ragStoreService; @Mock private IRagStoreModelConfigService ragStoreModelConfigService; @Mock private RagChunkEmbeddingMapper ragChunkEmbeddingMapper; @Mock private EmbeddingModelGateway embeddingModelGateway; @Mock private ChatModelGateway chatModelGateway; @Test void saveOrUpdateShouldValidateBoundStoreExists() { AgentDefinitionSaveRequest request = new AgentDefinitionSaveRequest(); request.setAgentCode("A_1"); request.setAgentName("Agent 1"); request.setStoreId(1001L); when(ragStoreService.getById(1001L)).thenReturn(null); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.saveOrUpdate(request)); assertTrue(exception.getMessage().contains("绑定知识库不存在")); } @Test void chatShouldRejectDisabledAgent() { AgentDefinition agent = new AgentDefinition(); agent.setId(1001L); agent.setStoreId(2001L); agent.setStatus("DISABLED"); doReturn(agent).when(agentDefinitionService).getById(1001L); AgentChatRequest request = new AgentChatRequest(); AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage(); message.setRole("user"); message.setContent("你好"); request.setMessages(List.of(message)); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request)); assertTrue(exception.getMessage().contains("停用")); } @Test void chatShouldRejectAgentWithoutStore() { AgentDefinition agent = new AgentDefinition(); agent.setId(1001L); agent.setStatus("ENABLED"); agent.setStoreId(null); doReturn(agent).when(agentDefinitionService).getById(1001L); AgentChatRequest request = new AgentChatRequest(); AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage(); message.setRole("user"); message.setContent("你好"); request.setMessages(List.of(message)); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request)); assertTrue(exception.getMessage().contains("未绑定知识库")); } @Test void chatShouldUseStoreScopedRecallAndReturnAnswer() { AgentDefinition agent = new AgentDefinition(); agent.setId(1001L); agent.setAgentCode("AGENT_1"); agent.setAgentName("知识助手"); agent.setSystemPrompt("你是企业知识助手"); agent.setStoreId(2001L); agent.setStatus("ENABLED"); doReturn(agent).when(agentDefinitionService).getById(1001L); RagStore store = new RagStore(); store.setId(2001L); store.setStoreName("企业知识库"); when(ragStoreService.getById(2001L)).thenReturn(store); RagStoreModelConfig modelConfig = new RagStoreModelConfig(); modelConfig.setStoreId(2001L); modelConfig.setEmbeddingModelId(3001L); modelConfig.setEmbeddingDimension(1024); when(ragStoreModelConfigService.getActiveEntity(2001L)).thenReturn(modelConfig); EmbeddingResult embeddingResult = new EmbeddingResult(); embeddingResult.setVectors(List.of(List.of(0.12, 0.34, 0.56))); when(embeddingModelGateway.embed(any(EmbeddingRequest.class))).thenReturn(embeddingResult); RagChunkRecallResponse recall = new RagChunkRecallResponse(); recall.setChunkId(4001L); recall.setDocumentId(5001L); recall.setChunkContent("公司请假流程:先提交审批单。"); recall.setScore(0.91); when(ragChunkEmbeddingMapper.queryTopKByStore(anyLong(), anyString(), anyInt())) .thenReturn(List.of(recall)); ModelCallLog callLog = new ModelCallLog(); callLog.setRequestId("req_001"); ChatResult chatResult = new ChatResult(); chatResult.setContent("根据知识库,先在OA提交请假审批。"); chatResult.setCallLog(callLog); when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult); AgentChatRequest request = new AgentChatRequest(); AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage(); message.setRole("user"); message.setContent("请假流程是什么?"); request.setMessages(List.of(message)); AgentChatResponse response = agentDefinitionService.chat(1001L, request); assertEquals(1001L, response.getAgentId()); assertEquals(2001L, response.getStoreId()); assertEquals("企业知识库", response.getStoreName()); assertEquals("根据知识库,先在OA提交请假审批。", response.getAnswer()); assertEquals("req_001", response.getModelRequestId()); assertEquals(1, response.getReferences().size()); assertEquals(4001L, response.getReferences().getFirst().getChunkId()); ArgumentCaptor embeddingRequestCaptor = ArgumentCaptor.forClass(EmbeddingRequest.class); verify(embeddingModelGateway).embed(embeddingRequestCaptor.capture()); EmbeddingRequest embeddingRequest = embeddingRequestCaptor.getValue(); assertEquals("RAG_QUERY_EMBEDDING", embeddingRequest.getTaskType()); assertEquals("RAG_STORE", embeddingRequest.getMatchScope()); assertEquals(2001L, embeddingRequest.getScopeId()); verify(ragChunkEmbeddingMapper).queryTopKByStore(anyLong(), anyString(), anyInt()); } @Test void chatShouldSupportSimpleModeWithoutRagRecall() { AgentDefinition agent = new AgentDefinition(); agent.setId(1001L); agent.setAgentCode("AGENT_1"); agent.setAgentName("知识助手"); agent.setStoreId(2001L); agent.setStatus("ENABLED"); doReturn(agent).when(agentDefinitionService).getById(1001L); RagStore store = new RagStore(); store.setId(2001L); store.setStoreName("企业知识库"); when(ragStoreService.getById(2001L)).thenReturn(store); ModelCallLog callLog = new ModelCallLog(); callLog.setRequestId("req_simple_001"); ChatResult chatResult = new ChatResult(); chatResult.setContent("这是普通对话回答。"); chatResult.setCallLog(callLog); when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult); AgentChatRequest request = new AgentChatRequest(); request.setRagEnabled(false); AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage(); message.setRole("user"); message.setContent("直接聊聊今天安排"); request.setMessages(List.of(message)); AgentChatResponse response = agentDefinitionService.chat(1001L, request); assertEquals("这是普通对话回答。", response.getAnswer()); assertTrue(response.getReferences().isEmpty()); verify(embeddingModelGateway, never()).embed(any(EmbeddingRequest.class)); verify(ragChunkEmbeddingMapper, never()).queryTopKByStore(anyLong(), anyString(), anyInt()); ArgumentCaptor chatRequestCaptor = ArgumentCaptor.forClass(ChatRequest.class); verify(chatModelGateway).chat(chatRequestCaptor.capture()); assertEquals("CHAT_SIMPLE", chatRequestCaptor.getValue().getTaskType()); } }