refactor(modules): 拆分多模块工程并收口common基础模块

This commit is contained in:
2026-06-01 03:26:18 +08:00
parent 6fe1209801
commit 07ad8bb36b
231 changed files with 1690 additions and 172 deletions

55
common-agent-rag/pom.xml Normal file
View File

@@ -0,0 +1,55 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.bruce</groupId>
<artifactId>common-agent-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
</parent>
<artifactId>common-agent-rag</artifactId>
<name>common-agent-rag</name>
<dependencies>
<dependency>
<groupId>com.bruce</groupId>
<artifactId>common-agent-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.bruce</groupId>
<artifactId>common-agent-modelprovider</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-spring-boot4-starter</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,16 @@
package com.bruce.rag.constant;
public final class RagSystemConstants {
public static final String RAG_STORE = "RAG_STORE";
public static final String RAG_DOCUMENT = "RAG_DOCUMENT";
/**
* 用于 sys_attachment.sourceType 标识该附件归属于 RAG 知识库业务。
*/
public static final String SOURCE_TYPE_RAG = "RAG";
private RagSystemConstants() {
}
}

View File

@@ -0,0 +1,126 @@
package com.bruce.rag.controller;
import com.bruce.common.domain.model.RequestResult;
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
import com.bruce.rag.dto.request.RagDocumentChunkRequest;
import com.bruce.rag.dto.request.RagDocumentParseRequest;
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
import com.bruce.rag.dto.response.RagDocumentParseResponse;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.service.IRagDocumentParseService;
import com.bruce.rag.service.IRagDocumentChunkService;
import com.bruce.rag.service.IRagDocumentService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@Tag(name = "RAG知识库文档管理")
@Slf4j
@RestController
@RequestMapping("/api/rag/documents")
public class RagDocumentController {
@Autowired
private IRagDocumentService ragDocumentService;
@Autowired
private IRagDocumentParseService ragDocumentParseService;
@Autowired
private IRagDocumentChunkService ragDocumentChunkService;
@Operation(summary = "查询全部知识库文档")
@PostMapping("/list")
public RequestResult<List<RagDocumentResponse>> list() {
log.info("RagDocumentController.list start");
List<RagDocumentResponse> responses = ragDocumentService.listResponses();
log.info("RagDocumentController.list success, count={}", responses.size());
return RequestResult.success(responses);
}
@Operation(summary = "按条件查询知识库文档")
@PostMapping("/query")
public RequestResult<List<RagDocumentResponse>> query(@RequestBody(required = false) RagDocumentQueryRequest request) {
log.info("RagDocumentController.query start, request={}", request);
List<RagDocumentResponse> responses = ragDocumentService.query(request);
log.info("RagDocumentController.query success, count={}", responses.size());
return RequestResult.success(responses);
}
@Operation(summary = "查询知识库文档详情")
@GetMapping("/detail")
public RequestResult<RagDocumentResponse> getById(@RequestParam("id") Long id) {
log.info("RagDocumentController.getById start, id={}", id);
RagDocumentResponse response = ragDocumentService.getResponseById(id);
log.info("RagDocumentController.getById success, id={}, found={}", id, response != null);
return RequestResult.success(response);
}
@Operation(summary = "新增或修改知识库文档")
@PostMapping("/save")
public RequestResult<Boolean> saveOrUpdate(@RequestBody RagDocumentSaveRequest request) {
log.info("RagDocumentController.saveOrUpdate start, request={}", request);
Boolean result = ragDocumentService.saveOrUpdate(request);
log.info("RagDocumentController.saveOrUpdate success, id={}, result={}",
request.getId(), result);
return RequestResult.success(result);
}
@Operation(summary = "删除知识库文档")
@PostMapping("/delete")
public RequestResult<Boolean> deleteById(@RequestParam("id") Long id) {
log.info("RagDocumentController.deleteById start, id={}", id);
Boolean result = ragDocumentService.removeById(id);
log.info("RagDocumentController.deleteById success, id={}, result={}", id, result);
return RequestResult.success(result);
}
@Operation(summary = "批量上传文档到知识库")
@PostMapping("/batchUpload")
public RequestResult<List<RagDocumentResponse>> batchUpload(@ModelAttribute RagDocumentBatchUploadRequest request) {
log.info("RagDocumentController.batchUpload start, storeId={}, fileCount={}",
request.getStoreId(), request.getFiles() != null ? request.getFiles().length : 0);
List<RagDocumentResponse> responses = ragDocumentService.batchUpload(request);
log.info("RagDocumentController.batchUpload success, storeId={}, uploaded={}",
request.getStoreId(), responses.size());
return RequestResult.success(responses);
}
@Operation(summary = "解析知识库文档")
@PostMapping("/parse")
public RequestResult<List<RagDocumentParseResponse>> parse(@RequestBody RagDocumentParseRequest request) {
log.info("RagDocumentController.parse start, request={}", request);
List<RagDocumentParseResponse> responses = ragDocumentParseService.parse(request);
log.info("RagDocumentController.parse success, count={}", responses.size());
return RequestResult.success(responses);
}
@Operation(summary = "重试解析知识库文档")
@PostMapping("/retryParse")
public RequestResult<List<RagDocumentParseResponse>> retryParse(@RequestBody RagDocumentParseRequest request) {
log.info("RagDocumentController.retryParse start, request={}", request);
List<RagDocumentParseResponse> responses = ragDocumentParseService.parse(request);
log.info("RagDocumentController.retryParse success, count={}", responses.size());
return RequestResult.success(responses);
}
@Operation(summary = "按策略异步切片")
@PostMapping("/chunk")
public RequestResult<Boolean> chunk(@RequestBody RagDocumentChunkRequest request) {
log.info("RagDocumentController.chunk start, request={}", request);
ragDocumentChunkService.submitChunkTask(request);
log.info("RagDocumentController.chunk submitted");
return RequestResult.success(Boolean.TRUE);
}
}

View File

@@ -0,0 +1,97 @@
package com.bruce.rag.controller;
import com.bruce.common.domain.model.RequestResult;
import com.bruce.rag.dto.request.RagStoreQueryRequest;
import com.bruce.rag.dto.request.RagStoreSaveRequest;
import com.bruce.rag.dto.response.RagStoreDocumentOverviewResponse;
import com.bruce.rag.dto.response.RagStoreOverviewResponse;
import com.bruce.rag.dto.response.RagStoreResponse;
import com.bruce.rag.service.IRagStoreService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@Tag(name = "RAG知识库管理")
@Slf4j
@RestController
@RequestMapping("/api/rag/store")
public class RagStoreController {
@Autowired
private IRagStoreService ragStoreService;
@Operation(summary = "查询全部知识库")
@PostMapping("/list")
public RequestResult<List<RagStoreResponse>> list() {
log.info("RagStoreController.list start");
List<RagStoreResponse> responses = ragStoreService.listResponses();
log.info("RagStoreController.list success, count={}", responses.size());
return RequestResult.success(responses);
}
@Operation(summary = "按条件查询知识库")
@PostMapping("/query")
public RequestResult<List<RagStoreResponse>> query(@RequestBody(required = false) RagStoreQueryRequest request) {
log.info("RagStoreController.query start, request={}", request);
List<RagStoreResponse> responses = ragStoreService.query(request);
log.info("RagStoreController.query success, count={}", responses.size());
return RequestResult.success(responses);
}
@Operation(summary = "查询知识库详情")
@GetMapping("/detail")
public RequestResult<RagStoreResponse> getById(@RequestParam("id") Long id) {
log.info("RagStoreController.getById start, id={}", id);
RagStoreResponse response = ragStoreService.getResponseById(id);
log.info("RagStoreController.getById success, id={}, found={}", id, response != null);
return RequestResult.success(response);
}
@Operation(summary = "查询知识库总览")
@GetMapping("/overview")
public RequestResult<RagStoreOverviewResponse> overview() {
log.info("RagStoreController.overview start");
RagStoreOverviewResponse response = ragStoreService.getOverview();
log.info("RagStoreController.overview success, totalStores={}, totalDocuments={}",
response.getTotalStores(), response.getTotalDocuments());
return RequestResult.success(response);
}
@Operation(summary = "查询知识库文档概览")
@GetMapping("/documentOverview")
public RequestResult<RagStoreDocumentOverviewResponse> documentOverview(@RequestParam("storeId") Long storeId) {
log.info("RagStoreController.documentOverview start, storeId={}", storeId);
RagStoreDocumentOverviewResponse response = ragStoreService.getDocumentOverview(storeId);
log.info("RagStoreController.documentOverview success, storeId={}, documentCount={}",
storeId, response.getDocumentCount());
return RequestResult.success(response);
}
@Operation(summary = "新增或修改知识库")
@PostMapping("/save")
public RequestResult<Boolean> saveOrUpdate(@RequestBody RagStoreSaveRequest request) {
log.info("RagStoreController.saveOrUpdate start, request={}", request);
Boolean result = ragStoreService.saveOrUpdate(request);
log.info("RagStoreController.saveOrUpdate success, id={}, storeCode={}, result={}",
request.getId(), request.getStoreCode(), result);
return RequestResult.success(result);
}
@Operation(summary = "删除知识库")
@PostMapping("/delete")
public RequestResult<Boolean> deleteById(@RequestParam("id") Long id) {
log.info("RagStoreController.deleteById start, id={}", id);
Boolean result = ragStoreService.removeById(id);
log.info("RagStoreController.deleteById success, id={}, result={}", id, result);
return RequestResult.success(result);
}
}

View File

@@ -0,0 +1,25 @@
package com.bruce.rag.dto.request;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.springframework.web.multipart.MultipartFile;
@Data
@Schema(description = "RAG知识库文档批量上传请求")
public class RagDocumentBatchUploadRequest {
@Schema(description = "知识库ID")
private Long storeId;
@Schema(description = "附件来源类型RAG 文档上传固定传 RAG")
private String sourceType;
@Schema(description = "上传文件列表")
private MultipartFile[] files;
@Schema(description = "文档摘要(批量设置)")
private String documentSummary;
@Schema(description = "备注(批量设置)")
private String remark;
}

View File

@@ -0,0 +1,26 @@
package com.bruce.rag.dto.request;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.List;
@Data
@Schema(description = "RAG知识库文档切片请求")
public class RagDocumentChunkRequest {
@Schema(description = "文档ID列表")
private List<Long> documentIds;
@Schema(description = "切片方式枚举值")
private Integer chunkStrategy;
@Schema(description = "切片长度")
private Integer chunkSize;
@Schema(description = "重叠长度")
private Integer chunkOverlap;
@Schema(description = "分隔符")
private String delimiter;
}

View File

@@ -0,0 +1,14 @@
package com.bruce.rag.dto.request;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.List;
@Data
@Schema(description = "RAG知识库文档解析请求")
public class RagDocumentParseRequest {
@Schema(description = "文档ID列表")
private List<Long> documentIds;
}

View File

@@ -0,0 +1,24 @@
package com.bruce.rag.dto.request;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
@Schema(description = "RAG知识库文档查询请求")
public class RagDocumentQueryRequest {
@Schema(description = "知识库ID")
private Long storeId;
@Schema(description = "附件ID")
private Long attachmentId;
@Schema(description = "解析状态")
private String parseStatus;
@Schema(description = "索引状态")
private String indexStatus;
@Schema(description = "是否启用")
private Boolean enabled;
}

View File

@@ -0,0 +1,39 @@
package com.bruce.rag.dto.request;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
@Schema(description = "RAG知识库文档保存请求")
public class RagDocumentSaveRequest {
@Schema(description = "主键ID更新时必填")
private Long id;
@Schema(description = "知识库ID")
private Long storeId;
@Schema(description = "附件ID")
private Long attachmentId;
@Schema(description = "文档标题")
private String documentTitle;
@Schema(description = "文档摘要")
private String documentSummary;
@Schema(description = "解析状态")
private String parseStatus;
@Schema(description = "索引状态")
private String indexStatus;
@Schema(description = "是否启用")
private Boolean enabled;
@Schema(description = "失败原因")
private String errorMessage;
@Schema(description = "备注")
private String remark;
}

View File

@@ -0,0 +1,18 @@
package com.bruce.rag.dto.request;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
@Schema(description = "RAG知识库查询请求")
public class RagStoreQueryRequest {
@Schema(description = "知识库编码")
private String storeCode;
@Schema(description = "知识库名称")
private String storeName;
@Schema(description = "状态")
private String status;
}

View File

@@ -0,0 +1,27 @@
package com.bruce.rag.dto.request;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
@Schema(description = "RAG知识库保存请求")
public class RagStoreSaveRequest {
@Schema(description = "主键ID")
private Long id;
@Schema(description = "知识库编码")
private String storeCode;
@Schema(description = "知识库名称")
private String storeName;
@Schema(description = "知识库描述")
private String description;
@Schema(description = "状态")
private String status;
@Schema(description = "备注")
private String remark;
}

View File

@@ -0,0 +1,11 @@
package com.bruce.rag.dto.response;
import lombok.Data;
@Data
public class RagChunkRecallResponse {
private Long chunkId;
private Long documentId;
private String chunkContent;
private Double score;
}

View File

@@ -0,0 +1,33 @@
package com.bruce.rag.dto.response;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.LinkedHashMap;
import java.util.Map;
@Data
@Schema(description = "RAG知识库文档解析响应")
public class RagDocumentParseResponse {
@Schema(description = "文档ID")
@JsonSerialize(using = ToStringSerializer.class)
private Long documentId;
@Schema(description = "解析状态")
private String parseStatus;
@Schema(description = "文本长度")
private Integer textLength;
@Schema(description = "页数")
private Integer pageCount;
@Schema(description = "工作表数量")
private Integer sheetCount;
@Schema(description = "解析元数据")
private Map<String, Object> metadata = new LinkedHashMap<>();
}

View File

@@ -0,0 +1,67 @@
package com.bruce.rag.dto.response;
import com.bruce.common.constant.CommonConsts;
import com.bruce.rag.entity.RagDocument;
import com.fasterxml.jackson.annotation.JsonFormat;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.springframework.beans.BeanUtils;
import java.util.Date;
@Data
@Schema(description = "RAG知识库文档响应")
public class RagDocumentResponse {
@Schema(description = "主键ID")
@JsonSerialize(using = ToStringSerializer.class)
private Long id;
@Schema(description = "知识库ID")
@JsonSerialize(using = ToStringSerializer.class)
private Long storeId;
@Schema(description = "附件ID")
@JsonSerialize(using = ToStringSerializer.class)
private Long attachmentId;
@Schema(description = "文档标题")
private String documentTitle;
@Schema(description = "文档摘要")
private String documentSummary;
@Schema(description = "解析状态")
private String parseStatus;
@Schema(description = "索引状态")
private String indexStatus;
@Schema(description = "是否启用")
private Boolean enabled;
@Schema(description = "失败原因")
private String errorMessage;
@Schema(description = "备注")
private String remark;
@Schema(description = "创建时间")
@JsonFormat(pattern = CommonConsts.DATE_FORMAT_LONG_STR, timezone = CommonConsts.TIME_ZONE_GMT8)
private Date createTime;
@Schema(description = "更新时间")
@JsonFormat(pattern = CommonConsts.DATE_FORMAT_LONG_STR, timezone = CommonConsts.TIME_ZONE_GMT8)
private Date updateTime;
public static RagDocumentResponse fromEntity(RagDocument entity) {
if (entity == null) {
return null;
}
RagDocumentResponse response = new RagDocumentResponse();
BeanUtils.copyProperties(entity, response);
return response;
}
}

View File

@@ -0,0 +1,38 @@
package com.bruce.rag.dto.response;
import com.bruce.common.constant.CommonConsts;
import com.fasterxml.jackson.annotation.JsonFormat;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.Date;
@Data
@Schema(description = "RAG知识库文档概览响应")
public class RagStoreDocumentOverviewResponse {
@Schema(description = "知识库ID")
@JsonSerialize(using = ToStringSerializer.class)
private Long storeId;
@Schema(description = "知识库名称")
private String storeName;
@Schema(description = "文档总数")
private Integer documentCount;
@Schema(description = "启用文档数")
private Integer enabledDocumentCount;
@Schema(description = "已解析文档数")
private Integer parsedDocumentCount;
@Schema(description = "已索引文档数")
private Integer indexedDocumentCount;
@Schema(description = "最近上传时间")
@JsonFormat(pattern = CommonConsts.DATE_FORMAT_LONG_STR, timezone = CommonConsts.TIME_ZONE_GMT8)
private Date lastUploadTime;
}

View File

@@ -0,0 +1,21 @@
package com.bruce.rag.dto.response;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
@Schema(description = "RAG知识库总览响应")
public class RagStoreOverviewResponse {
@Schema(description = "知识库总数")
private Integer totalStores;
@Schema(description = "文档总数")
private Integer totalDocuments;
@Schema(description = "切片总数")
private Integer totalChunks;
@Schema(description = "可检索知识库数")
private Integer retrievableStores;
}

View File

@@ -0,0 +1,53 @@
package com.bruce.rag.dto.response;
import com.bruce.common.constant.CommonConsts;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import com.bruce.rag.entity.RagStore;
import com.fasterxml.jackson.annotation.JsonFormat;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.springframework.beans.BeanUtils;
import java.util.Date;
@Data
@Schema(description = "RAG知识库响应")
public class RagStoreResponse {
@JsonSerialize(using = ToStringSerializer.class)
@Schema(description = "主键ID")
private Long id;
@Schema(description = "知识库编码")
private String storeCode;
@Schema(description = "知识库名称")
private String storeName;
@Schema(description = "知识库描述")
private String description;
@Schema(description = "状态")
private String status;
@Schema(description = "备注")
private String remark;
@Schema(description = "创建时间")
@JsonFormat(pattern = CommonConsts.DATE_FORMAT_LONG_STR, timezone = CommonConsts.TIME_ZONE_GMT8)
private Date createTime;
@Schema(description = "更新时间")
@JsonFormat(pattern = CommonConsts.DATE_FORMAT_LONG_STR, timezone = CommonConsts.TIME_ZONE_GMT8)
private Date updateTime;
public static RagStoreResponse fromEntity(RagStore entity) {
if (entity == null) {
return null;
}
RagStoreResponse response = new RagStoreResponse();
BeanUtils.copyProperties(entity, response);
return response;
}
}

View File

@@ -0,0 +1,68 @@
package com.bruce.rag.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import com.bruce.common.domain.model.BaseEntity;
import com.bruce.common.typehandler.PgJsonbStringTypeHandler;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
@TableName(value = "rag_chunk", autoResultMap = true)
@Schema(description = "RAG知识切片")
public class RagChunk extends BaseEntity {
@Schema(description = "知识库ID")
@TableField("store_id")
private Long storeId;
@Schema(description = "文档ID")
@TableField("document_id")
private Long documentId;
@Schema(description = "文档内切片序号")
@TableField("chunk_index")
private Integer chunkIndex;
@Schema(description = "切片内容")
@TableField("chunk_content")
private String chunkContent;
@Schema(description = "切片摘要")
@TableField("chunk_summary")
private String chunkSummary;
@Schema(description = "Token数量")
@TableField("token_count")
private Integer tokenCount;
@Schema(description = "页码")
@TableField("page_number")
private Integer pageNumber;
@Schema(description = "章节标题")
@TableField("section_title")
private String sectionTitle;
@Schema(description = "标题路径")
@TableField("heading_path")
private String headingPath;
@Schema(description = "向量ID")
@TableField("vector_id")
private String vectorId;
@Schema(description = "切片级扩展元数据JSON")
@TableField(value = "metadata_json", typeHandler = PgJsonbStringTypeHandler.class)
private String metadataJson;
@Schema(description = "是否启用")
private Boolean enabled;
@Schema(description = "备注")
private String remark;
}

View File

@@ -0,0 +1,50 @@
package com.bruce.rag.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import com.bruce.common.domain.model.BaseEntity;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
@TableName("rag_chunk_embedding")
@Schema(description = "RAG切片向量")
public class RagChunkEmbedding extends BaseEntity {
@Schema(description = "知识库ID")
@TableField("store_id")
private Long storeId;
@Schema(description = "文档ID")
@TableField("document_id")
private Long documentId;
@Schema(description = "切片ID")
@TableField("chunk_id")
private Long chunkId;
@Schema(description = "向量模型")
@TableField("embedding_model")
private String embeddingModel;
@Schema(description = "向量维度")
@TableField("embedding_dimension")
private Integer embeddingDimension;
@Schema(description = "向量内容")
private String embedding;
@Schema(description = "向量生成内容哈希")
@TableField("content_hash")
private String contentHash;
@Schema(description = "是否启用")
private Boolean enabled;
@Schema(description = "备注")
private String remark;
}

View File

@@ -0,0 +1,51 @@
package com.bruce.rag.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import com.bruce.common.domain.model.BaseEntity;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
@TableName("rag_document")
@Schema(description = "RAG知识库文档")
public class RagDocument extends BaseEntity {
@Schema(description = "知识库ID")
@TableField("store_id")
private Long storeId;
@Schema(description = "附件ID")
@TableField("attachment_id")
private Long attachmentId;
@Schema(description = "文档标题")
@TableField("document_title")
private String documentTitle;
@Schema(description = "文档摘要")
@TableField("document_summary")
private String documentSummary;
@Schema(description = "解析状态")
@TableField("parse_status")
private String parseStatus;
@Schema(description = "索引状态")
@TableField("index_status")
private String indexStatus;
@Schema(description = "是否启用")
private Boolean enabled;
@Schema(description = "失败原因")
@TableField("error_message")
private String errorMessage;
@Schema(description = "备注")
private String remark;
}

View File

@@ -0,0 +1,62 @@
package com.bruce.rag.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import com.bruce.common.domain.model.BaseEntity;
import com.bruce.common.typehandler.PgJsonbStringTypeHandler;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
@TableName(value = "rag_document_parse_result", autoResultMap = true)
@Schema(description = "RAG文档解析结果快照")
public class RagDocumentParseResult extends BaseEntity {
@Schema(description = "知识库ID")
@TableField("store_id")
private Long storeId;
@Schema(description = "文档ID")
@TableField("document_id")
private Long documentId;
@Schema(description = "解析文本")
@TableField("parsed_text")
private String parsedText;
@Schema(description = "文本长度")
@TableField("text_length")
private Integer textLength;
@Schema(description = "页数")
@TableField("page_count")
private Integer pageCount;
@Schema(description = "工作表数量")
@TableField("sheet_count")
private Integer sheetCount;
@Schema(description = "解析元数据JSON")
@TableField(value = "metadata_json", typeHandler = PgJsonbStringTypeHandler.class)
private String metadataJson;
@Schema(description = "解析结果哈希")
@TableField("content_hash")
private String contentHash;
@Schema(description = "解析版本")
@TableField("parse_version")
private Integer parseVersion;
@Schema(description = "是否启用")
@TableField("enabled")
private Boolean enabled;
@Schema(description = "备注")
@TableField("remark")
private String remark;
}

View File

@@ -0,0 +1,31 @@
package com.bruce.rag.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import com.bruce.common.domain.model.BaseEntity;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
@TableName("rag_store")
@Schema(description = "RAG知识库")
public class RagStore extends BaseEntity {
@Schema(description = "知识库编码")
private String storeCode;
@Schema(description = "知识库名称")
private String storeName;
@Schema(description = "知识库描述")
private String description;
@Schema(description = "状态")
private String status;
@Schema(description = "备注")
private String remark;
}

View File

@@ -0,0 +1,51 @@
package com.bruce.rag.enums;
import com.bruce.common.enums.PersistableSysEnumDefinition;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
@Getter
@AllArgsConstructor
public enum RagChunkStrategyEnum implements PersistableSysEnumDefinition {
FIXED_LENGTH(1, "固定长度切片"),
PARAGRAPH(2, "按段落切片"),
HEADING(3, "按标题层级切片"),
TABLE_ROW(4, "按表格行切片"),
DELIMITER(5, "按分隔符切片"),
SEMANTIC(6, "语义切片");
private static final String CATALOG = "rag";
private static final String TYPE = "chunk_strategy";
private static final String REMARK = "RAG文档切片方式";
private final Integer value;
private final String label;
public static RagChunkStrategyEnum fromValue(Integer value) {
return Arrays.stream(values())
.filter(item -> item.getValue().equals(value))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("不支持的切片方式: " + value));
}
@Override
public String getCatalog() {
return CATALOG;
}
@Override
public String getType() {
return TYPE;
}
@Override
public String getRemark() {
return REMARK;
}
}

View File

@@ -0,0 +1,40 @@
package com.bruce.rag.enums;
import com.bruce.common.enums.PersistableSysEnumDefinition;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum RagIndexStatusEnum implements PersistableSysEnumDefinition {
PENDING(1, "待索引"),
INDEXING(2, "索引中"),
INDEXED(3, "已索引"),
FAILED(4, "索引失败");
private static final String CATALOG = "rag";
private static final String TYPE = "index_status";
private static final String REMARK = "RAG文档索引状态";
private final Integer value;
private final String label;
@Override
public String getCatalog() {
return CATALOG;
}
@Override
public String getType() {
return TYPE;
}
@Override
public String getRemark() {
return REMARK;
}
}

View File

@@ -0,0 +1,40 @@
package com.bruce.rag.enums;
import com.bruce.common.enums.PersistableSysEnumDefinition;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum RagParseStatusEnum implements PersistableSysEnumDefinition {
UPLOADED(1, "已上传"),
PARSING(2, "解析中"),
PARSED(3, "已解析"),
FAILED(4, "解析失败");
private static final String CATALOG = "rag";
private static final String TYPE = "parse_status";
private static final String REMARK = "RAG文档解析状态";
private final Integer value;
private final String label;
@Override
public String getCatalog() {
return CATALOG;
}
@Override
public String getType() {
return TYPE;
}
@Override
public String getRemark() {
return REMARK;
}
}

View File

@@ -0,0 +1,32 @@
package com.bruce.rag.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.bruce.rag.dto.response.RagChunkRecallResponse;
import com.bruce.rag.entity.RagChunkEmbedding;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import java.util.List;
@Mapper
public interface RagChunkEmbeddingMapper extends BaseMapper<RagChunkEmbedding> {
@Select("""
SELECT
e.chunk_id AS chunkId,
e.document_id AS documentId,
c.chunk_content AS chunkContent,
1 - (e.embedding <=> CAST(#{queryVector} AS vector)) AS score
FROM rag_chunk_embedding e
INNER JOIN rag_chunk c ON c.id = e.chunk_id
WHERE e.store_id = #{storeId}
AND e.enabled = TRUE
AND c.enabled = TRUE
ORDER BY e.embedding <=> CAST(#{queryVector} AS vector)
LIMIT #{topK}
""")
List<RagChunkRecallResponse> queryTopKByStore(@Param("storeId") Long storeId,
@Param("queryVector") String queryVector,
@Param("topK") int topK);
}

View File

@@ -0,0 +1,9 @@
package com.bruce.rag.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.bruce.rag.entity.RagChunk;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface RagChunkMapper extends BaseMapper<RagChunk> {
}

View File

@@ -0,0 +1,9 @@
package com.bruce.rag.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.bruce.rag.entity.RagDocument;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface RagDocumentMapper extends BaseMapper<RagDocument> {
}

View File

@@ -0,0 +1,9 @@
package com.bruce.rag.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.bruce.rag.entity.RagDocumentParseResult;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface RagDocumentParseResultMapper extends BaseMapper<RagDocumentParseResult> {
}

View File

@@ -0,0 +1,9 @@
package com.bruce.rag.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.bruce.rag.entity.RagStore;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface RagStoreMapper extends BaseMapper<RagStore> {
}

View File

@@ -0,0 +1,26 @@
package com.bruce.rag.parse;
import com.bruce.rag.entity.RagChunk;
import com.bruce.rag.enums.RagChunkStrategyEnum;
import java.util.List;
/**
* 切片策略统一接口。
* <p>
* 这里的职责只有两个:
* 1. 告诉工厂自己支持哪一种切片策略
* 2. 根据切片命令生成切片结果
*/
public interface Chunker {
/**
* 判断当前实现是否支持指定的切片策略。
*/
boolean supports(RagChunkStrategyEnum strategy);
/**
* 执行切片,返回内存中的切片对象列表。
*/
List<RagChunk> chunk(RagChunkCommand command);
}

View File

@@ -0,0 +1,32 @@
package com.bruce.rag.parse;
import com.bruce.rag.enums.RagChunkStrategyEnum;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
/**
* 切片策略工厂。
* <p>
* Spring 会把所有实现了 {@link Chunker} 的 Bean 注入进来,
* 工厂再根据切片策略挑出对应实现,避免业务层自己写 if-else 或 switch。
*/
public class ChunkerFactory {
private final List<Chunker> chunkers;
public ChunkerFactory(List<Chunker> chunkers) {
this.chunkers = chunkers;
}
/**
* 根据切片策略解析出具体的切片器实现。
*/
public Chunker resolve(RagChunkStrategyEnum strategy) {
return chunkers.stream()
.filter(chunker -> chunker.supports(strategy))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("不支持的切片方式: " + strategy));
}
}

View File

@@ -0,0 +1,45 @@
package com.bruce.rag.parse;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.rag.entity.RagDocument;
import lombok.Data;
@Data
/**
* 一次切片请求的上下文参数。
* <p>
* 这里把文档信息、解析结果和切片配置收拢到一个对象里,
* 这样切片器接口不会因为参数越来越多而变得难维护。
*/
public class RagChunkCommand {
/**
* 当前要切片的文档实体。
*/
private RagDocument document;
/**
* 文档解析后的文本结果。
*/
private DocumentParseResult parseResult;
/**
* 切片策略枚举值,通常来自前端请求。
*/
private Integer chunkStrategy;
/**
* 单个切片的目标长度,主要给定长切片使用。
*/
private Integer chunkSize;
/**
* 相邻切片之间的重叠长度,主要给定长切片使用。
*/
private Integer chunkOverlap;
/**
* 自定义分隔符,主要给分隔符切片使用。
*/
private String delimiter;
}

View File

@@ -0,0 +1,68 @@
package com.bruce.rag.parse.impl;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.rag.entity.RagChunk;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.enums.RagChunkStrategyEnum;
import com.bruce.rag.parse.Chunker;
import com.bruce.rag.parse.RagChunkCommand;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;
@Component
/**
* 分隔符切片实现。
* <p>
* 先按外部传入的 delimiter 拆分文本,再过滤空片段,生成顺序切片。
*/
public class DelimiterChunker implements Chunker {
@Override
public boolean supports(RagChunkStrategyEnum strategy) {
return RagChunkStrategyEnum.DELIMITER == strategy;
}
@Override
public List<RagChunk> chunk(RagChunkCommand command) {
String text = extractText(command);
String delimiter = command == null ? null : command.getDelimiter();
if (!StringUtils.hasText(text) || !StringUtils.hasText(delimiter)) {
return List.of();
}
// 使用 Pattern.quote 处理正则特殊字符,确保分隔符按字面值切分。
String[] parts = text.split(Pattern.quote(delimiter));
List<RagChunk> chunks = new ArrayList<>();
for (String part : parts) {
if (!StringUtils.hasText(part)) {
continue;
}
chunks.add(buildChunk(command.getDocument(), chunks.size(), part.trim()));
}
return chunks;
}
private String extractText(RagChunkCommand command) {
DocumentParseResult parseResult = command == null ? null : command.getParseResult();
return parseResult == null ? null : parseResult.getText();
}
/**
* 分隔符切片同样只负责生成基础切片结构,不处理持久化和向量化。
*/
private RagChunk buildChunk(RagDocument document, int index, String content) {
RagChunk chunk = new RagChunk();
if (document != null) {
chunk.setStoreId(document.getStoreId());
chunk.setDocumentId(document.getId());
}
chunk.setChunkIndex(index);
chunk.setChunkContent(content);
chunk.setEnabled(Boolean.TRUE);
return chunk;
}
}

View File

@@ -0,0 +1,91 @@
package com.bruce.rag.parse.impl;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.rag.entity.RagChunk;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.enums.RagChunkStrategyEnum;
import com.bruce.rag.parse.Chunker;
import com.bruce.rag.parse.RagChunkCommand;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import java.util.ArrayList;
import java.util.List;
@Component
/**
* 定长切片实现。
* <p>
* 按 chunkSize 顺序截取文本,并结合 chunkOverlap 控制相邻切片的重叠部分。
*/
public class FixedLengthChunker implements Chunker {
@Override
public boolean supports(RagChunkStrategyEnum strategy) {
return RagChunkStrategyEnum.FIXED_LENGTH == strategy;
}
@Override
public List<RagChunk> chunk(RagChunkCommand command) {
String text = extractText(command);
if (!StringUtils.hasText(text)) {
return List.of();
}
int chunkSize = resolveChunkSize(command, text.length());
int overlap = resolveChunkOverlap(command, chunkSize);
// 实际步长等于切片长度减去重叠长度,最小保证为 1避免死循环。
int step = Math.max(1, chunkSize - overlap);
List<RagChunk> chunks = new ArrayList<>();
for (int start = 0; start < text.length(); start += step) {
int end = Math.min(text.length(), start + chunkSize);
chunks.add(buildChunk(command.getDocument(), chunks.size(), text.substring(start, end)));
if (end >= text.length()) {
break;
}
}
return chunks;
}
private String extractText(RagChunkCommand command) {
DocumentParseResult parseResult = command == null ? null : command.getParseResult();
return parseResult == null ? null : parseResult.getText();
}
/**
* 当未传 chunkSize 或传入非法值时,退化为整段文本一个切片。
*/
private int resolveChunkSize(RagChunkCommand command, int textLength) {
Integer chunkSize = command == null ? null : command.getChunkSize();
if (chunkSize == null || chunkSize <= 0) {
return textLength;
}
return chunkSize;
}
/**
* overlap 不能为负,也不能大于等于 chunkSize否则步长会变成 0 或负数。
*/
private int resolveChunkOverlap(RagChunkCommand command, int chunkSize) {
Integer overlap = command == null ? null : command.getChunkOverlap();
if (overlap == null || overlap < 0) {
return 0;
}
return Math.min(overlap, Math.max(0, chunkSize - 1));
}
/**
* 这里只构造最基础的切片对象,后续落库时再补充摘要、向量等扩展字段。
*/
private RagChunk buildChunk(RagDocument document, int index, String content) {
RagChunk chunk = new RagChunk();
if (document != null) {
chunk.setStoreId(document.getStoreId());
chunk.setDocumentId(document.getId());
}
chunk.setChunkIndex(index);
chunk.setChunkContent(content);
chunk.setEnabled(Boolean.TRUE);
return chunk;
}
}

View File

@@ -0,0 +1,7 @@
package com.bruce.rag.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.bruce.rag.entity.RagChunkEmbedding;
public interface IRagChunkEmbeddingService extends IService<RagChunkEmbedding> {
}

View File

@@ -0,0 +1,7 @@
package com.bruce.rag.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.bruce.rag.entity.RagChunk;
public interface IRagChunkService extends IService<RagChunk> {
}

View File

@@ -0,0 +1,8 @@
package com.bruce.rag.service;
import java.util.List;
public interface IRagDocumentAutoParseService {
void parseUploadedDocuments(List<Long> documentIds);
}

View File

@@ -0,0 +1,8 @@
package com.bruce.rag.service;
import com.bruce.rag.dto.request.RagDocumentChunkRequest;
public interface IRagDocumentChunkService {
void submitChunkTask(RagDocumentChunkRequest request);
}

View File

@@ -0,0 +1,14 @@
package com.bruce.rag.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.rag.entity.RagDocumentParseResult;
public interface IRagDocumentParseResultService extends IService<RagDocumentParseResult> {
RagDocumentParseResult getByDocumentId(Long documentId);
void saveSnapshot(Long storeId, Long documentId, DocumentParseResult parseResult);
DocumentParseResult toParseResult(RagDocumentParseResult snapshot);
}

View File

@@ -0,0 +1,16 @@
package com.bruce.rag.service;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.rag.dto.response.RagDocumentParseResponse;
import com.bruce.rag.dto.request.RagDocumentParseRequest;
import java.util.List;
public interface IRagDocumentParseService {
DocumentParseResult parseDocumentResult(Long documentId);
RagDocumentParseResponse parse(Long documentId);
List<RagDocumentParseResponse> parse(RagDocumentParseRequest request);
}

View File

@@ -0,0 +1,25 @@
package com.bruce.rag.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.entity.RagDocument;
import java.util.List;
public interface IRagDocumentService extends IService<RagDocument> {
List<RagDocumentResponse> listResponses();
List<RagDocumentResponse> query(RagDocumentQueryRequest request);
RagDocumentResponse getResponseById(Long id);
boolean saveOrUpdate(RagDocumentSaveRequest request);
boolean removeById(Long id);
List<RagDocumentResponse> batchUpload(RagDocumentBatchUploadRequest request);
}

View File

@@ -0,0 +1,26 @@
package com.bruce.rag.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.bruce.rag.dto.request.RagStoreQueryRequest;
import com.bruce.rag.dto.request.RagStoreSaveRequest;
import com.bruce.rag.dto.response.RagStoreDocumentOverviewResponse;
import com.bruce.rag.dto.response.RagStoreOverviewResponse;
import com.bruce.rag.dto.response.RagStoreResponse;
import com.bruce.rag.entity.RagStore;
import java.util.List;
public interface IRagStoreService extends IService<RagStore> {
List<RagStoreResponse> listResponses();
List<RagStoreResponse> query(RagStoreQueryRequest request);
RagStoreResponse getResponseById(Long id);
RagStoreOverviewResponse getOverview();
RagStoreDocumentOverviewResponse getDocumentOverview(Long storeId);
boolean saveOrUpdate(RagStoreSaveRequest request);
}

View File

@@ -0,0 +1,12 @@
package com.bruce.rag.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.bruce.rag.entity.RagChunkEmbedding;
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
import com.bruce.rag.service.IRagChunkEmbeddingService;
import org.springframework.stereotype.Service;
@Service
public class RagChunkEmbeddingServiceImpl extends ServiceImpl<RagChunkEmbeddingMapper, RagChunkEmbedding>
implements IRagChunkEmbeddingService {
}

View File

@@ -0,0 +1,11 @@
package com.bruce.rag.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.bruce.rag.entity.RagChunk;
import com.bruce.rag.mapper.RagChunkMapper;
import com.bruce.rag.service.IRagChunkService;
import org.springframework.stereotype.Service;
@Service
public class RagChunkServiceImpl extends ServiceImpl<RagChunkMapper, RagChunk> implements IRagChunkService {
}

View File

@@ -0,0 +1,37 @@
package com.bruce.rag.service.impl;
import com.bruce.rag.service.IRagDocumentAutoParseService;
import com.bruce.rag.service.IRagDocumentParseService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
public class RagDocumentAutoParseServiceImpl implements IRagDocumentAutoParseService {
private final IRagDocumentParseService ragDocumentParseService;
@Override
@Async
public void parseUploadedDocuments(List<Long> documentIds) {
if (documentIds == null || documentIds.isEmpty()) {
return;
}
for (Long documentId : documentIds) {
if (documentId == null) {
continue;
}
try {
ragDocumentParseService.parse(documentId);
} catch (RuntimeException e) {
log.warn("RagDocumentAutoParseServiceImpl.parseUploadedDocuments failed, documentId={}, message={}",
documentId, e.getMessage());
}
}
}
}

View File

@@ -0,0 +1,176 @@
package com.bruce.rag.service.impl;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.modelprovider.entity.RagStoreModelConfig;
import com.bruce.modelprovider.gateway.EmbeddingModelGateway;
import com.bruce.modelprovider.gateway.EmbeddingRequest;
import com.bruce.modelprovider.gateway.EmbeddingResult;
import com.bruce.modelprovider.service.IRagStoreModelConfigService;
import com.bruce.rag.dto.request.RagDocumentChunkRequest;
import com.bruce.rag.entity.RagChunk;
import com.bruce.rag.entity.RagChunkEmbedding;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.entity.RagDocumentParseResult;
import com.bruce.rag.enums.RagChunkStrategyEnum;
import com.bruce.rag.enums.RagIndexStatusEnum;
import com.bruce.rag.parse.Chunker;
import com.bruce.rag.parse.ChunkerFactory;
import com.bruce.rag.parse.RagChunkCommand;
import com.bruce.rag.service.IRagChunkEmbeddingService;
import com.bruce.rag.service.IRagChunkService;
import com.bruce.rag.service.IRagDocumentChunkService;
import com.bruce.rag.service.IRagDocumentParseResultService;
import com.bruce.rag.service.IRagDocumentService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.DigestUtils;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
/**
* RagDocumentChunkServiceImpl负责模型平台对应层的职责。
*/
public class RagDocumentChunkServiceImpl implements IRagDocumentChunkService {
private final IRagDocumentService ragDocumentService;
private final IRagDocumentParseResultService ragDocumentParseResultService;
private final ChunkerFactory chunkerFactory;
private final IRagChunkService ragChunkService;
private final IRagChunkEmbeddingService ragChunkEmbeddingService;
private final IRagStoreModelConfigService ragStoreModelConfigService;
private final EmbeddingModelGateway embeddingModelGateway;
@Override
@Async
/**
* 方法 submitChunkTask用于执行业务逻辑处理。
*/
public void submitChunkTask(RagDocumentChunkRequest request) {
validateRequest(request);
RagChunkStrategyEnum strategy = RagChunkStrategyEnum.fromValue(request.getChunkStrategy());
Chunker chunker = chunkerFactory.resolve(strategy);
for (Long documentId : request.getDocumentIds()) {
try {
RagDocument document = ragDocumentService.getById(documentId);
if (document == null) {
log.warn("RagDocumentChunkServiceImpl.chunkAsync document not found, documentId={}", documentId);
continue;
}
updateIndexStatus(document, RagIndexStatusEnum.INDEXING.name(), null);
RagDocumentParseResult snapshot = ragDocumentParseResultService.getByDocumentId(documentId);
if (snapshot == null) {
throw new IllegalStateException("文档尚未生成解析快照documentId=" + documentId);
}
RagStoreModelConfig storeModelConfig = ragStoreModelConfigService.getActiveEntity(document.getStoreId());
if (storeModelConfig == null || storeModelConfig.getEmbeddingModelId() == null) {
throw new IllegalStateException("请先配置知识库 Embedding 模型");
}
DocumentParseResult parseResult = ragDocumentParseResultService.toParseResult(snapshot);
RagChunkCommand command = new RagChunkCommand();
command.setDocument(document);
command.setParseResult(parseResult);
command.setChunkStrategy(request.getChunkStrategy());
command.setChunkSize(request.getChunkSize());
command.setChunkOverlap(request.getChunkOverlap());
command.setDelimiter(request.getDelimiter());
List<RagChunk> chunks = chunker.chunk(command);
ragChunkEmbeddingService.remove(Wrappers.<RagChunkEmbedding>lambdaQuery()
.eq(RagChunkEmbedding::getDocumentId, documentId));
ragChunkService.remove(Wrappers.<RagChunk>lambdaQuery()
.eq(RagChunk::getDocumentId, documentId));
if (!chunks.isEmpty()) {
ragChunkService.saveBatch(chunks);
}
writeEmbeddings(document, chunks, storeModelConfig);
updateIndexStatus(document, RagIndexStatusEnum.INDEXED.name(), null);
log.info("RagDocumentChunkServiceImpl.chunkAsync success, documentId={}, chunkCount={}",
documentId, chunks.size());
} catch (RuntimeException e) {
RagDocument failedDoc = ragDocumentService.getById(documentId);
if (failedDoc != null) {
updateIndexStatus(failedDoc, RagIndexStatusEnum.FAILED.name(), e.getMessage());
}
log.warn("RagDocumentChunkServiceImpl.chunkAsync failed, documentId={}, message={}",
documentId, e.getMessage());
}
}
}
@Transactional
/**
* 方法 writeEmbeddings用于执行业务逻辑处理。
*/
public void writeEmbeddings(RagDocument document, List<RagChunk> chunks, RagStoreModelConfig storeModelConfig) {
if (chunks == null || chunks.isEmpty()) {
return;
}
EmbeddingRequest embeddingRequest = new EmbeddingRequest();
embeddingRequest.setTexts(chunks.stream().map(RagChunk::getChunkContent).toList());
embeddingRequest.setTaskType("RAG_EMBEDDING");
embeddingRequest.setMatchScope("RAG_STORE");
embeddingRequest.setScopeId(document.getStoreId());
embeddingRequest.setBizType("RAG_DOCUMENT_INDEX");
embeddingRequest.setBizId(String.valueOf(document.getId()));
embeddingRequest.setExpectedDimension(storeModelConfig.getEmbeddingDimension());
EmbeddingResult result = embeddingModelGateway.embed(embeddingRequest);
if (result.getVectors().size() != chunks.size()) {
throw new IllegalStateException("向量数量与切片数量不一致");
}
List<RagChunkEmbedding> embeddingRows = new ArrayList<>();
for (int i = 0; i < chunks.size(); i++) {
RagChunk chunk = chunks.get(i);
List<Double> vector = result.getVectors().get(i);
RagChunkEmbedding row = new RagChunkEmbedding();
row.setStoreId(document.getStoreId());
row.setDocumentId(document.getId());
row.setChunkId(chunk.getId());
row.setEmbeddingModel(result.getModelName());
row.setEmbeddingDimension(result.getDimension());
row.setEmbedding(vector.toString());
row.setContentHash(DigestUtils.md5DigestAsHex(chunk.getChunkContent().getBytes(StandardCharsets.UTF_8)));
row.setEnabled(true);
embeddingRows.add(row);
}
ragChunkEmbeddingService.saveBatch(embeddingRows);
}
/**
* 方法 updateIndexStatus用于执行业务逻辑处理。
*/
private void updateIndexStatus(RagDocument document, String status, String errorMessage) {
document.setIndexStatus(status);
document.setErrorMessage(errorMessage == null ? null : errorMessage.substring(0, Math.min(errorMessage.length(), 1000)));
ragDocumentService.updateById(document);
}
/**
* 方法 validateRequest用于执行业务逻辑处理。
*/
private void validateRequest(RagDocumentChunkRequest request) {
if (request == null) {
throw new IllegalArgumentException("切片请求不能为空");
}
if (request.getDocumentIds() == null || request.getDocumentIds().isEmpty()) {
throw new IllegalArgumentException("文档ID列表不能为空");
}
RagChunkStrategyEnum.fromValue(request.getChunkStrategy());
}
}

View File

@@ -0,0 +1,109 @@
package com.bruce.rag.service.impl;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.rag.entity.RagDocumentParseResult;
import com.bruce.rag.mapper.RagDocumentParseResultMapper;
import com.bruce.rag.service.IRagDocumentParseResultService;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.util.DigestUtils;
import org.springframework.util.StringUtils;
import tools.jackson.core.type.TypeReference;
import tools.jackson.databind.ObjectMapper;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashMap;
import java.util.Map;
@Service
@RequiredArgsConstructor
public class RagDocumentParseResultServiceImpl extends ServiceImpl<RagDocumentParseResultMapper, RagDocumentParseResult>
implements IRagDocumentParseResultService {
private final ObjectMapper objectMapper;
@Override
public RagDocumentParseResult getByDocumentId(Long documentId) {
if (documentId == null) {
return null;
}
return getOne(Wrappers.<RagDocumentParseResult>lambdaQuery()
.eq(RagDocumentParseResult::getDocumentId, documentId)
.last("limit 1"));
}
@Override
public void saveSnapshot(Long storeId, Long documentId, DocumentParseResult parseResult) {
if (storeId == null || documentId == null || parseResult == null) {
throw new IllegalArgumentException("保存解析快照参数不完整");
}
RagDocumentParseResult existing = getByDocumentId(documentId);
RagDocumentParseResult snapshot = existing == null ? new RagDocumentParseResult() : existing;
snapshot.setStoreId(storeId);
snapshot.setDocumentId(documentId);
snapshot.setParsedText(parseResult.getText());
snapshot.setTextLength(parseResult.getTextLength());
snapshot.setPageCount(parseResult.getPageCount());
snapshot.setSheetCount(parseResult.getSheetCount());
snapshot.setMetadataJson(toJson(parseResult.getMetadata()));
snapshot.setContentHash(buildHash(parseResult.getText()));
snapshot.setParseVersion(resolveNextVersion(existing));
snapshot.setEnabled(Boolean.TRUE);
if (snapshot.getId() == null) {
save(snapshot);
} else {
updateById(snapshot);
}
}
@Override
public DocumentParseResult toParseResult(RagDocumentParseResult snapshot) {
if (snapshot == null) {
return null;
}
DocumentParseResult result = new DocumentParseResult();
result.setText(snapshot.getParsedText());
result.setTextLength(snapshot.getTextLength());
result.setPageCount(snapshot.getPageCount());
result.setSheetCount(snapshot.getSheetCount());
result.setMetadata(fromJson(snapshot.getMetadataJson()));
return result;
}
private Integer resolveNextVersion(RagDocumentParseResult existing) {
if (existing == null || existing.getParseVersion() == null || existing.getParseVersion() < 1) {
return 1;
}
return existing.getParseVersion() + 1;
}
private String buildHash(String text) {
if (!StringUtils.hasText(text)) {
return null;
}
return DigestUtils.md5DigestAsHex(text.getBytes(StandardCharsets.UTF_8));
}
private String toJson(Map<String, Object> metadata) {
try {
Map<String, Object> payload = metadata == null ? new LinkedHashMap<>() : metadata;
return objectMapper.writeValueAsString(payload);
} catch (Exception e) {
throw new IllegalStateException("解析元数据序列化失败", e);
}
}
private Map<String, Object> fromJson(String metadataJson) {
if (!StringUtils.hasText(metadataJson)) {
return new LinkedHashMap<>();
}
try {
return objectMapper.readValue(metadataJson, new TypeReference<>() {
});
} catch (Exception e) {
throw new IllegalStateException("解析元数据反序列化失败", e);
}
}
}

View File

@@ -0,0 +1,168 @@
package com.bruce.rag.service.impl;
import com.bruce.common.config.AttachmentProperties;
import com.bruce.common.document.parse.DocumentParseContext;
import com.bruce.common.document.parse.DocumentParseException;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.common.document.parse.DocumentParser;
import com.bruce.common.document.parse.DocumentParserFactory;
import com.bruce.common.domain.entity.SysAttachment;
import com.bruce.common.service.ISysAttachmentService;
import com.bruce.rag.dto.request.RagDocumentParseRequest;
import com.bruce.rag.dto.response.RagDocumentParseResponse;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.entity.RagDocumentParseResult;
import com.bruce.rag.enums.RagParseStatusEnum;
import com.bruce.rag.mapper.RagDocumentMapper;
import com.bruce.rag.service.IRagDocumentParseService;
import com.bruce.rag.service.IRagDocumentParseResultService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
public class RagDocumentParseServiceImpl implements IRagDocumentParseService {
private final RagDocumentMapper ragDocumentMapper;
private final ISysAttachmentService sysAttachmentService;
private final AttachmentProperties attachmentProperties;
private final DocumentParserFactory documentParserFactory;
private final IRagDocumentParseResultService ragDocumentParseResultService;
@Override
public List<RagDocumentParseResponse> parse(RagDocumentParseRequest request) {
log.info("RagDocumentParseServiceImpl.parse batch start, request={}", request);
validateParseRequest(request);
List<RagDocumentParseResponse> responses = request.getDocumentIds().stream()
.map(this::parse)
.toList();
log.info("RagDocumentParseServiceImpl.parse batch success, count={}", responses.size());
return responses;
}
@Override
public DocumentParseResult parseDocumentResult(Long documentId) {
RagDocumentParseResult snapshot = ragDocumentParseResultService.getByDocumentId(documentId);
if (snapshot != null) {
return ragDocumentParseResultService.toParseResult(snapshot);
}
return doParse(documentId);
}
@Override
public RagDocumentParseResponse parse(Long documentId) {
DocumentParseResult result = doParse(documentId);
RagDocumentParseResponse response = toResponse(documentId, result);
log.info("RagDocumentParseServiceImpl.parse success, documentId={}, textLength={}",
documentId, response.getTextLength());
return response;
}
private DocumentParseResult doParse(Long documentId) {
log.info("RagDocumentParseServiceImpl.parse start, documentId={}", documentId);
if (documentId == null) {
throw new IllegalArgumentException("文档ID不能为空");
}
RagDocument document = ragDocumentMapper.selectById(documentId);
if (document == null) {
throw new IllegalArgumentException("文档不存在ID: " + documentId);
}
if (document.getAttachmentId() == null) {
throw new IllegalArgumentException("文档附件ID不能为空");
}
SysAttachment attachment = sysAttachmentService.getById(document.getAttachmentId());
if (attachment == null) {
throw new IllegalArgumentException("附件不存在ID: " + document.getAttachmentId());
}
updateParseStatus(documentId, RagParseStatusEnum.PARSING, null);
try {
DocumentParseContext context = buildParseContext(document, attachment);
DocumentParser parser = documentParserFactory.resolve(context);
DocumentParseResult result = parser.parse(context);
ragDocumentParseResultService.saveSnapshot(document.getStoreId(), documentId, result);
updateParseStatus(documentId, RagParseStatusEnum.PARSED, null);
return result;
} catch (RuntimeException e) {
updateParseStatus(documentId, RagParseStatusEnum.FAILED, e.getMessage());
log.warn("RagDocumentParseServiceImpl.parse failed, documentId={}, message={}", documentId, e.getMessage());
throw e;
}
}
private void validateParseRequest(RagDocumentParseRequest request) {
if (request == null) {
throw new IllegalArgumentException("解析请求不能为空");
}
if (request.getDocumentIds() == null || request.getDocumentIds().isEmpty()) {
throw new IllegalArgumentException("文档ID列表不能为空");
}
}
private DocumentParseContext buildParseContext(RagDocument document, SysAttachment attachment) {
Path filePath = resolveFilePath(attachment);
if (!Files.isRegularFile(filePath)) {
throw new DocumentParseException("解析文件不存在: " + filePath);
}
DocumentParseContext context = new DocumentParseContext();
context.setDocumentId(document.getId());
context.setAttachmentId(attachment.getId());
context.setOriginalName(attachment.getOriginalName());
context.setSuffix(attachment.getFileSuffix());
context.setContentType(attachment.getContentType());
context.setFilePath(filePath);
return context;
}
private Path resolveFilePath(SysAttachment attachment) {
if (!StringUtils.hasText(attachment.getFilePath())) {
throw new DocumentParseException("附件文件路径不能为空");
}
Path filePath = Path.of(attachment.getFilePath());
if (filePath.isAbsolute()) {
return filePath.normalize();
}
return Path.of(attachmentProperties.getBasePath()).resolve(filePath).normalize();
}
private void updateParseStatus(Long documentId, RagParseStatusEnum status, String errorMessage) {
RagDocument current = ragDocumentMapper.selectById(documentId);
if (current == null) {
throw new IllegalArgumentException("文档不存在ID: " + documentId);
}
RagDocument update = new RagDocument();
update.setId(documentId);
update.setVersion(current.getVersion());
update.setParseStatus(status.name());
update.setErrorMessage(StringUtils.hasText(errorMessage) ? errorMessage : null);
boolean updated = ragDocumentMapper.updateById(update) > 0;
if (!updated) {
throw new IllegalStateException("更新解析状态失败文档ID: " + documentId + ", 状态: " + status.name());
}
}
private RagDocumentParseResponse toResponse(Long documentId, DocumentParseResult result) {
RagDocumentParseResponse response = new RagDocumentParseResponse();
response.setDocumentId(documentId);
response.setParseStatus(RagParseStatusEnum.PARSED.name());
response.setTextLength(result.getTextLength());
response.setPageCount(result.getPageCount());
response.setSheetCount(result.getSheetCount());
response.setMetadata(result.getMetadata());
return response;
}
}

View File

@@ -0,0 +1,233 @@
package com.bruce.rag.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.bruce.common.domain.entity.SysAttachment;
import com.bruce.common.dto.request.SysAttachmentUploadRequest;
import com.bruce.common.service.ISysAttachmentService;
import com.bruce.rag.constant.RagSystemConstants;
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.enums.RagIndexStatusEnum;
import com.bruce.rag.enums.RagParseStatusEnum;
import com.bruce.rag.mapper.RagDocumentMapper;
import com.bruce.rag.service.IRagDocumentAutoParseService;
import com.bruce.rag.service.IRagDocumentService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.ArrayList;
import java.util.List;
@Slf4j
@Service
public class RagDocumentServiceImpl extends ServiceImpl<RagDocumentMapper, RagDocument> implements IRagDocumentService {
@Autowired
private ISysAttachmentService sysAttachmentService;
@Autowired
private IRagDocumentAutoParseService ragDocumentAutoParseService;
@Override
public List<RagDocumentResponse> listResponses() {
log.info("RagDocumentServiceImpl.listResponses start");
List<RagDocumentResponse> responses = toResponses(list());
log.info("RagDocumentServiceImpl.listResponses success, count={}", responses.size());
return responses;
}
@Override
public List<RagDocumentResponse> query(RagDocumentQueryRequest request) {
log.info("RagDocumentServiceImpl.query start, request={}", request);
RagDocumentQueryRequest queryRequest = request == null ? new RagDocumentQueryRequest() : request;
String parseStatus = trimToNull(queryRequest.getParseStatus());
String indexStatus = trimToNull(queryRequest.getIndexStatus());
List<RagDocumentResponse> responses = toResponses(lambdaQuery()
.eq(queryRequest.getStoreId() != null, RagDocument::getStoreId, queryRequest.getStoreId())
.eq(queryRequest.getAttachmentId() != null, RagDocument::getAttachmentId, queryRequest.getAttachmentId())
.eq(parseStatus != null, RagDocument::getParseStatus, parseStatus)
.eq(indexStatus != null, RagDocument::getIndexStatus, indexStatus)
.eq(queryRequest.getEnabled() != null, RagDocument::getEnabled, queryRequest.getEnabled())
.orderByDesc(RagDocument::getId)
.list());
log.info("RagDocumentServiceImpl.query success, count={}", responses.size());
return responses;
}
@Override
public RagDocumentResponse getResponseById(Long id) {
log.info("RagDocumentServiceImpl.getResponseById start, id={}", id);
RagDocumentResponse response = RagDocumentResponse.fromEntity(getById(id));
log.info("RagDocumentServiceImpl.getResponseById success, id={}, found={}", id, response != null);
return response;
}
@Override
public boolean saveOrUpdate(RagDocumentSaveRequest request) {
log.info("RagDocumentServiceImpl.saveOrUpdate start, request={}", request);
validateSaveRequest(request);
RagDocument document;
if (request.getId() != null) {
document = getById(request.getId());
if (document == null) {
log.warn("RagDocumentServiceImpl.saveOrUpdate document not found, id={}", request.getId());
throw new IllegalArgumentException("文档不存在ID: " + request.getId());
}
} else {
document = new RagDocument();
document.setEnabled(true);
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
document.setIndexStatus(RagIndexStatusEnum.PENDING.name());
}
if (request.getStoreId() != null) {
document.setStoreId(request.getStoreId());
}
if (request.getAttachmentId() != null) {
document.setAttachmentId(request.getAttachmentId());
}
if (StringUtils.hasText(request.getDocumentTitle())) {
document.setDocumentTitle(request.getDocumentTitle().trim());
}
if (request.getDocumentSummary() != null) {
document.setDocumentSummary(trimToNull(request.getDocumentSummary()));
}
if (StringUtils.hasText(request.getParseStatus())) {
document.setParseStatus(request.getParseStatus().trim());
}
if (StringUtils.hasText(request.getIndexStatus())) {
document.setIndexStatus(request.getIndexStatus().trim());
}
if (request.getEnabled() != null) {
document.setEnabled(request.getEnabled());
}
if (request.getErrorMessage() != null) {
document.setErrorMessage(trimToNull(request.getErrorMessage()));
}
if (request.getRemark() != null) {
document.setRemark(trimToNull(request.getRemark()));
}
boolean result = request.getId() == null ? save(document) : updateById(document);
log.info("RagDocumentServiceImpl.saveOrUpdate success, requestId={}, savedId={}, result={}",
request.getId(), document.getId(), result);
return result;
}
@Override
public boolean removeById(Long id) {
log.info("RagDocumentServiceImpl.removeById start, id={}", id);
if (id == null) {
throw new IllegalArgumentException("文档ID不能为空");
}
RagDocument document = getById(id);
if (document == null) {
log.warn("RagDocumentServiceImpl.removeById document not found, id={}", id);
throw new IllegalArgumentException("文档不存在ID: " + id);
}
boolean result = super.removeById(id);
log.info("RagDocumentServiceImpl.removeById success, id={}, result={}", id, result);
return result;
}
@Override
public List<RagDocumentResponse> batchUpload(RagDocumentBatchUploadRequest request) {
log.info("RagDocumentServiceImpl.batchUpload start, request={}", request);
validateBatchUploadRequest(request);
List<RagDocumentResponse> results = new ArrayList<>();
for (var file : request.getFiles()) {
if (file == null || file.isEmpty()) {
continue;
}
SysAttachmentUploadRequest uploadRequest = new SysAttachmentUploadRequest();
uploadRequest.setFile(file);
uploadRequest.setSourceType(resolveSourceType(request.getSourceType()));
uploadRequest.setSourceId(request.getStoreId());
SysAttachment attachment = sysAttachmentService.upload(uploadRequest);
RagDocument document = new RagDocument();
document.setStoreId(request.getStoreId());
document.setAttachmentId(attachment.getId());
document.setDocumentTitle(StringUtils.hasText(file.getOriginalFilename()) ? file.getOriginalFilename().trim() : null);
document.setDocumentSummary(trimToNull(request.getDocumentSummary()));
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
document.setIndexStatus(RagIndexStatusEnum.PENDING.name());
document.setEnabled(true);
document.setErrorMessage(null);
document.setRemark(trimToNull(request.getRemark()));
save(document);
results.add(RagDocumentResponse.fromEntity(document));
}
if (!results.isEmpty()) {
List<Long> documentIds = results.stream()
.map(RagDocumentResponse::getId)
.filter(id -> id != null)
.toList();
ragDocumentAutoParseService.parseUploadedDocuments(documentIds);
}
log.info("RagDocumentServiceImpl.batchUpload success, storeId={}, uploaded={}",
request.getStoreId(), results.size());
return results;
}
void validateSaveRequest(RagDocumentSaveRequest request) {
if (request == null) {
throw new IllegalArgumentException("保存请求不能为空");
}
if (request.getStoreId() == null) {
throw new IllegalArgumentException("知识库ID不能为空");
}
if (!StringUtils.hasText(request.getDocumentTitle())) {
throw new IllegalArgumentException("文档标题不能为空");
}
}
void validateBatchUploadRequest(RagDocumentBatchUploadRequest request) {
if (request == null) {
throw new IllegalArgumentException("上传请求不能为空");
}
if (request.getStoreId() == null) {
throw new IllegalArgumentException("知识库ID不能为空");
}
if (StringUtils.hasText(request.getSourceType())
&& !RagSystemConstants.SOURCE_TYPE_RAG.equals(request.getSourceType().trim())) {
throw new IllegalArgumentException("sourceType必须为RAG");
}
if (request.getFiles() == null || request.getFiles().length == 0) {
throw new IllegalArgumentException("上传文件不能为空");
}
}
private List<RagDocumentResponse> toResponses(List<RagDocument> documents) {
return documents.stream()
.map(RagDocumentResponse::fromEntity)
.toList();
}
private String resolveSourceType(String sourceType) {
if (!StringUtils.hasText(sourceType)) {
return RagSystemConstants.SOURCE_TYPE_RAG;
}
return RagSystemConstants.SOURCE_TYPE_RAG;
}
private String trimToNull(String value) {
if (!StringUtils.hasText(value)) {
return null;
}
return value.trim();
}
}

View File

@@ -0,0 +1,182 @@
package com.bruce.rag.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.bruce.common.enums.EnableStatusEnum;
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
import com.bruce.rag.dto.request.RagStoreQueryRequest;
import com.bruce.rag.dto.request.RagStoreSaveRequest;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.dto.response.RagStoreDocumentOverviewResponse;
import com.bruce.rag.dto.response.RagStoreOverviewResponse;
import com.bruce.rag.dto.response.RagStoreResponse;
import com.bruce.rag.entity.RagStore;
import com.bruce.rag.enums.RagIndexStatusEnum;
import com.bruce.rag.enums.RagParseStatusEnum;
import com.bruce.rag.mapper.RagStoreMapper;
import com.bruce.rag.service.IRagDocumentService;
import com.bruce.rag.service.IRagStoreService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
@Slf4j
@Service
public class RagStoreServiceImpl extends ServiceImpl<RagStoreMapper, RagStore> implements IRagStoreService {
@Autowired
private IRagDocumentService ragDocumentService;
@Override
public List<RagStoreResponse> listResponses() {
log.info("RagStoreServiceImpl.listResponses start");
List<RagStoreResponse> responses = toResponses(list());
log.info("RagStoreServiceImpl.listResponses success, count={}", responses.size());
return responses;
}
@Override
public List<RagStoreResponse> query(RagStoreQueryRequest request) {
log.info("RagStoreServiceImpl.query start, request={}", request);
RagStoreQueryRequest queryRequest = request == null ? new RagStoreQueryRequest() : request;
List<RagStoreResponse> responses = toResponses(lambdaQuery()
.eq(StringUtils.hasText(queryRequest.getStoreCode()), RagStore::getStoreCode, queryRequest.getStoreCode())
.like(StringUtils.hasText(queryRequest.getStoreName()), RagStore::getStoreName, queryRequest.getStoreName())
.eq(StringUtils.hasText(queryRequest.getStatus()), RagStore::getStatus, queryRequest.getStatus())
.orderByAsc(RagStore::getStoreCode)
.list());
log.info("RagStoreServiceImpl.query success, count={}", responses.size());
return responses;
}
@Override
public RagStoreResponse getResponseById(Long id) {
log.info("RagStoreServiceImpl.getResponseById start, id={}", id);
RagStoreResponse response = RagStoreResponse.fromEntity(getById(id));
log.info("RagStoreServiceImpl.getResponseById success, id={}, found={}", id, response != null);
return response;
}
@Override
public RagStoreOverviewResponse getOverview() {
log.info("RagStoreServiceImpl.getOverview start");
List<RagStore> stores = list();
List<RagDocumentResponse> documents = ragDocumentService.listResponses();
RagStoreOverviewResponse response = new RagStoreOverviewResponse();
response.setTotalStores(stores.size());
response.setTotalDocuments(documents.size());
response.setTotalChunks(null);
response.setRetrievableStores((int) stores.stream()
.filter(store -> EnableStatusEnum.ENABLED.getLabel().equals(store.getStatus()))
.count());
log.info("RagStoreServiceImpl.getOverview success, totalStores={}, totalDocuments={}, retrievableStores={}",
response.getTotalStores(), response.getTotalDocuments(), response.getRetrievableStores());
return response;
}
@Override
public RagStoreDocumentOverviewResponse getDocumentOverview(Long storeId) {
log.info("RagStoreServiceImpl.getDocumentOverview start, storeId={}", storeId);
if (storeId == null) {
throw new IllegalArgumentException("知识库ID不能为空");
}
RagStore store = getById(storeId);
if (store == null) {
throw new IllegalArgumentException("知识库不存在ID: " + storeId);
}
RagDocumentQueryRequest request = new RagDocumentQueryRequest();
request.setStoreId(storeId);
List<RagDocumentResponse> documents = ragDocumentService.query(request);
RagStoreDocumentOverviewResponse response = new RagStoreDocumentOverviewResponse();
response.setStoreId(storeId);
response.setStoreName(store.getStoreName());
response.setDocumentCount(documents.size());
response.setEnabledDocumentCount((int) documents.stream()
.filter(document -> Boolean.TRUE.equals(document.getEnabled()))
.count());
response.setParsedDocumentCount((int) documents.stream()
.filter(document -> RagParseStatusEnum.PARSED.name().equals(document.getParseStatus()))
.count());
response.setIndexedDocumentCount((int) documents.stream()
.filter(document -> RagIndexStatusEnum.INDEXED.name().equals(document.getIndexStatus()))
.count());
response.setLastUploadTime(documents.stream()
.map(RagDocumentResponse::getCreateTime)
.filter(Objects::nonNull)
.max(Comparator.naturalOrder())
.orElse(null));
log.info("RagStoreServiceImpl.getDocumentOverview success, storeId={}, documentCount={}",
storeId, response.getDocumentCount());
return response;
}
@Override
public boolean saveOrUpdate(RagStoreSaveRequest request) {
log.info("RagStoreServiceImpl.saveOrUpdate start, request={}", request);
validateSaveRequest(request);
RagStore existingStore = lambdaQuery()
.eq(RagStore::getStoreCode, request.getStoreCode().trim())
.ne(request.getId() != null, RagStore::getId, request.getId())
.one();
if (existingStore != null) {
log.warn("RagStoreServiceImpl.saveOrUpdate duplicate storeCode detected, requestId={}, existingId={}, storeCode={}",
request.getId(), existingStore.getId(), request.getStoreCode().trim());
throw new IllegalArgumentException("知识库编码已存在: " + request.getStoreCode().trim());
}
RagStore ragStore = buildEntity(request);
boolean result = super.saveOrUpdate(ragStore);
log.info("RagStoreServiceImpl.saveOrUpdate success, requestId={}, savedId={}, storeCode={}, result={}",
request.getId(), ragStore.getId(), ragStore.getStoreCode(), result);
return result;
}
public void validateSaveRequest(RagStoreSaveRequest request) {
log.info("RagStoreServiceImpl.validateSaveRequest start");
if (request == null) {
throw new IllegalArgumentException("保存请求不能为空");
}
if (!StringUtils.hasText(request.getStoreCode())) {
throw new IllegalArgumentException("知识库编码不能为空");
}
if (!StringUtils.hasText(request.getStoreName())) {
throw new IllegalArgumentException("知识库名称不能为空");
}
log.info("RagStoreServiceImpl.validateSaveRequest success, id={}, storeCode={}, storeName={}",
request.getId(), request.getStoreCode(), request.getStoreName());
}
public RagStore buildEntity(RagStoreSaveRequest request) {
RagStore ragStore = new RagStore();
ragStore.setId(request.getId());
ragStore.setStoreCode(request.getStoreCode().trim());
ragStore.setStoreName(request.getStoreName().trim());
ragStore.setDescription(trimToNull(request.getDescription()));
ragStore.setStatus(StringUtils.hasText(request.getStatus())
? request.getStatus().trim()
: EnableStatusEnum.ENABLED.getLabel());
ragStore.setRemark(trimToNull(request.getRemark()));
return ragStore;
}
private List<RagStoreResponse> toResponses(List<RagStore> stores) {
return stores.stream()
.map(RagStoreResponse::fromEntity)
.toList();
}
private String trimToNull(String value) {
if (!StringUtils.hasText(value)) {
return null;
}
return value.trim();
}
}

View File

@@ -0,0 +1,217 @@
package com.bruce.rag;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.extension.service.IService;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import com.bruce.common.typehandler.PgJsonbStringTypeHandler;
import com.bruce.common.domain.model.RequestResult;
import com.bruce.rag.constant.RagSystemConstants;
import com.bruce.rag.controller.RagDocumentController;
import com.bruce.rag.controller.RagStoreController;
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
import com.bruce.rag.dto.request.RagDocumentParseRequest;
import com.bruce.rag.dto.request.RagStoreQueryRequest;
import com.bruce.rag.dto.request.RagStoreSaveRequest;
import com.bruce.rag.dto.response.RagDocumentParseResponse;
import com.bruce.rag.dto.response.RagStoreDocumentOverviewResponse;
import com.bruce.rag.dto.response.RagStoreOverviewResponse;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.dto.response.RagStoreResponse;
import com.bruce.rag.entity.RagChunk;
import com.bruce.rag.entity.RagChunkEmbedding;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.entity.RagDocumentParseResult;
import com.bruce.rag.entity.RagStore;
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
import com.bruce.rag.mapper.RagChunkMapper;
import com.bruce.rag.mapper.RagDocumentMapper;
import com.bruce.rag.mapper.RagDocumentParseResultMapper;
import com.bruce.rag.mapper.RagStoreMapper;
import com.bruce.rag.service.IRagChunkEmbeddingService;
import com.bruce.rag.service.IRagChunkService;
import com.bruce.rag.service.IRagDocumentParseService;
import com.bruce.rag.service.IRagDocumentParseResultService;
import com.bruce.rag.service.IRagDocumentService;
import com.bruce.rag.service.IRagStoreService;
import com.bruce.rag.service.impl.RagChunkEmbeddingServiceImpl;
import com.bruce.rag.service.impl.RagChunkServiceImpl;
import com.bruce.rag.service.impl.RagDocumentParseResultServiceImpl;
import com.bruce.rag.service.impl.RagDocumentServiceImpl;
import com.bruce.rag.service.impl.RagStoreServiceImpl;
import org.junit.jupiter.api.Test;
import org.springframework.web.bind.annotation.PostMapping;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
class RagComponentStructureTests {
@Test
void ragComponentsShouldReuseMybatisPlusBaseTypes() {
assertTrue(BaseMapper.class.isAssignableFrom(RagStoreMapper.class));
assertTrue(BaseMapper.class.isAssignableFrom(RagDocumentMapper.class));
assertTrue(BaseMapper.class.isAssignableFrom(RagDocumentParseResultMapper.class));
assertTrue(BaseMapper.class.isAssignableFrom(RagChunkMapper.class));
assertTrue(BaseMapper.class.isAssignableFrom(RagChunkEmbeddingMapper.class));
assertTrue(IService.class.isAssignableFrom(IRagStoreService.class));
assertTrue(IService.class.isAssignableFrom(IRagDocumentService.class));
assertTrue(IService.class.isAssignableFrom(IRagDocumentParseResultService.class));
assertTrue(IService.class.isAssignableFrom(IRagChunkService.class));
assertTrue(IService.class.isAssignableFrom(IRagChunkEmbeddingService.class));
assertTrue(ServiceImpl.class.isAssignableFrom(RagStoreServiceImpl.class));
assertTrue(ServiceImpl.class.isAssignableFrom(RagDocumentServiceImpl.class));
assertTrue(ServiceImpl.class.isAssignableFrom(RagDocumentParseResultServiceImpl.class));
assertTrue(ServiceImpl.class.isAssignableFrom(RagChunkServiceImpl.class));
assertTrue(ServiceImpl.class.isAssignableFrom(RagChunkEmbeddingServiceImpl.class));
}
@Test
void ragControllersShouldExposeRequestResultAndQueryDtoMethods() throws NoSuchMethodException {
Method storeListMethod = RagStoreController.class.getMethod("list");
Method storeQueryMethod = RagStoreController.class.getMethod("query", RagStoreQueryRequest.class);
Method storeDetailMethod = RagStoreController.class.getMethod("getById", Long.class);
Method storeOverviewMethod = RagStoreController.class.getMethod("overview");
Method storeDocumentOverviewMethod = RagStoreController.class.getMethod("documentOverview", Long.class);
Method storeSaveMethod = RagStoreController.class.getMethod("saveOrUpdate", RagStoreSaveRequest.class);
Method storeDeleteMethod = RagStoreController.class.getMethod("deleteById", Long.class);
Method storeResponseListMethod = IRagStoreService.class.getMethod("listResponses");
Method storeServiceQueryMethod = IRagStoreService.class.getMethod("query", RagStoreQueryRequest.class);
Method storeServiceDetailMethod = IRagStoreService.class.getMethod("getResponseById", Long.class);
Method storeServiceOverviewMethod = IRagStoreService.class.getMethod("getOverview");
Method storeServiceDocumentOverviewMethod = IRagStoreService.class.getMethod("getDocumentOverview", Long.class);
Method storeServiceSaveMethod = IRagStoreService.class.getMethod("saveOrUpdate", RagStoreSaveRequest.class);
Method documentListMethod = RagDocumentController.class.getMethod("list");
Method documentQueryMethod = RagDocumentController.class.getMethod("query", RagDocumentQueryRequest.class);
Method documentParseMethod = RagDocumentController.class.getMethod("parse", RagDocumentParseRequest.class);
Method documentResponseListMethod = IRagDocumentService.class.getMethod("listResponses");
Method documentServiceQueryMethod = IRagDocumentService.class.getMethod("query", RagDocumentQueryRequest.class);
Method documentParseServiceMethod = IRagDocumentParseService.class.getMethod("parse", RagDocumentParseRequest.class);
assertEquals(RequestResult.class, storeListMethod.getReturnType());
assertEquals(RequestResult.class, storeQueryMethod.getReturnType());
assertEquals(RequestResult.class, storeDetailMethod.getReturnType());
assertEquals(RequestResult.class, storeOverviewMethod.getReturnType());
assertEquals(RequestResult.class, storeDocumentOverviewMethod.getReturnType());
assertEquals(RequestResult.class, storeSaveMethod.getReturnType());
assertEquals(RequestResult.class, storeDeleteMethod.getReturnType());
assertEquals(List.class, storeServiceQueryMethod.getReturnType());
assertEquals(RagStoreResponse.class, storeServiceDetailMethod.getReturnType());
assertEquals(RagStoreOverviewResponse.class, storeServiceOverviewMethod.getReturnType());
assertEquals(RagStoreDocumentOverviewResponse.class, storeServiceDocumentOverviewMethod.getReturnType());
assertEquals(boolean.class, storeServiceSaveMethod.getReturnType());
assertTrue(storeResponseListMethod.getGenericReturnType().getTypeName().contains("RagStoreResponse"));
assertTrue(storeServiceQueryMethod.getGenericReturnType().getTypeName().contains("RagStoreResponse"));
assertTrue(storeListMethod.getGenericReturnType().getTypeName().contains("RagStoreResponse"));
assertTrue(storeQueryMethod.getGenericReturnType().getTypeName().contains("RagStoreResponse"));
assertTrue(storeDetailMethod.getGenericReturnType().getTypeName().contains("RagStoreResponse"));
assertTrue(storeOverviewMethod.getGenericReturnType().getTypeName().contains("RagStoreOverviewResponse"));
assertTrue(storeDocumentOverviewMethod.getGenericReturnType().getTypeName().contains("RagStoreDocumentOverviewResponse"));
assertEquals(RagStoreResponse.class, RagStoreResponse.class.getMethod("fromEntity", RagStore.class).getReturnType());
assertEquals(RequestResult.class, documentListMethod.getReturnType());
assertEquals(RequestResult.class, documentQueryMethod.getReturnType());
assertEquals(RequestResult.class, documentParseMethod.getReturnType());
assertEquals(List.class, documentServiceQueryMethod.getReturnType());
assertEquals(List.class, documentParseServiceMethod.getReturnType());
assertTrue(documentResponseListMethod.getGenericReturnType().getTypeName().contains("RagDocumentResponse"));
assertTrue(documentServiceQueryMethod.getGenericReturnType().getTypeName().contains("RagDocumentResponse"));
assertTrue(documentListMethod.getGenericReturnType().getTypeName().contains("RagDocumentResponse"));
assertTrue(documentQueryMethod.getGenericReturnType().getTypeName().contains("RagDocumentResponse"));
assertTrue(documentParseMethod.getGenericReturnType().getTypeName().contains("RagDocumentParseResponse"));
assertEquals(RagDocumentResponse.class, RagDocumentResponse.class.getMethod("fromEntity", RagDocument.class).getReturnType());
}
@Test
void ragDocumentListUrlShouldUseExplicitListAction() throws NoSuchMethodException {
Method documentListMethod = RagDocumentController.class.getMethod("list");
PostMapping postMapping = documentListMethod.getAnnotation(PostMapping.class);
assertNotNull(postMapping);
assertEquals("/list", postMapping.value()[0]);
}
@Test
void ragSourceTypesAndDocumentRelationShouldExist() throws NoSuchFieldException {
Field storeIdField = RagDocument.class.getDeclaredField("storeId");
Field attachmentIdField = RagDocument.class.getDeclaredField("attachmentId");
assertEquals("RAG_STORE", RagSystemConstants.RAG_STORE);
assertEquals("RAG_DOCUMENT", RagSystemConstants.RAG_DOCUMENT);
assertEquals("RAG", RagSystemConstants.SOURCE_TYPE_RAG);
assertEquals(Long.class, storeIdField.getType());
assertEquals(Long.class, attachmentIdField.getType());
assertTrue(RagStore.class.getSimpleName().contains("RagStore"));
assertTrue(RagStoreController.class.getSimpleName().contains("RagStoreController"));
assertTrue(RagDocumentController.class.getSimpleName().contains("RagDocumentController"));
}
@Test
void ragChunkStructureShouldSupportChunkMetadata() throws NoSuchFieldException {
assertEquals(Long.class, RagChunk.class.getDeclaredField("storeId").getType());
assertEquals(Long.class, RagChunk.class.getDeclaredField("documentId").getType());
assertEquals(Integer.class, RagChunk.class.getDeclaredField("chunkIndex").getType());
assertEquals(String.class, RagChunk.class.getDeclaredField("chunkContent").getType());
assertEquals(String.class, RagChunk.class.getDeclaredField("chunkSummary").getType());
assertEquals(Integer.class, RagChunk.class.getDeclaredField("tokenCount").getType());
assertEquals(Integer.class, RagChunk.class.getDeclaredField("pageNumber").getType());
assertEquals(String.class, RagChunk.class.getDeclaredField("sectionTitle").getType());
assertEquals(String.class, RagChunk.class.getDeclaredField("headingPath").getType());
assertEquals(String.class, RagChunk.class.getDeclaredField("vectorId").getType());
assertEquals(String.class, RagChunk.class.getDeclaredField("metadataJson").getType());
assertEquals(Boolean.class, RagChunk.class.getDeclaredField("enabled").getType());
assertEquals(String.class, RagChunk.class.getDeclaredField("remark").getType());
}
@Test
void ragChunkEmbeddingStructureShouldSupportPgvectorMetadata() throws NoSuchFieldException {
assertEquals(Long.class, RagChunkEmbedding.class.getDeclaredField("storeId").getType());
assertEquals(Long.class, RagChunkEmbedding.class.getDeclaredField("documentId").getType());
assertEquals(Long.class, RagChunkEmbedding.class.getDeclaredField("chunkId").getType());
assertEquals(String.class, RagChunkEmbedding.class.getDeclaredField("embeddingModel").getType());
assertEquals(Integer.class, RagChunkEmbedding.class.getDeclaredField("embeddingDimension").getType());
assertEquals(String.class, RagChunkEmbedding.class.getDeclaredField("embedding").getType());
assertEquals(String.class, RagChunkEmbedding.class.getDeclaredField("contentHash").getType());
assertEquals(Boolean.class, RagChunkEmbedding.class.getDeclaredField("enabled").getType());
assertEquals(String.class, RagChunkEmbedding.class.getDeclaredField("remark").getType());
}
@Test
void ragParseResultStructureShouldSupportSnapshotMetadata() throws NoSuchFieldException {
assertEquals(Long.class, RagDocumentParseResult.class.getDeclaredField("storeId").getType());
assertEquals(Long.class, RagDocumentParseResult.class.getDeclaredField("documentId").getType());
assertEquals(String.class, RagDocumentParseResult.class.getDeclaredField("parsedText").getType());
assertEquals(Integer.class, RagDocumentParseResult.class.getDeclaredField("textLength").getType());
assertEquals(Integer.class, RagDocumentParseResult.class.getDeclaredField("pageCount").getType());
assertEquals(Integer.class, RagDocumentParseResult.class.getDeclaredField("sheetCount").getType());
assertEquals(String.class, RagDocumentParseResult.class.getDeclaredField("metadataJson").getType());
assertEquals(String.class, RagDocumentParseResult.class.getDeclaredField("contentHash").getType());
assertEquals(Integer.class, RagDocumentParseResult.class.getDeclaredField("parseVersion").getType());
assertEquals(Boolean.class, RagDocumentParseResult.class.getDeclaredField("enabled").getType());
}
@Test
void ragMetadataJsonFieldsShouldUseJsonbTypeHandler() throws NoSuchFieldException {
TableName chunkTable = RagChunk.class.getAnnotation(TableName.class);
TableName parseResultTable = RagDocumentParseResult.class.getAnnotation(TableName.class);
TableField chunkMetadataField = RagChunk.class.getDeclaredField("metadataJson").getAnnotation(TableField.class);
TableField parseResultMetadataField = RagDocumentParseResult.class.getDeclaredField("metadataJson").getAnnotation(TableField.class);
assertNotNull(chunkTable);
assertNotNull(parseResultTable);
assertTrue(chunkTable.autoResultMap());
assertTrue(parseResultTable.autoResultMap());
assertNotNull(chunkMetadataField);
assertNotNull(parseResultMetadataField);
assertEquals(PgJsonbStringTypeHandler.class, chunkMetadataField.typeHandler());
assertEquals(PgJsonbStringTypeHandler.class, parseResultMetadataField.typeHandler());
}
}

View File

@@ -0,0 +1,180 @@
package com.bruce.rag;
import com.bruce.common.config.AttachmentProperties;
import com.bruce.common.document.parse.DocumentParseContext;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.common.document.parse.DocumentParser;
import com.bruce.common.document.parse.DocumentParserFactory;
import com.bruce.common.domain.entity.SysAttachment;
import com.bruce.common.service.ISysAttachmentService;
import com.bruce.rag.dto.request.RagDocumentParseRequest;
import com.bruce.rag.dto.response.RagDocumentParseResponse;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.mapper.RagDocumentMapper;
import com.bruce.rag.service.IRagDocumentParseResultService;
import com.bruce.rag.enums.RagParseStatusEnum;
import com.bruce.rag.service.impl.RagDocumentParseServiceImpl;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.mock;
@ExtendWith(MockitoExtension.class)
class RagDocumentParseServiceImplTests {
@TempDir
private Path tempDir;
@Mock
private RagDocumentMapper ragDocumentMapper;
@Mock
private ISysAttachmentService sysAttachmentService;
@Test
void parseShouldUpdateStatusAndReturnParseResponse() throws Exception {
Path file = tempDir.resolve("rag").resolve("people.txt");
Files.createDirectories(file.getParent());
Files.writeString(file, "people profiles");
RagDocument document = new RagDocument();
document.setId(1001L);
document.setStoreId(2002L);
document.setAttachmentId(3003L);
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
SysAttachment attachment = new SysAttachment();
attachment.setId(3003L);
attachment.setOriginalName("people.txt");
attachment.setFileSuffix("txt");
attachment.setContentType("text/plain");
attachment.setFilePath("rag/people.txt");
AttachmentProperties attachmentProperties = new AttachmentProperties();
attachmentProperties.setBasePath(tempDir.toString());
DocumentParser parser = new FixedDocumentParser("people profiles");
RagDocumentParseServiceImpl service = new RagDocumentParseServiceImpl(
ragDocumentMapper,
sysAttachmentService,
attachmentProperties,
new DocumentParserFactory(List.of(parser)),
mock(IRagDocumentParseResultService.class)
);
when(ragDocumentMapper.selectById(1001L)).thenReturn(document);
when(sysAttachmentService.getById(3003L)).thenReturn(attachment);
when(ragDocumentMapper.updateById(any(RagDocument.class))).thenReturn(1);
RagDocumentParseResponse response = service.parse(1001L);
assertEquals(1001L, response.getDocumentId());
assertEquals(RagParseStatusEnum.PARSED.name(), response.getParseStatus());
assertEquals(15, response.getTextLength());
assertEquals("fixed", response.getMetadata().get("parser"));
ArgumentCaptor<RagDocument> captor = ArgumentCaptor.forClass(RagDocument.class);
verify(ragDocumentMapper, times(2)).updateById(captor.capture());
List<RagDocument> updates = captor.getAllValues();
assertEquals(RagParseStatusEnum.PARSING.name(), updates.get(0).getParseStatus());
assertEquals(RagParseStatusEnum.PARSED.name(), updates.get(1).getParseStatus());
assertTrue(parser.supports(new DocumentParseContext()));
}
@Test
void parseShouldSupportBatchRequest() throws Exception {
Path file = tempDir.resolve("rag").resolve("batch.txt");
Files.createDirectories(file.getParent());
Files.writeString(file, "batch profiles");
RagDocument document = new RagDocument();
document.setId(1002L);
document.setStoreId(2002L);
document.setAttachmentId(3004L);
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
SysAttachment attachment = new SysAttachment();
attachment.setId(3004L);
attachment.setOriginalName("batch.txt");
attachment.setFileSuffix("txt");
attachment.setContentType("text/plain");
attachment.setFilePath("rag/batch.txt");
AttachmentProperties attachmentProperties = new AttachmentProperties();
attachmentProperties.setBasePath(tempDir.toString());
RagDocumentParseServiceImpl service = new RagDocumentParseServiceImpl(
ragDocumentMapper,
sysAttachmentService,
attachmentProperties,
new DocumentParserFactory(List.of(new FixedDocumentParser("batch profiles"))),
mock(IRagDocumentParseResultService.class)
);
RagDocumentParseRequest request = new RagDocumentParseRequest();
request.setDocumentIds(List.of(1002L));
when(ragDocumentMapper.selectById(1002L)).thenReturn(document);
when(sysAttachmentService.getById(3004L)).thenReturn(attachment);
when(ragDocumentMapper.updateById(any(RagDocument.class))).thenReturn(1);
List<RagDocumentParseResponse> responses = service.parse(request);
assertEquals(1, responses.size());
assertEquals(1002L, responses.getFirst().getDocumentId());
assertEquals(RagParseStatusEnum.PARSED.name(), responses.getFirst().getParseStatus());
}
@Test
void parseShouldRejectEmptyDocumentIds() {
AttachmentProperties attachmentProperties = new AttachmentProperties();
attachmentProperties.setBasePath(tempDir.toString());
RagDocumentParseServiceImpl service = new RagDocumentParseServiceImpl(
ragDocumentMapper,
sysAttachmentService,
attachmentProperties,
new DocumentParserFactory(List.of(new FixedDocumentParser("batch profiles"))),
mock(IRagDocumentParseResultService.class)
);
RagDocumentParseRequest request = new RagDocumentParseRequest();
request.setDocumentIds(List.of());
assertThrows(IllegalArgumentException.class, () -> service.parse(request));
}
private static class FixedDocumentParser implements DocumentParser {
private final String text;
private FixedDocumentParser(String text) {
this.text = text;
}
@Override
public boolean supports(DocumentParseContext context) {
return true;
}
@Override
public DocumentParseResult parse(DocumentParseContext context) {
DocumentParseResult result = new DocumentParseResult();
result.setText(text);
result.setTextLength(text.length());
result.setMetadata(Map.of("parser", "fixed"));
return result;
}
}
}

View File

@@ -0,0 +1,164 @@
package com.bruce.rag;
import com.bruce.common.domain.entity.SysAttachment;
import com.bruce.common.dto.request.SysAttachmentUploadRequest;
import com.bruce.common.service.ISysAttachmentService;
import com.bruce.rag.constant.RagSystemConstants;
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.enums.RagIndexStatusEnum;
import com.bruce.rag.enums.RagParseStatusEnum;
import com.bruce.rag.service.IRagDocumentAutoParseService;
import com.bruce.rag.service.impl.RagDocumentServiceImpl;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockMultipartFile;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class RagDocumentServiceImplTests {
@Spy
@InjectMocks
private RagDocumentServiceImpl ragDocumentService;
@Mock
private ISysAttachmentService sysAttachmentService;
@Mock
private IRagDocumentAutoParseService ragDocumentAutoParseService;
@Test
void batchUploadShouldUseRagSourceTypeAndStoreIdAsSourceId() {
MockMultipartFile file = new MockMultipartFile(
"files",
"knowledge.txt",
"text/plain",
"hello rag".getBytes()
);
RagDocumentBatchUploadRequest request = new RagDocumentBatchUploadRequest();
request.setStoreId(1001L);
request.setSourceType(RagSystemConstants.SOURCE_TYPE_RAG);
request.setFiles(new MockMultipartFile[]{file});
request.setDocumentSummary("批量摘要");
request.setRemark("批量备注");
SysAttachment attachment = new SysAttachment();
attachment.setId(2002L);
when(sysAttachmentService.upload(any(SysAttachmentUploadRequest.class))).thenReturn(attachment);
doAnswer(invocation -> true).when(ragDocumentService).save(any(RagDocument.class));
var responses = ragDocumentService.batchUpload(request);
ArgumentCaptor<SysAttachmentUploadRequest> uploadCaptor = ArgumentCaptor.forClass(SysAttachmentUploadRequest.class);
verify(sysAttachmentService).upload(uploadCaptor.capture());
SysAttachmentUploadRequest uploadRequest = uploadCaptor.getValue();
assertEquals(RagSystemConstants.SOURCE_TYPE_RAG, uploadRequest.getSourceType());
assertEquals(1001L, uploadRequest.getSourceId());
assertEquals(file, uploadRequest.getFile());
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
verify(ragDocumentService).save(documentCaptor.capture());
RagDocument savedDocument = documentCaptor.getValue();
assertEquals(1001L, savedDocument.getStoreId());
assertEquals(2002L, savedDocument.getAttachmentId());
assertEquals("knowledge.txt", savedDocument.getDocumentTitle());
assertEquals("批量摘要", savedDocument.getDocumentSummary());
assertEquals(RagParseStatusEnum.UPLOADED.name(), savedDocument.getParseStatus());
assertEquals(RagIndexStatusEnum.PENDING.name(), savedDocument.getIndexStatus());
assertTrue(savedDocument.getEnabled());
assertNull(savedDocument.getErrorMessage());
assertEquals("批量备注", savedDocument.getRemark());
assertEquals(1, responses.size());
assertEquals(RagParseStatusEnum.UPLOADED.name(), responses.getFirst().getParseStatus());
assertEquals(RagIndexStatusEnum.PENDING.name(), responses.getFirst().getIndexStatus());
}
@Test
void saveOrUpdateShouldWriteAllEditableFields() {
RagDocument existingDocument = new RagDocument();
existingDocument.setId(3003L);
RagDocumentSaveRequest request = new RagDocumentSaveRequest();
request.setId(3003L);
request.setStoreId(1001L);
request.setAttachmentId(2002L);
request.setDocumentTitle(" 新标题 ");
request.setDocumentSummary(" 新摘要 ");
request.setParseStatus(RagParseStatusEnum.PARSED.name());
request.setIndexStatus(RagIndexStatusEnum.INDEXED.name());
request.setEnabled(false);
request.setErrorMessage(" 已修复 ");
request.setRemark(" 备注信息 ");
doReturn(existingDocument).when(ragDocumentService).getById(3003L);
doReturn(true).when(ragDocumentService).updateById(any(RagDocument.class));
boolean result = ragDocumentService.saveOrUpdate(request);
assertTrue(result);
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
verify(ragDocumentService).updateById(documentCaptor.capture());
RagDocument savedDocument = documentCaptor.getValue();
assertEquals(3003L, savedDocument.getId());
assertEquals(1001L, savedDocument.getStoreId());
assertEquals(2002L, savedDocument.getAttachmentId());
assertEquals("新标题", savedDocument.getDocumentTitle());
assertEquals("新摘要", savedDocument.getDocumentSummary());
assertEquals(RagParseStatusEnum.PARSED.name(), savedDocument.getParseStatus());
assertEquals(RagIndexStatusEnum.INDEXED.name(), savedDocument.getIndexStatus());
assertEquals(false, savedDocument.getEnabled());
assertEquals("已修复", savedDocument.getErrorMessage());
assertEquals("备注信息", savedDocument.getRemark());
}
@Test
void saveOrUpdateShouldPreserveExistingFieldsForPartialUpdate() {
RagDocument existingDocument = new RagDocument();
existingDocument.setId(3003L);
existingDocument.setStoreId(1001L);
existingDocument.setAttachmentId(2002L);
existingDocument.setDocumentTitle("people_profiles.txt");
existingDocument.setDocumentSummary("测试人员信息,有多条人员信息");
existingDocument.setParseStatus(RagParseStatusEnum.UPLOADED.name());
existingDocument.setIndexStatus(RagIndexStatusEnum.PENDING.name());
existingDocument.setEnabled(true);
existingDocument.setRemark("测试人员信息");
RagDocumentSaveRequest request = new RagDocumentSaveRequest();
request.setId(3003L);
request.setStoreId(1001L);
request.setDocumentTitle("people_profiles.txt");
request.setEnabled(false);
doReturn(existingDocument).when(ragDocumentService).getById(3003L);
doReturn(true).when(ragDocumentService).updateById(any(RagDocument.class));
boolean result = ragDocumentService.saveOrUpdate(request);
assertTrue(result);
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
verify(ragDocumentService).updateById(documentCaptor.capture());
RagDocument savedDocument = documentCaptor.getValue();
assertEquals(2002L, savedDocument.getAttachmentId());
assertEquals("测试人员信息,有多条人员信息", savedDocument.getDocumentSummary());
assertEquals(RagParseStatusEnum.UPLOADED.name(), savedDocument.getParseStatus());
assertEquals(RagIndexStatusEnum.PENDING.name(), savedDocument.getIndexStatus());
assertEquals(false, savedDocument.getEnabled());
assertEquals("测试人员信息", savedDocument.getRemark());
}
}

View File

@@ -0,0 +1,126 @@
package com.bruce.rag;
import com.bruce.common.enums.EnableStatusEnum;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.dto.response.RagStoreDocumentOverviewResponse;
import com.bruce.rag.dto.response.RagStoreOverviewResponse;
import com.bruce.rag.entity.RagStore;
import com.bruce.rag.enums.RagIndexStatusEnum;
import com.bruce.rag.enums.RagParseStatusEnum;
import com.bruce.rag.service.IRagDocumentService;
import com.bruce.rag.service.impl.RagStoreServiceImpl;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.Date;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class RagStoreOverviewServiceTests {
@Spy
@InjectMocks
private RagStoreServiceImpl ragStoreService;
@Mock
private IRagDocumentService ragDocumentService;
@Test
void getOverviewShouldAggregateStoreAndDocumentCounts() {
RagStore enabledStore = new RagStore();
enabledStore.setId(1L);
enabledStore.setStatus(EnableStatusEnum.ENABLED.getLabel());
RagStore disabledStore = new RagStore();
disabledStore.setId(2L);
disabledStore.setStatus("停用");
when(ragDocumentService.listResponses()).thenReturn(List.of(
createDocumentResponse("11", "1", true, RagParseStatusEnum.UPLOADED.name(), RagIndexStatusEnum.PENDING.name(), new Date()),
createDocumentResponse("22", "2", false, RagParseStatusEnum.PARSED.name(), RagIndexStatusEnum.INDEXED.name(), new Date())
));
doReturn(List.of(enabledStore, disabledStore)).when(ragStoreService).list();
RagStoreOverviewResponse response = ragStoreService.getOverview();
assertEquals(2, response.getTotalStores());
assertEquals(2, response.getTotalDocuments());
assertNull(response.getTotalChunks());
assertEquals(1, response.getRetrievableStores());
}
@Test
void getDocumentOverviewShouldAggregateCurrentStoreDocumentMetrics() {
RagStore store = new RagStore();
store.setId(1L);
store.setStoreName("产品制度库");
doReturn(store).when(ragStoreService).getById(1L);
when(ragDocumentService.query(org.mockito.ArgumentMatchers.any())).thenReturn(List.of(
createDocumentResponse("11", "1", true, RagParseStatusEnum.UPLOADED.name(), RagIndexStatusEnum.PENDING.name(), new Date(1747816496000L)),
createDocumentResponse("12", "1", true, RagParseStatusEnum.PARSED.name(), RagIndexStatusEnum.INDEXED.name(), new Date(1747820096000L)),
createDocumentResponse("13", "1", false, RagParseStatusEnum.FAILED.name(), RagIndexStatusEnum.FAILED.name(), new Date(1747812896000L))
));
RagStoreDocumentOverviewResponse response = ragStoreService.getDocumentOverview(1L);
assertEquals(1L, response.getStoreId());
assertEquals("产品制度库", response.getStoreName());
assertEquals(3, response.getDocumentCount());
assertEquals(2, response.getEnabledDocumentCount());
assertEquals(1, response.getParsedDocumentCount());
assertEquals(1, response.getIndexedDocumentCount());
assertEquals(new Date(1747820096000L), response.getLastUploadTime());
}
@Test
void getDocumentOverviewShouldQueryDocumentsByStoreIdOnly() {
RagStore store = new RagStore();
store.setId(1L);
store.setStoreName("产品制度库");
doReturn(store).when(ragStoreService).getById(1L);
when(ragDocumentService.query(org.mockito.ArgumentMatchers.any())).thenReturn(List.of());
ragStoreService.getDocumentOverview(1L);
org.mockito.ArgumentCaptor<com.bruce.rag.dto.request.RagDocumentQueryRequest> captor =
org.mockito.ArgumentCaptor.forClass(com.bruce.rag.dto.request.RagDocumentQueryRequest.class);
org.mockito.Mockito.verify(ragDocumentService).query(captor.capture());
assertEquals(1L, captor.getValue().getStoreId());
assertNull(captor.getValue().getParseStatus());
assertNull(captor.getValue().getIndexStatus());
}
@Test
void getDocumentOverviewShouldRejectUnknownStore() {
doReturn(null).when(ragStoreService).getById(999L);
assertThrows(IllegalArgumentException.class, () -> ragStoreService.getDocumentOverview(999L));
}
private RagDocumentResponse createDocumentResponse(
String id,
String storeId,
boolean enabled,
String parseStatus,
String indexStatus,
Date createTime
) {
RagDocumentResponse response = new RagDocumentResponse();
response.setId(Long.valueOf(id));
response.setStoreId(Long.valueOf(storeId));
response.setEnabled(enabled);
response.setParseStatus(parseStatus);
response.setIndexStatus(indexStatus);
response.setCreateTime(createTime);
return response;
}
}

View File

@@ -0,0 +1,45 @@
package com.bruce.rag;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.bruce.rag.dto.response.RagDocumentResponse;
import com.bruce.rag.dto.response.RagStoreResponse;
import org.junit.jupiter.api.Test;
import java.util.Date;
import static org.junit.jupiter.api.Assertions.assertTrue;
class RagStoreResponseSerializationTests {
@Test
void idShouldSerializeAsStringForFrontendPrecisionSafety() throws Exception {
RagStoreResponse response = new RagStoreResponse();
response.setId(2057302206052372481L);
response.setStoreCode("TEXT-1");
response.setStoreName("测试库1");
String json = new ObjectMapper().writeValueAsString(response);
assertTrue(json.contains("\"id\":\"2057302206052372481\""));
}
@Test
void responseTimeShouldSerializeWithUnifiedFormat() throws Exception {
RagStoreResponse storeResponse = new RagStoreResponse();
storeResponse.setCreateTime(new Date(1747816496000L));
storeResponse.setUpdateTime(new Date(1747816496000L));
RagDocumentResponse documentResponse = new RagDocumentResponse();
documentResponse.setCreateTime(new Date(1747816496000L));
documentResponse.setUpdateTime(new Date(1747816496000L));
ObjectMapper objectMapper = new ObjectMapper();
String storeJson = objectMapper.writeValueAsString(storeResponse);
String documentJson = objectMapper.writeValueAsString(documentResponse);
assertTrue(storeJson.contains("\"createTime\":\"2025-05-21 16:34:56\""));
assertTrue(storeJson.contains("\"updateTime\":\"2025-05-21 16:34:56\""));
assertTrue(documentJson.contains("\"createTime\":\"2025-05-21 16:34:56\""));
assertTrue(documentJson.contains("\"updateTime\":\"2025-05-21 16:34:56\""));
}
}

View File

@@ -0,0 +1,54 @@
package com.bruce.rag;
import com.bruce.common.enums.EnableStatusEnum;
import com.bruce.rag.dto.request.RagStoreSaveRequest;
import com.bruce.rag.entity.RagStore;
import com.bruce.rag.service.impl.RagStoreServiceImpl;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
class RagStoreSaveValidationTests {
@Test
void saveShouldRejectBlankStoreCode() {
RagStoreServiceImpl service = new RagStoreServiceImpl();
RagStoreSaveRequest request = new RagStoreSaveRequest();
request.setStoreName("产品制度库");
assertThrows(IllegalArgumentException.class, () -> service.validateSaveRequest(request));
}
@Test
void saveShouldRejectBlankStoreName() {
RagStoreServiceImpl service = new RagStoreServiceImpl();
RagStoreSaveRequest request = new RagStoreSaveRequest();
request.setStoreCode("PROD_DOC");
assertThrows(IllegalArgumentException.class, () -> service.validateSaveRequest(request));
}
@Test
void saveShouldAcceptMinimalValidRequest() {
RagStoreServiceImpl service = new RagStoreServiceImpl();
RagStoreSaveRequest request = new RagStoreSaveRequest();
request.setStoreCode("PROD_DOC");
request.setStoreName("产品制度库");
assertDoesNotThrow(() -> service.validateSaveRequest(request));
}
@Test
void saveShouldDefaultStatusToEnabledEnumLabel() {
RagStoreServiceImpl service = new RagStoreServiceImpl();
RagStoreSaveRequest request = new RagStoreSaveRequest();
request.setStoreCode("PROD_DOC");
request.setStoreName("产品制度库");
RagStore ragStore = service.buildEntity(request);
assertEquals(EnableStatusEnum.ENABLED.getLabel(), ragStore.getStatus());
}
}

View File

@@ -0,0 +1,50 @@
package com.bruce.rag.parse;
import com.bruce.rag.entity.RagChunk;
import com.bruce.rag.enums.RagChunkStrategyEnum;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
class ChunkerFactoryTests {
@Test
void resolveShouldReturnMatchingChunker() {
Chunker supported = new StubChunker(RagChunkStrategyEnum.FIXED_LENGTH);
Chunker unsupported = new StubChunker(RagChunkStrategyEnum.DELIMITER);
ChunkerFactory factory = new ChunkerFactory(List.of(supported, unsupported));
Chunker resolved = factory.resolve(RagChunkStrategyEnum.FIXED_LENGTH);
assertSame(supported, resolved);
}
@Test
void resolveShouldRejectUnsupportedStrategy() {
ChunkerFactory factory = new ChunkerFactory(List.of(new StubChunker(RagChunkStrategyEnum.FIXED_LENGTH)));
assertThrows(IllegalArgumentException.class, () -> factory.resolve(RagChunkStrategyEnum.SEMANTIC));
}
private static class StubChunker implements Chunker {
private final RagChunkStrategyEnum strategy;
private StubChunker(RagChunkStrategyEnum strategy) {
this.strategy = strategy;
}
@Override
public boolean supports(RagChunkStrategyEnum strategy) {
return this.strategy == strategy;
}
@Override
public List<RagChunk> chunk(RagChunkCommand command) {
return List.of();
}
}
}

View File

@@ -0,0 +1,64 @@
package com.bruce.rag.parse;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.rag.entity.RagChunk;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.enums.RagChunkStrategyEnum;
import com.bruce.rag.parse.impl.DelimiterChunker;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
class DelimiterChunkerTests {
@Test
void chunkShouldSplitByDelimiterAndIgnoreBlankSegments() {
DelimiterChunker chunker = new DelimiterChunker();
RagChunkCommand command = new RagChunkCommand();
command.setDocument(buildDocument());
command.setParseResult(buildParseResult("第一段。第二段。。第三段"));
command.setChunkStrategy(RagChunkStrategyEnum.DELIMITER.getValue());
command.setDelimiter("");
List<RagChunk> chunks = chunker.chunk(command);
assertEquals(3, chunks.size());
assertEquals("第一段", chunks.get(0).getChunkContent());
assertEquals("第二段", chunks.get(1).getChunkContent());
assertEquals("第三段", chunks.get(2).getChunkContent());
assertEquals(0, chunks.get(0).getChunkIndex());
assertEquals(1, chunks.get(1).getChunkIndex());
assertEquals(2, chunks.get(2).getChunkIndex());
}
@Test
void chunkShouldReturnEmptyListForBlankText() {
DelimiterChunker chunker = new DelimiterChunker();
RagChunkCommand command = new RagChunkCommand();
command.setDocument(buildDocument());
command.setParseResult(buildParseResult(" "));
command.setChunkStrategy(RagChunkStrategyEnum.DELIMITER.getValue());
command.setDelimiter("");
assertTrue(chunker.chunk(command).isEmpty());
}
private static RagDocument buildDocument() {
RagDocument document = new RagDocument();
document.setId(66L);
document.setStoreId(55L);
return document;
}
private static DocumentParseResult buildParseResult(String text) {
DocumentParseResult result = new DocumentParseResult();
result.setText(text);
result.setTextLength(text.length());
return result;
}
}

View File

@@ -0,0 +1,69 @@
package com.bruce.rag.parse;
import com.bruce.common.document.parse.DocumentParseResult;
import com.bruce.rag.entity.RagChunk;
import com.bruce.rag.entity.RagDocument;
import com.bruce.rag.enums.RagChunkStrategyEnum;
import com.bruce.rag.parse.impl.FixedLengthChunker;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
class FixedLengthChunkerTests {
@Test
void chunkShouldSplitTextByChunkSizeAndOverlap() {
FixedLengthChunker chunker = new FixedLengthChunker();
RagChunkCommand command = new RagChunkCommand();
command.setDocument(buildDocument());
command.setParseResult(buildParseResult("abcdefghij"));
command.setChunkStrategy(RagChunkStrategyEnum.FIXED_LENGTH.getValue());
command.setChunkSize(4);
command.setChunkOverlap(1);
List<RagChunk> chunks = chunker.chunk(command);
assertEquals(3, chunks.size());
assertEquals("abcd", chunks.get(0).getChunkContent());
assertEquals("defg", chunks.get(1).getChunkContent());
assertEquals("ghij", chunks.get(2).getChunkContent());
assertEquals(0, chunks.get(0).getChunkIndex());
assertEquals(1, chunks.get(1).getChunkIndex());
assertEquals(2, chunks.get(2).getChunkIndex());
assertEquals(99L, chunks.get(0).getDocumentId());
assertEquals(88L, chunks.get(0).getStoreId());
assertTrue(Boolean.TRUE.equals(chunks.get(0).getEnabled()));
}
@Test
void chunkShouldReturnEmptyListForBlankText() {
FixedLengthChunker chunker = new FixedLengthChunker();
RagChunkCommand command = new RagChunkCommand();
command.setDocument(buildDocument());
command.setParseResult(buildParseResult(" "));
command.setChunkStrategy(RagChunkStrategyEnum.FIXED_LENGTH.getValue());
command.setChunkSize(4);
command.setChunkOverlap(1);
assertTrue(chunker.chunk(command).isEmpty());
}
private static RagDocument buildDocument() {
RagDocument document = new RagDocument();
document.setId(99L);
document.setStoreId(88L);
return document;
}
private static DocumentParseResult buildParseResult(String text) {
DocumentParseResult result = new DocumentParseResult();
result.setText(text);
result.setTextLength(text.length());
return result;
}
}