Files
common_agent/src/test/java/com/bruce/rag/RagDocumentServiceImplTests.java

165 lines
7.7 KiB
Java

package com.bruce.rag;
import com.bruce.common.domain.entity.SysAttachment;
import com.bruce.common.dto.request.SysAttachmentUploadRequest;
import com.bruce.common.service.ISysAttachmentService;
import com.bruce.rag.constant.RagSystemConstants;
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.enums.RagIndexStatusEnum;
import com.bruce.rag.enums.RagParseStatusEnum;
import com.bruce.rag.service.impl.RagDocumentServiceImpl;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.context.ApplicationEventPublisher;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class RagDocumentServiceImplTests {
@Spy
@InjectMocks
private RagDocumentServiceImpl ragDocumentService;
@Mock
private ISysAttachmentService sysAttachmentService;
@Mock
private ApplicationEventPublisher eventPublisher;
@Test
void batchUploadShouldUseRagSourceTypeAndStoreIdAsSourceId() {
MockMultipartFile file = new MockMultipartFile(
"files",
"knowledge.txt",
"text/plain",
"hello rag".getBytes()
);
RagDocumentBatchUploadRequest request = new RagDocumentBatchUploadRequest();
request.setStoreId(1001L);
request.setSourceType(RagSystemConstants.SOURCE_TYPE_RAG);
request.setFiles(new MockMultipartFile[]{file});
request.setDocumentSummary("批量摘要");
request.setRemark("批量备注");
SysAttachment attachment = new SysAttachment();
attachment.setId(2002L);
when(sysAttachmentService.upload(any(SysAttachmentUploadRequest.class))).thenReturn(attachment);
doAnswer(invocation -> true).when(ragDocumentService).save(any(RagDocument.class));
var responses = ragDocumentService.batchUpload(request);
ArgumentCaptor<SysAttachmentUploadRequest> uploadCaptor = ArgumentCaptor.forClass(SysAttachmentUploadRequest.class);
verify(sysAttachmentService).upload(uploadCaptor.capture());
SysAttachmentUploadRequest uploadRequest = uploadCaptor.getValue();
assertEquals(RagSystemConstants.SOURCE_TYPE_RAG, uploadRequest.getSourceType());
assertEquals(1001L, uploadRequest.getSourceId());
assertEquals(file, uploadRequest.getFile());
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
verify(ragDocumentService).save(documentCaptor.capture());
RagDocument savedDocument = documentCaptor.getValue();
assertEquals(1001L, savedDocument.getStoreId());
assertEquals(2002L, savedDocument.getAttachmentId());
assertEquals("knowledge.txt", savedDocument.getDocumentTitle());
assertEquals("批量摘要", savedDocument.getDocumentSummary());
assertEquals(RagParseStatusEnum.UPLOADED.name(), savedDocument.getParseStatus());
assertEquals(RagIndexStatusEnum.PENDING.name(), savedDocument.getIndexStatus());
assertTrue(savedDocument.getEnabled());
assertNull(savedDocument.getErrorMessage());
assertEquals("批量备注", savedDocument.getRemark());
assertEquals(1, responses.size());
assertEquals(RagParseStatusEnum.UPLOADED.name(), responses.getFirst().getParseStatus());
assertEquals(RagIndexStatusEnum.PENDING.name(), responses.getFirst().getIndexStatus());
}
@Test
void saveOrUpdateShouldWriteAllEditableFields() {
RagDocument existingDocument = new RagDocument();
existingDocument.setId(3003L);
RagDocumentSaveRequest request = new RagDocumentSaveRequest();
request.setId(3003L);
request.setStoreId(1001L);
request.setAttachmentId(2002L);
request.setDocumentTitle(" 新标题 ");
request.setDocumentSummary(" 新摘要 ");
request.setParseStatus(RagParseStatusEnum.PARSED.name());
request.setIndexStatus(RagIndexStatusEnum.INDEXED.name());
request.setEnabled(false);
request.setErrorMessage(" 已修复 ");
request.setRemark(" 备注信息 ");
doReturn(existingDocument).when(ragDocumentService).getById(3003L);
doReturn(true).when(ragDocumentService).updateById(any(RagDocument.class));
boolean result = ragDocumentService.saveOrUpdate(request);
assertTrue(result);
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
verify(ragDocumentService).updateById(documentCaptor.capture());
RagDocument savedDocument = documentCaptor.getValue();
assertEquals(3003L, savedDocument.getId());
assertEquals(1001L, savedDocument.getStoreId());
assertEquals(2002L, savedDocument.getAttachmentId());
assertEquals("新标题", savedDocument.getDocumentTitle());
assertEquals("新摘要", savedDocument.getDocumentSummary());
assertEquals(RagParseStatusEnum.PARSED.name(), savedDocument.getParseStatus());
assertEquals(RagIndexStatusEnum.INDEXED.name(), savedDocument.getIndexStatus());
assertEquals(false, savedDocument.getEnabled());
assertEquals("已修复", savedDocument.getErrorMessage());
assertEquals("备注信息", savedDocument.getRemark());
}
@Test
void saveOrUpdateShouldPreserveExistingFieldsForPartialUpdate() {
RagDocument existingDocument = new RagDocument();
existingDocument.setId(3003L);
existingDocument.setStoreId(1001L);
existingDocument.setAttachmentId(2002L);
existingDocument.setDocumentTitle("people_profiles.txt");
existingDocument.setDocumentSummary("测试人员信息,有多条人员信息");
existingDocument.setParseStatus(RagParseStatusEnum.UPLOADED.name());
existingDocument.setIndexStatus(RagIndexStatusEnum.PENDING.name());
existingDocument.setEnabled(true);
existingDocument.setRemark("测试人员信息");
RagDocumentSaveRequest request = new RagDocumentSaveRequest();
request.setId(3003L);
request.setStoreId(1001L);
request.setDocumentTitle("people_profiles.txt");
request.setEnabled(false);
doReturn(existingDocument).when(ragDocumentService).getById(3003L);
doReturn(true).when(ragDocumentService).updateById(any(RagDocument.class));
boolean result = ragDocumentService.saveOrUpdate(request);
assertTrue(result);
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
verify(ragDocumentService).updateById(documentCaptor.capture());
RagDocument savedDocument = documentCaptor.getValue();
assertEquals(2002L, savedDocument.getAttachmentId());
assertEquals("测试人员信息,有多条人员信息", savedDocument.getDocumentSummary());
assertEquals(RagParseStatusEnum.UPLOADED.name(), savedDocument.getParseStatus());
assertEquals(RagIndexStatusEnum.PENDING.name(), savedDocument.getIndexStatus());
assertEquals(false, savedDocument.getEnabled());
assertEquals("测试人员信息", savedDocument.getRemark());
}
}