feat(rag-document): 补全文档管理接口与页面
This commit is contained in:
@@ -6,6 +6,11 @@ public final class RagSystemConstants {
|
||||
|
||||
public static final String RAG_DOCUMENT = "RAG_DOCUMENT";
|
||||
|
||||
/**
|
||||
* 用于 sys_attachment.sourceType 标识该附件归属于 RAG 知识库业务。
|
||||
*/
|
||||
public static final String SOURCE_TYPE_RAG = "RAG";
|
||||
|
||||
private RagSystemConstants() {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
package com.bruce.rag.controller;
|
||||
|
||||
import com.bruce.common.domain.model.RequestResult;
|
||||
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
|
||||
import com.bruce.rag.dto.response.RagDocumentResponse;
|
||||
import com.bruce.rag.service.IRagDocumentService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.ModelAttribute;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Tag(name = "RAG知识库文档管理")
|
||||
@Slf4j
|
||||
@RestController
|
||||
@RequestMapping("/api/rag/documents")
|
||||
public class RagDocumentController {
|
||||
@@ -24,14 +30,59 @@ public class RagDocumentController {
|
||||
private IRagDocumentService ragDocumentService;
|
||||
|
||||
@Operation(summary = "查询全部知识库文档")
|
||||
@GetMapping
|
||||
@PostMapping("/list")
|
||||
public RequestResult<List<RagDocumentResponse>> list() {
|
||||
return RequestResult.success(ragDocumentService.listResponses());
|
||||
log.info("RagDocumentController.list start");
|
||||
List<RagDocumentResponse> responses = ragDocumentService.listResponses();
|
||||
log.info("RagDocumentController.list success, count={}", responses.size());
|
||||
return RequestResult.success(responses);
|
||||
}
|
||||
|
||||
@Operation(summary = "按条件查询知识库文档")
|
||||
@PostMapping("/query")
|
||||
public RequestResult<List<RagDocumentResponse>> query(@RequestBody RagDocumentQueryRequest request) {
|
||||
return RequestResult.success(ragDocumentService.query(request));
|
||||
public RequestResult<List<RagDocumentResponse>> query(@RequestBody(required = false) RagDocumentQueryRequest request) {
|
||||
log.info("RagDocumentController.query start, request={}", request);
|
||||
List<RagDocumentResponse> responses = ragDocumentService.query(request);
|
||||
log.info("RagDocumentController.query success, count={}", responses.size());
|
||||
return RequestResult.success(responses);
|
||||
}
|
||||
|
||||
@Operation(summary = "查询知识库文档详情")
|
||||
@GetMapping("/detail")
|
||||
public RequestResult<RagDocumentResponse> getById(@RequestParam("id") Long id) {
|
||||
log.info("RagDocumentController.getById start, id={}", id);
|
||||
RagDocumentResponse response = ragDocumentService.getResponseById(id);
|
||||
log.info("RagDocumentController.getById success, id={}, found={}", id, response != null);
|
||||
return RequestResult.success(response);
|
||||
}
|
||||
|
||||
@Operation(summary = "新增或修改知识库文档")
|
||||
@PostMapping("/save")
|
||||
public RequestResult<Boolean> saveOrUpdate(@RequestBody RagDocumentSaveRequest request) {
|
||||
log.info("RagDocumentController.saveOrUpdate start, request={}", request);
|
||||
Boolean result = ragDocumentService.saveOrUpdate(request);
|
||||
log.info("RagDocumentController.saveOrUpdate success, id={}, result={}",
|
||||
request.getId(), result);
|
||||
return RequestResult.success(result);
|
||||
}
|
||||
|
||||
@Operation(summary = "删除知识库文档")
|
||||
@PostMapping("/delete")
|
||||
public RequestResult<Boolean> deleteById(@RequestParam("id") Long id) {
|
||||
log.info("RagDocumentController.deleteById start, id={}", id);
|
||||
Boolean result = ragDocumentService.removeById(id);
|
||||
log.info("RagDocumentController.deleteById success, id={}, result={}", id, result);
|
||||
return RequestResult.success(result);
|
||||
}
|
||||
|
||||
@Operation(summary = "批量上传文档到知识库")
|
||||
@PostMapping("/batchUpload")
|
||||
public RequestResult<List<RagDocumentResponse>> batchUpload(@ModelAttribute RagDocumentBatchUploadRequest request) {
|
||||
log.info("RagDocumentController.batchUpload start, storeId={}, fileCount={}",
|
||||
request.getStoreId(), request.getFiles() != null ? request.getFiles().length : 0);
|
||||
List<RagDocumentResponse> responses = ragDocumentService.batchUpload(request);
|
||||
log.info("RagDocumentController.batchUpload success, storeId={}, uploaded={}",
|
||||
request.getStoreId(), responses.size());
|
||||
return RequestResult.success(responses);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.bruce.rag.dto.request;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
@Data
|
||||
@Schema(description = "RAG知识库文档批量上传请求")
|
||||
public class RagDocumentBatchUploadRequest {
|
||||
|
||||
@Schema(description = "知识库ID")
|
||||
private Long storeId;
|
||||
|
||||
@Schema(description = "附件来源类型,RAG 文档上传固定传 RAG")
|
||||
private String sourceType;
|
||||
|
||||
@Schema(description = "上传文件列表")
|
||||
private MultipartFile[] files;
|
||||
|
||||
@Schema(description = "文档摘要(批量设置)")
|
||||
private String documentSummary;
|
||||
|
||||
@Schema(description = "备注(批量设置)")
|
||||
private String remark;
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.bruce.rag.dto.request;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@Schema(description = "RAG知识库文档保存请求")
|
||||
public class RagDocumentSaveRequest {
|
||||
|
||||
@Schema(description = "主键ID(更新时必填)")
|
||||
private Long id;
|
||||
|
||||
@Schema(description = "知识库ID")
|
||||
private Long storeId;
|
||||
|
||||
@Schema(description = "附件ID")
|
||||
private Long attachmentId;
|
||||
|
||||
@Schema(description = "文档标题")
|
||||
private String documentTitle;
|
||||
|
||||
@Schema(description = "文档摘要")
|
||||
private String documentSummary;
|
||||
|
||||
@Schema(description = "解析状态")
|
||||
private String parseStatus;
|
||||
|
||||
@Schema(description = "索引状态")
|
||||
private String indexStatus;
|
||||
|
||||
@Schema(description = "是否启用")
|
||||
private Boolean enabled;
|
||||
|
||||
@Schema(description = "失败原因")
|
||||
private String errorMessage;
|
||||
|
||||
@Schema(description = "备注")
|
||||
private String remark;
|
||||
}
|
||||
@@ -1,21 +1,30 @@
|
||||
package com.bruce.rag.dto.response;
|
||||
|
||||
import com.bruce.common.constant.CommonConsts;
|
||||
import com.bruce.rag.entity.RagDocument;
|
||||
import com.fasterxml.jackson.annotation.JsonFormat;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@Schema(description = "RAG知识库文档响应")
|
||||
public class RagDocumentResponse {
|
||||
|
||||
@Schema(description = "主键ID")
|
||||
@JsonSerialize(using = ToStringSerializer.class)
|
||||
private Long id;
|
||||
|
||||
@Schema(description = "知识库ID")
|
||||
@JsonSerialize(using = ToStringSerializer.class)
|
||||
private Long storeId;
|
||||
|
||||
@Schema(description = "附件ID")
|
||||
@JsonSerialize(using = ToStringSerializer.class)
|
||||
private Long attachmentId;
|
||||
|
||||
@Schema(description = "文档标题")
|
||||
@@ -39,6 +48,14 @@ public class RagDocumentResponse {
|
||||
@Schema(description = "备注")
|
||||
private String remark;
|
||||
|
||||
@Schema(description = "创建时间")
|
||||
@JsonFormat(pattern = CommonConsts.DATE_FORMAT_LONG_STR, timezone = CommonConsts.TIME_ZONE_GMT8)
|
||||
private Date createTime;
|
||||
|
||||
@Schema(description = "更新时间")
|
||||
@JsonFormat(pattern = CommonConsts.DATE_FORMAT_LONG_STR, timezone = CommonConsts.TIME_ZONE_GMT8)
|
||||
private Date updateTime;
|
||||
|
||||
public static RagDocumentResponse fromEntity(RagDocument entity) {
|
||||
if (entity == null) {
|
||||
return null;
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package com.bruce.rag.service;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.IService;
|
||||
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
|
||||
import com.bruce.rag.dto.response.RagDocumentResponse;
|
||||
import com.bruce.rag.entity.RagDocument;
|
||||
|
||||
@@ -12,4 +14,12 @@ public interface IRagDocumentService extends IService<RagDocument> {
|
||||
List<RagDocumentResponse> listResponses();
|
||||
|
||||
List<RagDocumentResponse> query(RagDocumentQueryRequest request);
|
||||
|
||||
RagDocumentResponse getResponseById(Long id);
|
||||
|
||||
boolean saveOrUpdate(RagDocumentSaveRequest request);
|
||||
|
||||
boolean removeById(Long id);
|
||||
|
||||
List<RagDocumentResponse> batchUpload(RagDocumentBatchUploadRequest request);
|
||||
}
|
||||
|
||||
@@ -1,36 +1,190 @@
|
||||
package com.bruce.rag.service.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.bruce.common.domain.entity.SysAttachment;
|
||||
import com.bruce.common.dto.request.SysAttachmentUploadRequest;
|
||||
import com.bruce.common.service.ISysAttachmentService;
|
||||
import com.bruce.rag.constant.RagSystemConstants;
|
||||
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
|
||||
import com.bruce.rag.dto.response.RagDocumentResponse;
|
||||
import com.bruce.rag.entity.RagDocument;
|
||||
import com.bruce.rag.enums.RagIndexStatusEnum;
|
||||
import com.bruce.rag.enums.RagParseStatusEnum;
|
||||
import com.bruce.rag.mapper.RagDocumentMapper;
|
||||
import com.bruce.rag.service.IRagDocumentService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class RagDocumentServiceImpl extends ServiceImpl<RagDocumentMapper, RagDocument> implements IRagDocumentService {
|
||||
|
||||
@Autowired
|
||||
private ISysAttachmentService sysAttachmentService;
|
||||
|
||||
@Override
|
||||
public List<RagDocumentResponse> listResponses() {
|
||||
return toResponses(list());
|
||||
log.info("RagDocumentServiceImpl.listResponses start");
|
||||
List<RagDocumentResponse> responses = toResponses(list());
|
||||
log.info("RagDocumentServiceImpl.listResponses success, count={}", responses.size());
|
||||
return responses;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RagDocumentResponse> query(RagDocumentQueryRequest request) {
|
||||
if (request == null) {
|
||||
throw new IllegalArgumentException("查询请求不能为空");
|
||||
}
|
||||
return toResponses(lambdaQuery()
|
||||
.eq(request.getStoreId() != null, RagDocument::getStoreId, request.getStoreId())
|
||||
.eq(request.getAttachmentId() != null, RagDocument::getAttachmentId, request.getAttachmentId())
|
||||
.eq(request.getParseStatus() != null, RagDocument::getParseStatus, request.getParseStatus())
|
||||
.eq(request.getIndexStatus() != null, RagDocument::getIndexStatus, request.getIndexStatus())
|
||||
.eq(request.getEnabled() != null, RagDocument::getEnabled, request.getEnabled())
|
||||
log.info("RagDocumentServiceImpl.query start, request={}", request);
|
||||
RagDocumentQueryRequest queryRequest = request == null ? new RagDocumentQueryRequest() : request;
|
||||
String parseStatus = trimToNull(queryRequest.getParseStatus());
|
||||
String indexStatus = trimToNull(queryRequest.getIndexStatus());
|
||||
List<RagDocumentResponse> responses = toResponses(lambdaQuery()
|
||||
.eq(queryRequest.getStoreId() != null, RagDocument::getStoreId, queryRequest.getStoreId())
|
||||
.eq(queryRequest.getAttachmentId() != null, RagDocument::getAttachmentId, queryRequest.getAttachmentId())
|
||||
.eq(parseStatus != null, RagDocument::getParseStatus, parseStatus)
|
||||
.eq(indexStatus != null, RagDocument::getIndexStatus, indexStatus)
|
||||
.eq(queryRequest.getEnabled() != null, RagDocument::getEnabled, queryRequest.getEnabled())
|
||||
.orderByDesc(RagDocument::getId)
|
||||
.list());
|
||||
log.info("RagDocumentServiceImpl.query success, count={}", responses.size());
|
||||
return responses;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RagDocumentResponse getResponseById(Long id) {
|
||||
log.info("RagDocumentServiceImpl.getResponseById start, id={}", id);
|
||||
RagDocumentResponse response = RagDocumentResponse.fromEntity(getById(id));
|
||||
log.info("RagDocumentServiceImpl.getResponseById success, id={}, found={}", id, response != null);
|
||||
return response;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean saveOrUpdate(RagDocumentSaveRequest request) {
|
||||
log.info("RagDocumentServiceImpl.saveOrUpdate start, request={}", request);
|
||||
validateSaveRequest(request);
|
||||
|
||||
RagDocument document;
|
||||
if (request.getId() != null) {
|
||||
document = getById(request.getId());
|
||||
if (document == null) {
|
||||
log.warn("RagDocumentServiceImpl.saveOrUpdate document not found, id={}", request.getId());
|
||||
throw new IllegalArgumentException("文档不存在,ID: " + request.getId());
|
||||
}
|
||||
} else {
|
||||
document = new RagDocument();
|
||||
document.setEnabled(true);
|
||||
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
|
||||
document.setIndexStatus(RagIndexStatusEnum.PENDING.name());
|
||||
}
|
||||
|
||||
document.setStoreId(request.getStoreId());
|
||||
document.setAttachmentId(request.getAttachmentId());
|
||||
document.setDocumentTitle(request.getDocumentTitle().trim());
|
||||
document.setDocumentSummary(trimToNull(request.getDocumentSummary()));
|
||||
if (StringUtils.hasText(request.getParseStatus())) {
|
||||
document.setParseStatus(request.getParseStatus().trim());
|
||||
}
|
||||
if (StringUtils.hasText(request.getIndexStatus())) {
|
||||
document.setIndexStatus(request.getIndexStatus().trim());
|
||||
}
|
||||
if (request.getEnabled() != null) {
|
||||
document.setEnabled(request.getEnabled());
|
||||
}
|
||||
document.setErrorMessage(trimToNull(request.getErrorMessage()));
|
||||
document.setRemark(trimToNull(request.getRemark()));
|
||||
|
||||
boolean result = saveOrUpdate(document);
|
||||
log.info("RagDocumentServiceImpl.saveOrUpdate success, requestId={}, savedId={}, result={}",
|
||||
request.getId(), document.getId(), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean removeById(Long id) {
|
||||
log.info("RagDocumentServiceImpl.removeById start, id={}", id);
|
||||
if (id == null) {
|
||||
throw new IllegalArgumentException("文档ID不能为空");
|
||||
}
|
||||
RagDocument document = getById(id);
|
||||
if (document == null) {
|
||||
log.warn("RagDocumentServiceImpl.removeById document not found, id={}", id);
|
||||
throw new IllegalArgumentException("文档不存在,ID: " + id);
|
||||
}
|
||||
boolean result = super.removeById(id);
|
||||
log.info("RagDocumentServiceImpl.removeById success, id={}, result={}", id, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RagDocumentResponse> batchUpload(RagDocumentBatchUploadRequest request) {
|
||||
log.info("RagDocumentServiceImpl.batchUpload start, request={}", request);
|
||||
validateBatchUploadRequest(request);
|
||||
|
||||
List<RagDocumentResponse> results = new ArrayList<>();
|
||||
|
||||
for (var file : request.getFiles()) {
|
||||
if (file == null || file.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SysAttachmentUploadRequest uploadRequest = new SysAttachmentUploadRequest();
|
||||
uploadRequest.setFile(file);
|
||||
uploadRequest.setSourceType(resolveSourceType(request.getSourceType()));
|
||||
uploadRequest.setSourceId(request.getStoreId());
|
||||
|
||||
SysAttachment attachment = sysAttachmentService.upload(uploadRequest);
|
||||
|
||||
RagDocument document = new RagDocument();
|
||||
document.setStoreId(request.getStoreId());
|
||||
document.setAttachmentId(attachment.getId());
|
||||
document.setDocumentTitle(StringUtils.hasText(file.getOriginalFilename()) ? file.getOriginalFilename().trim() : null);
|
||||
document.setDocumentSummary(trimToNull(request.getDocumentSummary()));
|
||||
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
|
||||
document.setIndexStatus(RagIndexStatusEnum.PENDING.name());
|
||||
document.setEnabled(true);
|
||||
document.setErrorMessage(null);
|
||||
document.setRemark(trimToNull(request.getRemark()));
|
||||
|
||||
save(document);
|
||||
results.add(RagDocumentResponse.fromEntity(document));
|
||||
}
|
||||
|
||||
log.info("RagDocumentServiceImpl.batchUpload success, storeId={}, uploaded={}",
|
||||
request.getStoreId(), results.size());
|
||||
return results;
|
||||
}
|
||||
|
||||
void validateSaveRequest(RagDocumentSaveRequest request) {
|
||||
if (request == null) {
|
||||
throw new IllegalArgumentException("保存请求不能为空");
|
||||
}
|
||||
if (request.getStoreId() == null) {
|
||||
throw new IllegalArgumentException("知识库ID不能为空");
|
||||
}
|
||||
if (!StringUtils.hasText(request.getDocumentTitle())) {
|
||||
throw new IllegalArgumentException("文档标题不能为空");
|
||||
}
|
||||
}
|
||||
|
||||
void validateBatchUploadRequest(RagDocumentBatchUploadRequest request) {
|
||||
if (request == null) {
|
||||
throw new IllegalArgumentException("上传请求不能为空");
|
||||
}
|
||||
if (request.getStoreId() == null) {
|
||||
throw new IllegalArgumentException("知识库ID不能为空");
|
||||
}
|
||||
if (StringUtils.hasText(request.getSourceType())
|
||||
&& !RagSystemConstants.SOURCE_TYPE_RAG.equals(request.getSourceType().trim())) {
|
||||
throw new IllegalArgumentException("sourceType必须为RAG");
|
||||
}
|
||||
if (request.getFiles() == null || request.getFiles().length == 0) {
|
||||
throw new IllegalArgumentException("上传文件不能为空");
|
||||
}
|
||||
}
|
||||
|
||||
private List<RagDocumentResponse> toResponses(List<RagDocument> documents) {
|
||||
@@ -38,4 +192,18 @@ public class RagDocumentServiceImpl extends ServiceImpl<RagDocumentMapper, RagDo
|
||||
.map(RagDocumentResponse::fromEntity)
|
||||
.toList();
|
||||
}
|
||||
|
||||
private String resolveSourceType(String sourceType) {
|
||||
if (!StringUtils.hasText(sourceType)) {
|
||||
return RagSystemConstants.SOURCE_TYPE_RAG;
|
||||
}
|
||||
return RagSystemConstants.SOURCE_TYPE_RAG;
|
||||
}
|
||||
|
||||
private String trimToNull(String value) {
|
||||
if (!StringUtils.hasText(value)) {
|
||||
return null;
|
||||
}
|
||||
return value.trim();
|
||||
}
|
||||
}
|
||||
|
||||
124
src/test/java/com/bruce/rag/RagDocumentServiceImplTests.java
Normal file
124
src/test/java/com/bruce/rag/RagDocumentServiceImplTests.java
Normal file
@@ -0,0 +1,124 @@
|
||||
package com.bruce.rag;
|
||||
|
||||
import com.bruce.common.domain.entity.SysAttachment;
|
||||
import com.bruce.common.dto.request.SysAttachmentUploadRequest;
|
||||
import com.bruce.common.service.ISysAttachmentService;
|
||||
import com.bruce.rag.constant.RagSystemConstants;
|
||||
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
|
||||
import com.bruce.rag.entity.RagDocument;
|
||||
import com.bruce.rag.enums.RagIndexStatusEnum;
|
||||
import com.bruce.rag.enums.RagParseStatusEnum;
|
||||
import com.bruce.rag.service.impl.RagDocumentServiceImpl;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.Spy;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.springframework.mock.web.MockMultipartFile;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class RagDocumentServiceImplTests {
|
||||
|
||||
@Spy
|
||||
@InjectMocks
|
||||
private RagDocumentServiceImpl ragDocumentService;
|
||||
|
||||
@Mock
|
||||
private ISysAttachmentService sysAttachmentService;
|
||||
|
||||
@Test
|
||||
void batchUploadShouldUseRagSourceTypeAndStoreIdAsSourceId() {
|
||||
MockMultipartFile file = new MockMultipartFile(
|
||||
"files",
|
||||
"knowledge.txt",
|
||||
"text/plain",
|
||||
"hello rag".getBytes()
|
||||
);
|
||||
RagDocumentBatchUploadRequest request = new RagDocumentBatchUploadRequest();
|
||||
request.setStoreId(1001L);
|
||||
request.setSourceType(RagSystemConstants.SOURCE_TYPE_RAG);
|
||||
request.setFiles(new MockMultipartFile[]{file});
|
||||
request.setDocumentSummary("批量摘要");
|
||||
request.setRemark("批量备注");
|
||||
|
||||
SysAttachment attachment = new SysAttachment();
|
||||
attachment.setId(2002L);
|
||||
when(sysAttachmentService.upload(any(SysAttachmentUploadRequest.class))).thenReturn(attachment);
|
||||
doAnswer(invocation -> true).when(ragDocumentService).save(any(RagDocument.class));
|
||||
|
||||
var responses = ragDocumentService.batchUpload(request);
|
||||
|
||||
ArgumentCaptor<SysAttachmentUploadRequest> uploadCaptor = ArgumentCaptor.forClass(SysAttachmentUploadRequest.class);
|
||||
verify(sysAttachmentService).upload(uploadCaptor.capture());
|
||||
SysAttachmentUploadRequest uploadRequest = uploadCaptor.getValue();
|
||||
assertEquals(RagSystemConstants.SOURCE_TYPE_RAG, uploadRequest.getSourceType());
|
||||
assertEquals(1001L, uploadRequest.getSourceId());
|
||||
assertEquals(file, uploadRequest.getFile());
|
||||
|
||||
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
|
||||
verify(ragDocumentService).save(documentCaptor.capture());
|
||||
RagDocument savedDocument = documentCaptor.getValue();
|
||||
assertEquals(1001L, savedDocument.getStoreId());
|
||||
assertEquals(2002L, savedDocument.getAttachmentId());
|
||||
assertEquals("knowledge.txt", savedDocument.getDocumentTitle());
|
||||
assertEquals("批量摘要", savedDocument.getDocumentSummary());
|
||||
assertEquals(RagParseStatusEnum.UPLOADED.name(), savedDocument.getParseStatus());
|
||||
assertEquals(RagIndexStatusEnum.PENDING.name(), savedDocument.getIndexStatus());
|
||||
assertTrue(savedDocument.getEnabled());
|
||||
assertNull(savedDocument.getErrorMessage());
|
||||
assertEquals("批量备注", savedDocument.getRemark());
|
||||
assertEquals(1, responses.size());
|
||||
assertEquals(RagParseStatusEnum.UPLOADED.name(), responses.getFirst().getParseStatus());
|
||||
assertEquals(RagIndexStatusEnum.PENDING.name(), responses.getFirst().getIndexStatus());
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveOrUpdateShouldWriteAllEditableFields() {
|
||||
RagDocument existingDocument = new RagDocument();
|
||||
existingDocument.setId(3003L);
|
||||
|
||||
RagDocumentSaveRequest request = new RagDocumentSaveRequest();
|
||||
request.setId(3003L);
|
||||
request.setStoreId(1001L);
|
||||
request.setAttachmentId(2002L);
|
||||
request.setDocumentTitle(" 新标题 ");
|
||||
request.setDocumentSummary(" 新摘要 ");
|
||||
request.setParseStatus(RagParseStatusEnum.PARSED.name());
|
||||
request.setIndexStatus(RagIndexStatusEnum.INDEXED.name());
|
||||
request.setEnabled(false);
|
||||
request.setErrorMessage(" 已修复 ");
|
||||
request.setRemark(" 备注信息 ");
|
||||
|
||||
doReturn(existingDocument).when(ragDocumentService).getById(3003L);
|
||||
doReturn(true).when(ragDocumentService).saveOrUpdate(any(RagDocument.class));
|
||||
|
||||
boolean result = ragDocumentService.saveOrUpdate(request);
|
||||
|
||||
assertTrue(result);
|
||||
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
|
||||
verify(ragDocumentService).saveOrUpdate(documentCaptor.capture());
|
||||
RagDocument savedDocument = documentCaptor.getValue();
|
||||
assertEquals(3003L, savedDocument.getId());
|
||||
assertEquals(1001L, savedDocument.getStoreId());
|
||||
assertEquals(2002L, savedDocument.getAttachmentId());
|
||||
assertEquals("新标题", savedDocument.getDocumentTitle());
|
||||
assertEquals("新摘要", savedDocument.getDocumentSummary());
|
||||
assertEquals(RagParseStatusEnum.PARSED.name(), savedDocument.getParseStatus());
|
||||
assertEquals(RagIndexStatusEnum.INDEXED.name(), savedDocument.getIndexStatus());
|
||||
assertEquals(false, savedDocument.getEnabled());
|
||||
assertEquals("已修复", savedDocument.getErrorMessage());
|
||||
assertEquals("备注信息", savedDocument.getRemark());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user