feat(rag-document): 补全文档管理接口与页面

This commit is contained in:
zhiye.sun
2026-05-21 15:34:12 +08:00
parent 67cfbeb572
commit 541c3ff455
12 changed files with 1233 additions and 15 deletions

View File

@@ -6,6 +6,11 @@ public final class RagSystemConstants {
public static final String RAG_DOCUMENT = "RAG_DOCUMENT";
/**
* 用于 sys_attachment.sourceType 标识该附件归属于 RAG 知识库业务。
*/
public static final String SOURCE_TYPE_RAG = "RAG";
private RagSystemConstants() {
}
}

View File

@@ -1,21 +1,27 @@
package com.bruce.rag.controller;
import com.bruce.common.domain.model.RequestResult;
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.service.IRagDocumentService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@Tag(name = "RAG知识库文档管理")
@Slf4j
@RestController
@RequestMapping("/api/rag/documents")
public class RagDocumentController {
@@ -24,14 +30,59 @@ public class RagDocumentController {
private IRagDocumentService ragDocumentService;
@Operation(summary = "查询全部知识库文档")
@GetMapping
@PostMapping("/list")
public RequestResult<List<RagDocumentResponse>> list() {
return RequestResult.success(ragDocumentService.listResponses());
log.info("RagDocumentController.list start");
List<RagDocumentResponse> responses = ragDocumentService.listResponses();
log.info("RagDocumentController.list success, count={}", responses.size());
return RequestResult.success(responses);
}
@Operation(summary = "按条件查询知识库文档")
@PostMapping("/query")
public RequestResult<List<RagDocumentResponse>> query(@RequestBody RagDocumentQueryRequest request) {
return RequestResult.success(ragDocumentService.query(request));
public RequestResult<List<RagDocumentResponse>> query(@RequestBody(required = false) RagDocumentQueryRequest request) {
log.info("RagDocumentController.query start, request={}", request);
List<RagDocumentResponse> responses = ragDocumentService.query(request);
log.info("RagDocumentController.query success, count={}", responses.size());
return RequestResult.success(responses);
}
@Operation(summary = "查询知识库文档详情")
@GetMapping("/detail")
public RequestResult<RagDocumentResponse> getById(@RequestParam("id") Long id) {
log.info("RagDocumentController.getById start, id={}", id);
RagDocumentResponse response = ragDocumentService.getResponseById(id);
log.info("RagDocumentController.getById success, id={}, found={}", id, response != null);
return RequestResult.success(response);
}
@Operation(summary = "新增或修改知识库文档")
@PostMapping("/save")
public RequestResult<Boolean> saveOrUpdate(@RequestBody RagDocumentSaveRequest request) {
log.info("RagDocumentController.saveOrUpdate start, request={}", request);
Boolean result = ragDocumentService.saveOrUpdate(request);
log.info("RagDocumentController.saveOrUpdate success, id={}, result={}",
request.getId(), result);
return RequestResult.success(result);
}
@Operation(summary = "删除知识库文档")
@PostMapping("/delete")
public RequestResult<Boolean> deleteById(@RequestParam("id") Long id) {
log.info("RagDocumentController.deleteById start, id={}", id);
Boolean result = ragDocumentService.removeById(id);
log.info("RagDocumentController.deleteById success, id={}, result={}", id, result);
return RequestResult.success(result);
}
@Operation(summary = "批量上传文档到知识库")
@PostMapping("/batchUpload")
public RequestResult<List<RagDocumentResponse>> batchUpload(@ModelAttribute RagDocumentBatchUploadRequest request) {
log.info("RagDocumentController.batchUpload start, storeId={}, fileCount={}",
request.getStoreId(), request.getFiles() != null ? request.getFiles().length : 0);
List<RagDocumentResponse> responses = ragDocumentService.batchUpload(request);
log.info("RagDocumentController.batchUpload success, storeId={}, uploaded={}",
request.getStoreId(), responses.size());
return RequestResult.success(responses);
}
}

View File

@@ -0,0 +1,25 @@
package com.bruce.rag.dto.request;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.springframework.web.multipart.MultipartFile;
@Data
@Schema(description = "RAG知识库文档批量上传请求")
public class RagDocumentBatchUploadRequest {
@Schema(description = "知识库ID")
private Long storeId;
@Schema(description = "附件来源类型RAG 文档上传固定传 RAG")
private String sourceType;
@Schema(description = "上传文件列表")
private MultipartFile[] files;
@Schema(description = "文档摘要(批量设置)")
private String documentSummary;
@Schema(description = "备注(批量设置)")
private String remark;
}

View File

@@ -0,0 +1,39 @@
package com.bruce.rag.dto.request;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
@Schema(description = "RAG知识库文档保存请求")
public class RagDocumentSaveRequest {
@Schema(description = "主键ID更新时必填")
private Long id;
@Schema(description = "知识库ID")
private Long storeId;
@Schema(description = "附件ID")
private Long attachmentId;
@Schema(description = "文档标题")
private String documentTitle;
@Schema(description = "文档摘要")
private String documentSummary;
@Schema(description = "解析状态")
private String parseStatus;
@Schema(description = "索引状态")
private String indexStatus;
@Schema(description = "是否启用")
private Boolean enabled;
@Schema(description = "失败原因")
private String errorMessage;
@Schema(description = "备注")
private String remark;
}

View File

@@ -1,21 +1,30 @@
package com.bruce.rag.dto.response;
import com.bruce.common.constant.CommonConsts;
import com.bruce.rag.entity.RagDocument;
import com.fasterxml.jackson.annotation.JsonFormat;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.springframework.beans.BeanUtils;
import java.util.Date;
@Data
@Schema(description = "RAG知识库文档响应")
public class RagDocumentResponse {
@Schema(description = "主键ID")
@JsonSerialize(using = ToStringSerializer.class)
private Long id;
@Schema(description = "知识库ID")
@JsonSerialize(using = ToStringSerializer.class)
private Long storeId;
@Schema(description = "附件ID")
@JsonSerialize(using = ToStringSerializer.class)
private Long attachmentId;
@Schema(description = "文档标题")
@@ -39,6 +48,14 @@ public class RagDocumentResponse {
@Schema(description = "备注")
private String remark;
@Schema(description = "创建时间")
@JsonFormat(pattern = CommonConsts.DATE_FORMAT_LONG_STR, timezone = CommonConsts.TIME_ZONE_GMT8)
private Date createTime;
@Schema(description = "更新时间")
@JsonFormat(pattern = CommonConsts.DATE_FORMAT_LONG_STR, timezone = CommonConsts.TIME_ZONE_GMT8)
private Date updateTime;
public static RagDocumentResponse fromEntity(RagDocument entity) {
if (entity == null) {
return null;

View File

@@ -1,7 +1,9 @@
package com.bruce.rag.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.entity.RagDocument;
@@ -12,4 +14,12 @@ public interface IRagDocumentService extends IService<RagDocument> {
List<RagDocumentResponse> listResponses();
List<RagDocumentResponse> query(RagDocumentQueryRequest request);
RagDocumentResponse getResponseById(Long id);
boolean saveOrUpdate(RagDocumentSaveRequest request);
boolean removeById(Long id);
List<RagDocumentResponse> batchUpload(RagDocumentBatchUploadRequest request);
}

View File

@@ -1,36 +1,190 @@
package com.bruce.rag.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.bruce.common.domain.entity.SysAttachment;
import com.bruce.common.dto.request.SysAttachmentUploadRequest;
import com.bruce.common.service.ISysAttachmentService;
import com.bruce.rag.constant.RagSystemConstants;
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.enums.RagIndexStatusEnum;
import com.bruce.rag.enums.RagParseStatusEnum;
import com.bruce.rag.mapper.RagDocumentMapper;
import com.bruce.rag.service.IRagDocumentService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.ArrayList;
import java.util.List;
@Slf4j
@Service
public class RagDocumentServiceImpl extends ServiceImpl<RagDocumentMapper, RagDocument> implements IRagDocumentService {
@Autowired
private ISysAttachmentService sysAttachmentService;
@Override
public List<RagDocumentResponse> listResponses() {
return toResponses(list());
log.info("RagDocumentServiceImpl.listResponses start");
List<RagDocumentResponse> responses = toResponses(list());
log.info("RagDocumentServiceImpl.listResponses success, count={}", responses.size());
return responses;
}
@Override
public List<RagDocumentResponse> query(RagDocumentQueryRequest request) {
if (request == null) {
throw new IllegalArgumentException("查询请求不能为空");
}
return toResponses(lambdaQuery()
.eq(request.getStoreId() != null, RagDocument::getStoreId, request.getStoreId())
.eq(request.getAttachmentId() != null, RagDocument::getAttachmentId, request.getAttachmentId())
.eq(request.getParseStatus() != null, RagDocument::getParseStatus, request.getParseStatus())
.eq(request.getIndexStatus() != null, RagDocument::getIndexStatus, request.getIndexStatus())
.eq(request.getEnabled() != null, RagDocument::getEnabled, request.getEnabled())
log.info("RagDocumentServiceImpl.query start, request={}", request);
RagDocumentQueryRequest queryRequest = request == null ? new RagDocumentQueryRequest() : request;
String parseStatus = trimToNull(queryRequest.getParseStatus());
String indexStatus = trimToNull(queryRequest.getIndexStatus());
List<RagDocumentResponse> responses = toResponses(lambdaQuery()
.eq(queryRequest.getStoreId() != null, RagDocument::getStoreId, queryRequest.getStoreId())
.eq(queryRequest.getAttachmentId() != null, RagDocument::getAttachmentId, queryRequest.getAttachmentId())
.eq(parseStatus != null, RagDocument::getParseStatus, parseStatus)
.eq(indexStatus != null, RagDocument::getIndexStatus, indexStatus)
.eq(queryRequest.getEnabled() != null, RagDocument::getEnabled, queryRequest.getEnabled())
.orderByDesc(RagDocument::getId)
.list());
log.info("RagDocumentServiceImpl.query success, count={}", responses.size());
return responses;
}
@Override
public RagDocumentResponse getResponseById(Long id) {
log.info("RagDocumentServiceImpl.getResponseById start, id={}", id);
RagDocumentResponse response = RagDocumentResponse.fromEntity(getById(id));
log.info("RagDocumentServiceImpl.getResponseById success, id={}, found={}", id, response != null);
return response;
}
@Override
public boolean saveOrUpdate(RagDocumentSaveRequest request) {
log.info("RagDocumentServiceImpl.saveOrUpdate start, request={}", request);
validateSaveRequest(request);
RagDocument document;
if (request.getId() != null) {
document = getById(request.getId());
if (document == null) {
log.warn("RagDocumentServiceImpl.saveOrUpdate document not found, id={}", request.getId());
throw new IllegalArgumentException("文档不存在ID: " + request.getId());
}
} else {
document = new RagDocument();
document.setEnabled(true);
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
document.setIndexStatus(RagIndexStatusEnum.PENDING.name());
}
document.setStoreId(request.getStoreId());
document.setAttachmentId(request.getAttachmentId());
document.setDocumentTitle(request.getDocumentTitle().trim());
document.setDocumentSummary(trimToNull(request.getDocumentSummary()));
if (StringUtils.hasText(request.getParseStatus())) {
document.setParseStatus(request.getParseStatus().trim());
}
if (StringUtils.hasText(request.getIndexStatus())) {
document.setIndexStatus(request.getIndexStatus().trim());
}
if (request.getEnabled() != null) {
document.setEnabled(request.getEnabled());
}
document.setErrorMessage(trimToNull(request.getErrorMessage()));
document.setRemark(trimToNull(request.getRemark()));
boolean result = saveOrUpdate(document);
log.info("RagDocumentServiceImpl.saveOrUpdate success, requestId={}, savedId={}, result={}",
request.getId(), document.getId(), result);
return result;
}
@Override
public boolean removeById(Long id) {
log.info("RagDocumentServiceImpl.removeById start, id={}", id);
if (id == null) {
throw new IllegalArgumentException("文档ID不能为空");
}
RagDocument document = getById(id);
if (document == null) {
log.warn("RagDocumentServiceImpl.removeById document not found, id={}", id);
throw new IllegalArgumentException("文档不存在ID: " + id);
}
boolean result = super.removeById(id);
log.info("RagDocumentServiceImpl.removeById success, id={}, result={}", id, result);
return result;
}
@Override
public List<RagDocumentResponse> batchUpload(RagDocumentBatchUploadRequest request) {
log.info("RagDocumentServiceImpl.batchUpload start, request={}", request);
validateBatchUploadRequest(request);
List<RagDocumentResponse> results = new ArrayList<>();
for (var file : request.getFiles()) {
if (file == null || file.isEmpty()) {
continue;
}
SysAttachmentUploadRequest uploadRequest = new SysAttachmentUploadRequest();
uploadRequest.setFile(file);
uploadRequest.setSourceType(resolveSourceType(request.getSourceType()));
uploadRequest.setSourceId(request.getStoreId());
SysAttachment attachment = sysAttachmentService.upload(uploadRequest);
RagDocument document = new RagDocument();
document.setStoreId(request.getStoreId());
document.setAttachmentId(attachment.getId());
document.setDocumentTitle(StringUtils.hasText(file.getOriginalFilename()) ? file.getOriginalFilename().trim() : null);
document.setDocumentSummary(trimToNull(request.getDocumentSummary()));
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
document.setIndexStatus(RagIndexStatusEnum.PENDING.name());
document.setEnabled(true);
document.setErrorMessage(null);
document.setRemark(trimToNull(request.getRemark()));
save(document);
results.add(RagDocumentResponse.fromEntity(document));
}
log.info("RagDocumentServiceImpl.batchUpload success, storeId={}, uploaded={}",
request.getStoreId(), results.size());
return results;
}
void validateSaveRequest(RagDocumentSaveRequest request) {
if (request == null) {
throw new IllegalArgumentException("保存请求不能为空");
}
if (request.getStoreId() == null) {
throw new IllegalArgumentException("知识库ID不能为空");
}
if (!StringUtils.hasText(request.getDocumentTitle())) {
throw new IllegalArgumentException("文档标题不能为空");
}
}
void validateBatchUploadRequest(RagDocumentBatchUploadRequest request) {
if (request == null) {
throw new IllegalArgumentException("上传请求不能为空");
}
if (request.getStoreId() == null) {
throw new IllegalArgumentException("知识库ID不能为空");
}
if (StringUtils.hasText(request.getSourceType())
&& !RagSystemConstants.SOURCE_TYPE_RAG.equals(request.getSourceType().trim())) {
throw new IllegalArgumentException("sourceType必须为RAG");
}
if (request.getFiles() == null || request.getFiles().length == 0) {
throw new IllegalArgumentException("上传文件不能为空");
}
}
private List<RagDocumentResponse> toResponses(List<RagDocument> documents) {
@@ -38,4 +192,18 @@ public class RagDocumentServiceImpl extends ServiceImpl<RagDocumentMapper, RagDo
.map(RagDocumentResponse::fromEntity)
.toList();
}
private String resolveSourceType(String sourceType) {
if (!StringUtils.hasText(sourceType)) {
return RagSystemConstants.SOURCE_TYPE_RAG;
}
return RagSystemConstants.SOURCE_TYPE_RAG;
}
private String trimToNull(String value) {
if (!StringUtils.hasText(value)) {
return null;
}
return value.trim();
}
}

View File

@@ -0,0 +1,124 @@
package com.bruce.rag;
import com.bruce.common.domain.entity.SysAttachment;
import com.bruce.common.dto.request.SysAttachmentUploadRequest;
import com.bruce.common.service.ISysAttachmentService;
import com.bruce.rag.constant.RagSystemConstants;
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.enums.RagIndexStatusEnum;
import com.bruce.rag.enums.RagParseStatusEnum;
import com.bruce.rag.service.impl.RagDocumentServiceImpl;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockMultipartFile;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class RagDocumentServiceImplTests {
@Spy
@InjectMocks
private RagDocumentServiceImpl ragDocumentService;
@Mock
private ISysAttachmentService sysAttachmentService;
@Test
void batchUploadShouldUseRagSourceTypeAndStoreIdAsSourceId() {
MockMultipartFile file = new MockMultipartFile(
"files",
"knowledge.txt",
"text/plain",
"hello rag".getBytes()
);
RagDocumentBatchUploadRequest request = new RagDocumentBatchUploadRequest();
request.setStoreId(1001L);
request.setSourceType(RagSystemConstants.SOURCE_TYPE_RAG);
request.setFiles(new MockMultipartFile[]{file});
request.setDocumentSummary("批量摘要");
request.setRemark("批量备注");
SysAttachment attachment = new SysAttachment();
attachment.setId(2002L);
when(sysAttachmentService.upload(any(SysAttachmentUploadRequest.class))).thenReturn(attachment);
doAnswer(invocation -> true).when(ragDocumentService).save(any(RagDocument.class));
var responses = ragDocumentService.batchUpload(request);
ArgumentCaptor<SysAttachmentUploadRequest> uploadCaptor = ArgumentCaptor.forClass(SysAttachmentUploadRequest.class);
verify(sysAttachmentService).upload(uploadCaptor.capture());
SysAttachmentUploadRequest uploadRequest = uploadCaptor.getValue();
assertEquals(RagSystemConstants.SOURCE_TYPE_RAG, uploadRequest.getSourceType());
assertEquals(1001L, uploadRequest.getSourceId());
assertEquals(file, uploadRequest.getFile());
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
verify(ragDocumentService).save(documentCaptor.capture());
RagDocument savedDocument = documentCaptor.getValue();
assertEquals(1001L, savedDocument.getStoreId());
assertEquals(2002L, savedDocument.getAttachmentId());
assertEquals("knowledge.txt", savedDocument.getDocumentTitle());
assertEquals("批量摘要", savedDocument.getDocumentSummary());
assertEquals(RagParseStatusEnum.UPLOADED.name(), savedDocument.getParseStatus());
assertEquals(RagIndexStatusEnum.PENDING.name(), savedDocument.getIndexStatus());
assertTrue(savedDocument.getEnabled());
assertNull(savedDocument.getErrorMessage());
assertEquals("批量备注", savedDocument.getRemark());
assertEquals(1, responses.size());
assertEquals(RagParseStatusEnum.UPLOADED.name(), responses.getFirst().getParseStatus());
assertEquals(RagIndexStatusEnum.PENDING.name(), responses.getFirst().getIndexStatus());
}
@Test
void saveOrUpdateShouldWriteAllEditableFields() {
RagDocument existingDocument = new RagDocument();
existingDocument.setId(3003L);
RagDocumentSaveRequest request = new RagDocumentSaveRequest();
request.setId(3003L);
request.setStoreId(1001L);
request.setAttachmentId(2002L);
request.setDocumentTitle(" 新标题 ");
request.setDocumentSummary(" 新摘要 ");
request.setParseStatus(RagParseStatusEnum.PARSED.name());
request.setIndexStatus(RagIndexStatusEnum.INDEXED.name());
request.setEnabled(false);
request.setErrorMessage(" 已修复 ");
request.setRemark(" 备注信息 ");
doReturn(existingDocument).when(ragDocumentService).getById(3003L);
doReturn(true).when(ragDocumentService).saveOrUpdate(any(RagDocument.class));
boolean result = ragDocumentService.saveOrUpdate(request);
assertTrue(result);
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
verify(ragDocumentService).saveOrUpdate(documentCaptor.capture());
RagDocument savedDocument = documentCaptor.getValue();
assertEquals(3003L, savedDocument.getId());
assertEquals(1001L, savedDocument.getStoreId());
assertEquals(2002L, savedDocument.getAttachmentId());
assertEquals("新标题", savedDocument.getDocumentTitle());
assertEquals("新摘要", savedDocument.getDocumentSummary());
assertEquals(RagParseStatusEnum.PARSED.name(), savedDocument.getParseStatus());
assertEquals(RagIndexStatusEnum.INDEXED.name(), savedDocument.getIndexStatus());
assertEquals(false, savedDocument.getEnabled());
assertEquals("已修复", savedDocument.getErrorMessage());
assertEquals("备注信息", savedDocument.getRemark());
}
}