Files
common_agent/src/test/java/com/bruce/agent/AgentDefinitionServiceImplTests.java

222 lines
9.4 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package com.bruce.agent;
import com.bruce.agent.dto.request.AgentChatRequest;
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
import com.bruce.agent.dto.response.AgentChatResponse;
import com.bruce.agent.entity.AgentDefinition;
import com.bruce.agent.service.impl.AgentDefinitionServiceImpl;
import com.bruce.modelprovider.entity.ModelCallLog;
import com.bruce.modelprovider.entity.RagStoreModelConfig;
import com.bruce.modelprovider.gateway.ChatRequest;
import com.bruce.modelprovider.gateway.ChatResult;
import com.bruce.modelprovider.gateway.EmbeddingRequest;
import com.bruce.modelprovider.gateway.EmbeddingResult;
import com.bruce.modelprovider.service.IRagStoreModelConfigService;
import com.bruce.rag.dto.response.RagChunkRecallResponse;
import com.bruce.rag.entity.RagStore;
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
import com.bruce.rag.service.IRagStoreService;
import com.bruce.modelprovider.gateway.ChatModelGateway;
import com.bruce.modelprovider.gateway.EmbeddingModelGateway;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class AgentDefinitionServiceImplTests {
@Spy
@InjectMocks
private AgentDefinitionServiceImpl agentDefinitionService;
@Mock
private IRagStoreService ragStoreService;
@Mock
private IRagStoreModelConfigService ragStoreModelConfigService;
@Mock
private RagChunkEmbeddingMapper ragChunkEmbeddingMapper;
@Mock
private EmbeddingModelGateway embeddingModelGateway;
@Mock
private ChatModelGateway chatModelGateway;
@Test
void saveOrUpdateShouldValidateBoundStoreExists() {
AgentDefinitionSaveRequest request = new AgentDefinitionSaveRequest();
request.setAgentCode("A_1");
request.setAgentName("Agent 1");
request.setStoreId(1001L);
when(ragStoreService.getById(1001L)).thenReturn(null);
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.saveOrUpdate(request));
assertTrue(exception.getMessage().contains("绑定知识库不存在"));
}
@Test
void chatShouldRejectDisabledAgent() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setStoreId(2001L);
agent.setStatus("DISABLED");
doReturn(agent).when(agentDefinitionService).getById(1001L);
AgentChatRequest request = new AgentChatRequest();
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("你好");
request.setMessages(List.of(message));
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
assertTrue(exception.getMessage().contains("停用"));
}
@Test
void chatShouldRejectAgentWithoutStore() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setStatus("ENABLED");
agent.setStoreId(null);
doReturn(agent).when(agentDefinitionService).getById(1001L);
AgentChatRequest request = new AgentChatRequest();
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("你好");
request.setMessages(List.of(message));
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
assertTrue(exception.getMessage().contains("未绑定知识库"));
}
@Test
void chatShouldUseStoreScopedRecallAndReturnAnswer() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setAgentCode("AGENT_1");
agent.setAgentName("知识助手");
agent.setSystemPrompt("你是企业知识助手");
agent.setStoreId(2001L);
agent.setStatus("ENABLED");
doReturn(agent).when(agentDefinitionService).getById(1001L);
RagStore store = new RagStore();
store.setId(2001L);
store.setStoreName("企业知识库");
when(ragStoreService.getById(2001L)).thenReturn(store);
RagStoreModelConfig modelConfig = new RagStoreModelConfig();
modelConfig.setStoreId(2001L);
modelConfig.setEmbeddingModelId(3001L);
modelConfig.setEmbeddingDimension(1024);
when(ragStoreModelConfigService.getActiveEntity(2001L)).thenReturn(modelConfig);
EmbeddingResult embeddingResult = new EmbeddingResult();
embeddingResult.setVectors(List.of(List.of(0.12, 0.34, 0.56)));
when(embeddingModelGateway.embed(any(EmbeddingRequest.class))).thenReturn(embeddingResult);
RagChunkRecallResponse recall = new RagChunkRecallResponse();
recall.setChunkId(4001L);
recall.setDocumentId(5001L);
recall.setChunkContent("公司请假流程:先提交审批单。");
recall.setScore(0.91);
when(ragChunkEmbeddingMapper.queryTopKByStore(anyLong(), anyString(), anyInt()))
.thenReturn(List.of(recall));
ModelCallLog callLog = new ModelCallLog();
callLog.setRequestId("req_001");
ChatResult chatResult = new ChatResult();
chatResult.setContent("根据知识库先在OA提交请假审批。");
chatResult.setCallLog(callLog);
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
AgentChatRequest request = new AgentChatRequest();
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("请假流程是什么?");
request.setMessages(List.of(message));
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
assertEquals(1001L, response.getAgentId());
assertEquals(2001L, response.getStoreId());
assertEquals("企业知识库", response.getStoreName());
assertEquals("根据知识库先在OA提交请假审批。", response.getAnswer());
assertEquals("req_001", response.getModelRequestId());
assertEquals(1, response.getReferences().size());
assertEquals(4001L, response.getReferences().getFirst().getChunkId());
ArgumentCaptor<EmbeddingRequest> embeddingRequestCaptor = ArgumentCaptor.forClass(EmbeddingRequest.class);
verify(embeddingModelGateway).embed(embeddingRequestCaptor.capture());
EmbeddingRequest embeddingRequest = embeddingRequestCaptor.getValue();
assertEquals("RAG_QUERY_EMBEDDING", embeddingRequest.getTaskType());
assertEquals("RAG_STORE", embeddingRequest.getMatchScope());
assertEquals(2001L, embeddingRequest.getScopeId());
verify(ragChunkEmbeddingMapper).queryTopKByStore(anyLong(), anyString(), anyInt());
}
@Test
void chatShouldSupportSimpleModeWithoutRagRecall() {
AgentDefinition agent = new AgentDefinition();
agent.setId(1001L);
agent.setAgentCode("AGENT_1");
agent.setAgentName("知识助手");
agent.setStoreId(2001L);
agent.setStatus("ENABLED");
doReturn(agent).when(agentDefinitionService).getById(1001L);
RagStore store = new RagStore();
store.setId(2001L);
store.setStoreName("企业知识库");
when(ragStoreService.getById(2001L)).thenReturn(store);
ModelCallLog callLog = new ModelCallLog();
callLog.setRequestId("req_simple_001");
ChatResult chatResult = new ChatResult();
chatResult.setContent("这是普通对话回答。");
chatResult.setCallLog(callLog);
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
AgentChatRequest request = new AgentChatRequest();
request.setRagEnabled(false);
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
message.setRole("user");
message.setContent("直接聊聊今天安排");
request.setMessages(List.of(message));
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
assertEquals("这是普通对话回答。", response.getAnswer());
assertTrue(response.getReferences().isEmpty());
verify(embeddingModelGateway, never()).embed(any(EmbeddingRequest.class));
verify(ragChunkEmbeddingMapper, never()).queryTopKByStore(anyLong(), anyString(), anyInt());
ArgumentCaptor<ChatRequest> chatRequestCaptor = ArgumentCaptor.forClass(ChatRequest.class);
verify(chatModelGateway).chat(chatRequestCaptor.capture());
assertEquals("CHAT_SIMPLE", chatRequestCaptor.getValue().getTaskType());
}
}