feat(rag): add document parsing structures
This commit is contained in:
17
src/main/java/com/bruce/common/config/MybatisPlusConfig.java
Normal file
17
src/main/java/com/bruce/common/config/MybatisPlusConfig.java
Normal file
@@ -0,0 +1,17 @@
|
||||
package com.bruce.common.config;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.OptimisticLockerInnerInterceptor;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
public class MybatisPlusConfig {
|
||||
|
||||
@Bean
|
||||
public MybatisPlusInterceptor mybatisPlusInterceptor() {
|
||||
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
|
||||
interceptor.addInnerInterceptor(new OptimisticLockerInnerInterceptor());
|
||||
return interceptor;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.bruce.common.document.parse;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.nio.file.Path;
|
||||
|
||||
@Data
|
||||
public class DocumentParseContext {
|
||||
|
||||
private Long documentId;
|
||||
|
||||
private Long attachmentId;
|
||||
|
||||
private String originalName;
|
||||
|
||||
private String suffix;
|
||||
|
||||
private String contentType;
|
||||
|
||||
private Path filePath;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.bruce.common.document.parse;
|
||||
|
||||
public class DocumentParseException extends RuntimeException {
|
||||
|
||||
public DocumentParseException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public DocumentParseException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.bruce.common.document.parse;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class DocumentParseResult {
|
||||
|
||||
private String text;
|
||||
|
||||
private Integer textLength;
|
||||
|
||||
private Integer pageCount;
|
||||
|
||||
private Integer sheetCount;
|
||||
|
||||
private Map<String, Object> metadata = new LinkedHashMap<>();
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.bruce.common.document.parse;
|
||||
|
||||
public interface DocumentParser {
|
||||
|
||||
boolean supports(DocumentParseContext context);
|
||||
|
||||
DocumentParseResult parse(DocumentParseContext context);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package com.bruce.common.document.parse;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
@Component
|
||||
public class DocumentParserFactory {
|
||||
|
||||
private final List<DocumentParser> parsers;
|
||||
|
||||
public DocumentParserFactory(List<DocumentParser> parsers) {
|
||||
this.parsers = parsers;
|
||||
}
|
||||
|
||||
public DocumentParser resolve(DocumentParseContext context) {
|
||||
return parsers.stream()
|
||||
.filter(parser -> parser.supports(context))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new DocumentParseException("不支持的文档类型: " + resolveType(context)));
|
||||
}
|
||||
|
||||
private String resolveType(DocumentParseContext context) {
|
||||
if (context == null) {
|
||||
return "unknown";
|
||||
}
|
||||
if (StringUtils.hasText(context.getSuffix())) {
|
||||
return context.getSuffix().trim().toLowerCase(Locale.ROOT);
|
||||
}
|
||||
if (StringUtils.hasText(context.getContentType())) {
|
||||
return context.getContentType().trim();
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package com.bruce.common.document.parse.impl;
|
||||
|
||||
import com.bruce.common.document.parse.DocumentParseContext;
|
||||
import com.bruce.common.document.parse.DocumentParseException;
|
||||
import com.bruce.common.document.parse.DocumentParseResult;
|
||||
import org.apache.tika.Tika;
|
||||
import org.apache.tika.exception.TikaException;
|
||||
import org.apache.tika.metadata.Metadata;
|
||||
import org.apache.tika.metadata.TikaCoreProperties;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.util.Locale;
|
||||
import java.util.Set;
|
||||
|
||||
abstract class AbstractTikaDocumentParser {
|
||||
|
||||
private static final int MAX_TEXT_LENGTH = -1;
|
||||
|
||||
private final Tika tika = new Tika();
|
||||
|
||||
boolean supportsSuffix(DocumentParseContext context, Set<String> suffixes) {
|
||||
return context != null
|
||||
&& StringUtils.hasText(context.getSuffix())
|
||||
&& suffixes.contains(context.getSuffix().trim().toLowerCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
boolean supportsContentType(DocumentParseContext context, String prefix) {
|
||||
return context != null
|
||||
&& StringUtils.hasText(context.getContentType())
|
||||
&& context.getContentType().trim().toLowerCase(Locale.ROOT).startsWith(prefix);
|
||||
}
|
||||
|
||||
DocumentParseResult parseWithTika(DocumentParseContext context) {
|
||||
if (context == null || context.getFilePath() == null) {
|
||||
throw new DocumentParseException("解析文件不能为空");
|
||||
}
|
||||
try {
|
||||
Metadata metadata = new Metadata();
|
||||
metadata.set(TikaCoreProperties.RESOURCE_NAME_KEY, context.getOriginalName());
|
||||
if (StringUtils.hasText(context.getContentType())) {
|
||||
metadata.set(Metadata.CONTENT_TYPE, context.getContentType());
|
||||
}
|
||||
String text;
|
||||
try (InputStream inputStream = Files.newInputStream(context.getFilePath())) {
|
||||
text = tika.parseToString(inputStream, metadata, MAX_TEXT_LENGTH);
|
||||
}
|
||||
DocumentParseResult result = new DocumentParseResult();
|
||||
result.setText(text == null ? "" : text.trim());
|
||||
result.setTextLength(result.getText().length());
|
||||
result.getMetadata().put("contentType", firstNonBlank(metadata.get(Metadata.CONTENT_TYPE), context.getContentType()));
|
||||
result.getMetadata().put("resourceName", firstNonBlank(metadata.get(TikaCoreProperties.RESOURCE_NAME_KEY), context.getOriginalName()));
|
||||
return result;
|
||||
} catch (IOException | TikaException e) {
|
||||
throw new DocumentParseException("文档解析失败: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
private String firstNonBlank(String first, String fallback) {
|
||||
return StringUtils.hasText(first) ? first : fallback;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.bruce.common.document.parse.impl;
|
||||
|
||||
import com.bruce.common.document.parse.DocumentParseContext;
|
||||
import com.bruce.common.document.parse.DocumentParser;
|
||||
import com.bruce.common.document.parse.DocumentParseResult;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
@Component
|
||||
public class ExcelDocumentParser extends AbstractTikaDocumentParser implements DocumentParser {
|
||||
|
||||
private static final Set<String> SUFFIXES = Set.of("xls", "xlsx");
|
||||
|
||||
@Override
|
||||
public boolean supports(DocumentParseContext context) {
|
||||
return supportsSuffix(context, SUFFIXES)
|
||||
|| supportsContentType(context, "application/vnd.ms-excel")
|
||||
|| supportsContentType(context, "application/vnd.openxmlformats-officedocument.spreadsheetml");
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocumentParseResult parse(DocumentParseContext context) {
|
||||
return parseWithTika(context);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.bruce.common.document.parse.impl;
|
||||
|
||||
import com.bruce.common.document.parse.DocumentParseContext;
|
||||
import com.bruce.common.document.parse.DocumentParser;
|
||||
import com.bruce.common.document.parse.DocumentParseResult;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
@Component
|
||||
public class PdfDocumentParser extends AbstractTikaDocumentParser implements DocumentParser {
|
||||
|
||||
private static final Set<String> SUFFIXES = Set.of("pdf");
|
||||
|
||||
@Override
|
||||
public boolean supports(DocumentParseContext context) {
|
||||
return supportsSuffix(context, SUFFIXES) || supportsContentType(context, "application/pdf");
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocumentParseResult parse(DocumentParseContext context) {
|
||||
return parseWithTika(context);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.bruce.common.document.parse.impl;
|
||||
|
||||
import com.bruce.common.document.parse.DocumentParseContext;
|
||||
import com.bruce.common.document.parse.DocumentParser;
|
||||
import com.bruce.common.document.parse.DocumentParseResult;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
@Component
|
||||
public class TxtDocumentParser extends AbstractTikaDocumentParser implements DocumentParser {
|
||||
|
||||
private static final Set<String> SUFFIXES = Set.of("txt", "md", "log");
|
||||
|
||||
@Override
|
||||
public boolean supports(DocumentParseContext context) {
|
||||
return supportsSuffix(context, SUFFIXES) || supportsContentType(context, "text/");
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocumentParseResult parse(DocumentParseContext context) {
|
||||
return parseWithTika(context);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.bruce.common.document.parse.impl;
|
||||
|
||||
import com.bruce.common.document.parse.DocumentParseContext;
|
||||
import com.bruce.common.document.parse.DocumentParser;
|
||||
import com.bruce.common.document.parse.DocumentParseResult;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
@Component
|
||||
public class WordDocumentParser extends AbstractTikaDocumentParser implements DocumentParser {
|
||||
|
||||
private static final Set<String> SUFFIXES = Set.of("doc", "docx");
|
||||
|
||||
@Override
|
||||
public boolean supports(DocumentParseContext context) {
|
||||
return supportsSuffix(context, SUFFIXES)
|
||||
|| supportsContentType(context, "application/msword")
|
||||
|| supportsContentType(context, "application/vnd.openxmlformats-officedocument.wordprocessingml");
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocumentParseResult parse(DocumentParseContext context) {
|
||||
return parseWithTika(context);
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,12 @@ package com.bruce.rag.controller;
|
||||
|
||||
import com.bruce.common.domain.model.RequestResult;
|
||||
import com.bruce.rag.dto.request.RagDocumentBatchUploadRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentParseRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentSaveRequest;
|
||||
import com.bruce.rag.dto.response.RagDocumentParseResponse;
|
||||
import com.bruce.rag.dto.response.RagDocumentResponse;
|
||||
import com.bruce.rag.service.IRagDocumentParseService;
|
||||
import com.bruce.rag.service.IRagDocumentService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
@@ -29,6 +32,9 @@ public class RagDocumentController {
|
||||
@Autowired
|
||||
private IRagDocumentService ragDocumentService;
|
||||
|
||||
@Autowired
|
||||
private IRagDocumentParseService ragDocumentParseService;
|
||||
|
||||
@Operation(summary = "查询全部知识库文档")
|
||||
@PostMapping("/list")
|
||||
public RequestResult<List<RagDocumentResponse>> list() {
|
||||
@@ -85,4 +91,13 @@ public class RagDocumentController {
|
||||
request.getStoreId(), responses.size());
|
||||
return RequestResult.success(responses);
|
||||
}
|
||||
|
||||
@Operation(summary = "解析知识库文档")
|
||||
@PostMapping("/parse")
|
||||
public RequestResult<List<RagDocumentParseResponse>> parse(@RequestBody RagDocumentParseRequest request) {
|
||||
log.info("RagDocumentController.parse start, request={}", request);
|
||||
List<RagDocumentParseResponse> responses = ragDocumentParseService.parse(request);
|
||||
log.info("RagDocumentController.parse success, count={}", responses.size());
|
||||
return RequestResult.success(responses);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.bruce.rag.dto.request;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Schema(description = "RAG知识库文档解析请求")
|
||||
public class RagDocumentParseRequest {
|
||||
|
||||
@Schema(description = "文档ID列表")
|
||||
private List<Long> documentIds;
|
||||
|
||||
@Schema(description = "切片方式")
|
||||
private String chunkStrategy;
|
||||
|
||||
@Schema(description = "切片长度")
|
||||
private Integer chunkSize;
|
||||
|
||||
@Schema(description = "重叠长度")
|
||||
private Integer chunkOverlap;
|
||||
|
||||
@Schema(description = "分隔符")
|
||||
private String delimiter;
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package com.bruce.rag.dto.response;
|
||||
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@Schema(description = "RAG知识库文档解析响应")
|
||||
public class RagDocumentParseResponse {
|
||||
|
||||
@Schema(description = "文档ID")
|
||||
@JsonSerialize(using = ToStringSerializer.class)
|
||||
private Long documentId;
|
||||
|
||||
@Schema(description = "解析状态")
|
||||
private String parseStatus;
|
||||
|
||||
@Schema(description = "文本长度")
|
||||
private Integer textLength;
|
||||
|
||||
@Schema(description = "页数")
|
||||
private Integer pageCount;
|
||||
|
||||
@Schema(description = "工作表数量")
|
||||
private Integer sheetCount;
|
||||
|
||||
@Schema(description = "解析元数据")
|
||||
private Map<String, Object> metadata = new LinkedHashMap<>();
|
||||
}
|
||||
67
src/main/java/com/bruce/rag/entity/RagChunk.java
Normal file
67
src/main/java/com/bruce/rag/entity/RagChunk.java
Normal file
@@ -0,0 +1,67 @@
|
||||
package com.bruce.rag.entity;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.bruce.common.domain.model.BaseEntity;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@TableName("rag_chunk")
|
||||
@Schema(description = "RAG知识切片")
|
||||
public class RagChunk extends BaseEntity {
|
||||
|
||||
@Schema(description = "知识库ID")
|
||||
@TableField("store_id")
|
||||
private Long storeId;
|
||||
|
||||
@Schema(description = "文档ID")
|
||||
@TableField("document_id")
|
||||
private Long documentId;
|
||||
|
||||
@Schema(description = "文档内切片序号")
|
||||
@TableField("chunk_index")
|
||||
private Integer chunkIndex;
|
||||
|
||||
@Schema(description = "切片内容")
|
||||
@TableField("chunk_content")
|
||||
private String chunkContent;
|
||||
|
||||
@Schema(description = "切片摘要")
|
||||
@TableField("chunk_summary")
|
||||
private String chunkSummary;
|
||||
|
||||
@Schema(description = "Token数量")
|
||||
@TableField("token_count")
|
||||
private Integer tokenCount;
|
||||
|
||||
@Schema(description = "页码")
|
||||
@TableField("page_number")
|
||||
private Integer pageNumber;
|
||||
|
||||
@Schema(description = "章节标题")
|
||||
@TableField("section_title")
|
||||
private String sectionTitle;
|
||||
|
||||
@Schema(description = "标题路径")
|
||||
@TableField("heading_path")
|
||||
private String headingPath;
|
||||
|
||||
@Schema(description = "向量ID")
|
||||
@TableField("vector_id")
|
||||
private String vectorId;
|
||||
|
||||
@Schema(description = "切片级扩展元数据JSON")
|
||||
@TableField("metadata_json")
|
||||
private String metadataJson;
|
||||
|
||||
@Schema(description = "是否启用")
|
||||
private Boolean enabled;
|
||||
|
||||
@Schema(description = "备注")
|
||||
private String remark;
|
||||
}
|
||||
50
src/main/java/com/bruce/rag/entity/RagChunkEmbedding.java
Normal file
50
src/main/java/com/bruce/rag/entity/RagChunkEmbedding.java
Normal file
@@ -0,0 +1,50 @@
|
||||
package com.bruce.rag.entity;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.bruce.common.domain.model.BaseEntity;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@TableName("rag_chunk_embedding")
|
||||
@Schema(description = "RAG切片向量")
|
||||
public class RagChunkEmbedding extends BaseEntity {
|
||||
|
||||
@Schema(description = "知识库ID")
|
||||
@TableField("store_id")
|
||||
private Long storeId;
|
||||
|
||||
@Schema(description = "文档ID")
|
||||
@TableField("document_id")
|
||||
private Long documentId;
|
||||
|
||||
@Schema(description = "切片ID")
|
||||
@TableField("chunk_id")
|
||||
private Long chunkId;
|
||||
|
||||
@Schema(description = "向量模型")
|
||||
@TableField("embedding_model")
|
||||
private String embeddingModel;
|
||||
|
||||
@Schema(description = "向量维度")
|
||||
@TableField("embedding_dimension")
|
||||
private Integer embeddingDimension;
|
||||
|
||||
@Schema(description = "向量内容")
|
||||
private String embedding;
|
||||
|
||||
@Schema(description = "向量生成内容哈希")
|
||||
@TableField("content_hash")
|
||||
private String contentHash;
|
||||
|
||||
@Schema(description = "是否启用")
|
||||
private Boolean enabled;
|
||||
|
||||
@Schema(description = "备注")
|
||||
private String remark;
|
||||
}
|
||||
20
src/main/java/com/bruce/rag/enums/RagChunkStrategyEnum.java
Normal file
20
src/main/java/com/bruce/rag/enums/RagChunkStrategyEnum.java
Normal file
@@ -0,0 +1,20 @@
|
||||
package com.bruce.rag.enums;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum RagChunkStrategyEnum {
|
||||
|
||||
FIXED_LENGTH(1, "固定长度切片"),
|
||||
PARAGRAPH(2, "按段落切片"),
|
||||
HEADING(3, "按标题层级切片"),
|
||||
TABLE_ROW(4, "按表格行切片"),
|
||||
DELIMITER(5, "按分隔符切片"),
|
||||
SEMANTIC(6, "语义切片");
|
||||
|
||||
private final Integer value;
|
||||
|
||||
private final String label;
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.bruce.rag.mapper;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.bruce.rag.entity.RagChunkEmbedding;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface RagChunkEmbeddingMapper extends BaseMapper<RagChunkEmbedding> {
|
||||
}
|
||||
9
src/main/java/com/bruce/rag/mapper/RagChunkMapper.java
Normal file
9
src/main/java/com/bruce/rag/mapper/RagChunkMapper.java
Normal file
@@ -0,0 +1,9 @@
|
||||
package com.bruce.rag.mapper;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.bruce.rag.entity.RagChunk;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface RagChunkMapper extends BaseMapper<RagChunk> {
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.bruce.rag.service;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.IService;
|
||||
import com.bruce.rag.entity.RagChunkEmbedding;
|
||||
|
||||
public interface IRagChunkEmbeddingService extends IService<RagChunkEmbedding> {
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.bruce.rag.service;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.IService;
|
||||
import com.bruce.rag.entity.RagChunk;
|
||||
|
||||
public interface IRagChunkService extends IService<RagChunk> {
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.bruce.rag.service;
|
||||
|
||||
import com.bruce.rag.dto.response.RagDocumentParseResponse;
|
||||
import com.bruce.rag.dto.request.RagDocumentParseRequest;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface IRagDocumentParseService {
|
||||
|
||||
RagDocumentParseResponse parse(Long documentId);
|
||||
|
||||
List<RagDocumentParseResponse> parse(RagDocumentParseRequest request);
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.bruce.rag.service.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.bruce.rag.entity.RagChunkEmbedding;
|
||||
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
|
||||
import com.bruce.rag.service.IRagChunkEmbeddingService;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class RagChunkEmbeddingServiceImpl extends ServiceImpl<RagChunkEmbeddingMapper, RagChunkEmbedding>
|
||||
implements IRagChunkEmbeddingService {
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.bruce.rag.service.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.bruce.rag.entity.RagChunk;
|
||||
import com.bruce.rag.mapper.RagChunkMapper;
|
||||
import com.bruce.rag.service.IRagChunkService;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class RagChunkServiceImpl extends ServiceImpl<RagChunkMapper, RagChunk> implements IRagChunkService {
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
package com.bruce.rag.service.impl;
|
||||
|
||||
import com.bruce.common.config.AttachmentProperties;
|
||||
import com.bruce.common.document.parse.DocumentParseContext;
|
||||
import com.bruce.common.document.parse.DocumentParseException;
|
||||
import com.bruce.common.document.parse.DocumentParseResult;
|
||||
import com.bruce.common.document.parse.DocumentParser;
|
||||
import com.bruce.common.document.parse.DocumentParserFactory;
|
||||
import com.bruce.common.domain.entity.SysAttachment;
|
||||
import com.bruce.common.service.ISysAttachmentService;
|
||||
import com.bruce.rag.dto.request.RagDocumentParseRequest;
|
||||
import com.bruce.rag.dto.response.RagDocumentParseResponse;
|
||||
import com.bruce.rag.entity.RagDocument;
|
||||
import com.bruce.rag.enums.RagChunkStrategyEnum;
|
||||
import com.bruce.rag.enums.RagParseStatusEnum;
|
||||
import com.bruce.rag.service.IRagDocumentParseService;
|
||||
import com.bruce.rag.service.IRagDocumentService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class RagDocumentParseServiceImpl implements IRagDocumentParseService {
|
||||
|
||||
private final IRagDocumentService ragDocumentService;
|
||||
|
||||
private final ISysAttachmentService sysAttachmentService;
|
||||
|
||||
private final AttachmentProperties attachmentProperties;
|
||||
|
||||
private final DocumentParserFactory documentParserFactory;
|
||||
|
||||
@Override
|
||||
public List<RagDocumentParseResponse> parse(RagDocumentParseRequest request) {
|
||||
log.info("RagDocumentParseServiceImpl.parse batch start, request={}", request);
|
||||
validateParseRequest(request);
|
||||
List<RagDocumentParseResponse> responses = request.getDocumentIds().stream()
|
||||
.map(this::parse)
|
||||
.toList();
|
||||
log.info("RagDocumentParseServiceImpl.parse batch success, count={}", responses.size());
|
||||
return responses;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RagDocumentParseResponse parse(Long documentId) {
|
||||
log.info("RagDocumentParseServiceImpl.parse start, documentId={}", documentId);
|
||||
if (documentId == null) {
|
||||
throw new IllegalArgumentException("文档ID不能为空");
|
||||
}
|
||||
|
||||
RagDocument document = ragDocumentService.getById(documentId);
|
||||
if (document == null) {
|
||||
throw new IllegalArgumentException("文档不存在,ID: " + documentId);
|
||||
}
|
||||
if (document.getAttachmentId() == null) {
|
||||
throw new IllegalArgumentException("文档附件ID不能为空");
|
||||
}
|
||||
|
||||
SysAttachment attachment = sysAttachmentService.getById(document.getAttachmentId());
|
||||
if (attachment == null) {
|
||||
throw new IllegalArgumentException("附件不存在,ID: " + document.getAttachmentId());
|
||||
}
|
||||
|
||||
updateParseStatus(documentId, RagParseStatusEnum.PARSING, null);
|
||||
try {
|
||||
DocumentParseContext context = buildParseContext(document, attachment);
|
||||
DocumentParser parser = documentParserFactory.resolve(context);
|
||||
DocumentParseResult result = parser.parse(context);
|
||||
updateParseStatus(documentId, RagParseStatusEnum.PARSED, null);
|
||||
RagDocumentParseResponse response = toResponse(documentId, result);
|
||||
log.info("RagDocumentParseServiceImpl.parse success, documentId={}, textLength={}",
|
||||
documentId, response.getTextLength());
|
||||
return response;
|
||||
} catch (RuntimeException e) {
|
||||
updateParseStatus(documentId, RagParseStatusEnum.FAILED, e.getMessage());
|
||||
log.warn("RagDocumentParseServiceImpl.parse failed, documentId={}, message={}", documentId, e.getMessage());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private void validateParseRequest(RagDocumentParseRequest request) {
|
||||
if (request == null) {
|
||||
throw new IllegalArgumentException("解析请求不能为空");
|
||||
}
|
||||
if (request.getDocumentIds() == null || request.getDocumentIds().isEmpty()) {
|
||||
throw new IllegalArgumentException("文档ID列表不能为空");
|
||||
}
|
||||
Set<String> strategies = Arrays.stream(RagChunkStrategyEnum.values())
|
||||
.map(Enum::name)
|
||||
.collect(Collectors.toSet());
|
||||
if (request.getChunkStrategy() == null || !strategies.contains(request.getChunkStrategy())) {
|
||||
throw new IllegalArgumentException("不支持的切片方式: " + request.getChunkStrategy());
|
||||
}
|
||||
}
|
||||
|
||||
private DocumentParseContext buildParseContext(RagDocument document, SysAttachment attachment) {
|
||||
Path filePath = resolveFilePath(attachment);
|
||||
if (!Files.isRegularFile(filePath)) {
|
||||
throw new DocumentParseException("解析文件不存在: " + filePath);
|
||||
}
|
||||
|
||||
DocumentParseContext context = new DocumentParseContext();
|
||||
context.setDocumentId(document.getId());
|
||||
context.setAttachmentId(attachment.getId());
|
||||
context.setOriginalName(attachment.getOriginalName());
|
||||
context.setSuffix(attachment.getFileSuffix());
|
||||
context.setContentType(attachment.getContentType());
|
||||
context.setFilePath(filePath);
|
||||
return context;
|
||||
}
|
||||
|
||||
private Path resolveFilePath(SysAttachment attachment) {
|
||||
if (!StringUtils.hasText(attachment.getFilePath())) {
|
||||
throw new DocumentParseException("附件文件路径不能为空");
|
||||
}
|
||||
Path filePath = Path.of(attachment.getFilePath());
|
||||
if (filePath.isAbsolute()) {
|
||||
return filePath.normalize();
|
||||
}
|
||||
return Path.of(attachmentProperties.getBasePath()).resolve(filePath).normalize();
|
||||
}
|
||||
|
||||
private void updateParseStatus(Long documentId, RagParseStatusEnum status, String errorMessage) {
|
||||
RagDocument update = new RagDocument();
|
||||
update.setId(documentId);
|
||||
update.setParseStatus(status.name());
|
||||
update.setErrorMessage(StringUtils.hasText(errorMessage) ? errorMessage : null);
|
||||
ragDocumentService.updateById(update);
|
||||
}
|
||||
|
||||
private RagDocumentParseResponse toResponse(Long documentId, DocumentParseResult result) {
|
||||
RagDocumentParseResponse response = new RagDocumentParseResponse();
|
||||
response.setDocumentId(documentId);
|
||||
response.setParseStatus(RagParseStatusEnum.PARSED.name());
|
||||
response.setTextLength(result.getTextLength());
|
||||
response.setPageCount(result.getPageCount());
|
||||
response.setSheetCount(result.getSheetCount());
|
||||
response.setMetadata(result.getMetadata());
|
||||
return response;
|
||||
}
|
||||
}
|
||||
@@ -82,10 +82,18 @@ public class RagDocumentServiceImpl extends ServiceImpl<RagDocumentMapper, RagDo
|
||||
document.setIndexStatus(RagIndexStatusEnum.PENDING.name());
|
||||
}
|
||||
|
||||
document.setStoreId(request.getStoreId());
|
||||
document.setAttachmentId(request.getAttachmentId());
|
||||
document.setDocumentTitle(request.getDocumentTitle().trim());
|
||||
document.setDocumentSummary(trimToNull(request.getDocumentSummary()));
|
||||
if (request.getStoreId() != null) {
|
||||
document.setStoreId(request.getStoreId());
|
||||
}
|
||||
if (request.getAttachmentId() != null) {
|
||||
document.setAttachmentId(request.getAttachmentId());
|
||||
}
|
||||
if (StringUtils.hasText(request.getDocumentTitle())) {
|
||||
document.setDocumentTitle(request.getDocumentTitle().trim());
|
||||
}
|
||||
if (request.getDocumentSummary() != null) {
|
||||
document.setDocumentSummary(trimToNull(request.getDocumentSummary()));
|
||||
}
|
||||
if (StringUtils.hasText(request.getParseStatus())) {
|
||||
document.setParseStatus(request.getParseStatus().trim());
|
||||
}
|
||||
@@ -95,10 +103,14 @@ public class RagDocumentServiceImpl extends ServiceImpl<RagDocumentMapper, RagDo
|
||||
if (request.getEnabled() != null) {
|
||||
document.setEnabled(request.getEnabled());
|
||||
}
|
||||
document.setErrorMessage(trimToNull(request.getErrorMessage()));
|
||||
document.setRemark(trimToNull(request.getRemark()));
|
||||
if (request.getErrorMessage() != null) {
|
||||
document.setErrorMessage(trimToNull(request.getErrorMessage()));
|
||||
}
|
||||
if (request.getRemark() != null) {
|
||||
document.setRemark(trimToNull(request.getRemark()));
|
||||
}
|
||||
|
||||
boolean result = saveOrUpdate(document);
|
||||
boolean result = request.getId() == null ? save(document) : updateById(document);
|
||||
log.info("RagDocumentServiceImpl.saveOrUpdate success, requestId={}, savedId={}, result={}",
|
||||
request.getId(), document.getId(), result);
|
||||
return result;
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.bruce.common.config;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.OptimisticLockerInnerInterceptor;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
class MybatisPlusConfigTests {
|
||||
|
||||
@Test
|
||||
void mybatisPlusInterceptorShouldRegisterOptimisticLocker() {
|
||||
MybatisPlusConfig config = new MybatisPlusConfig();
|
||||
|
||||
var interceptor = config.mybatisPlusInterceptor();
|
||||
|
||||
assertTrue(interceptor.getInterceptors().stream()
|
||||
.anyMatch(OptimisticLockerInnerInterceptor.class::isInstance));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package com.bruce.common.document.parse;
|
||||
|
||||
import com.bruce.common.document.parse.impl.ExcelDocumentParser;
|
||||
import com.bruce.common.document.parse.impl.PdfDocumentParser;
|
||||
import com.bruce.common.document.parse.impl.TxtDocumentParser;
|
||||
import com.bruce.common.document.parse.impl.WordDocumentParser;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
class DocumentParserFactoryTests {
|
||||
|
||||
@Test
|
||||
void resolveShouldChooseParserByFileSuffix() {
|
||||
DocumentParserFactory factory = new DocumentParserFactory(List.of(
|
||||
new TxtDocumentParser(),
|
||||
new WordDocumentParser(),
|
||||
new PdfDocumentParser(),
|
||||
new ExcelDocumentParser()
|
||||
));
|
||||
|
||||
assertEquals(TxtDocumentParser.class, factory.resolve(context("txt")).getClass());
|
||||
assertEquals(WordDocumentParser.class, factory.resolve(context("docx")).getClass());
|
||||
assertEquals(PdfDocumentParser.class, factory.resolve(context("pdf")).getClass());
|
||||
assertEquals(ExcelDocumentParser.class, factory.resolve(context("xlsx")).getClass());
|
||||
}
|
||||
|
||||
@Test
|
||||
void resolveShouldRejectUnsupportedSuffix() {
|
||||
DocumentParserFactory factory = new DocumentParserFactory(List.of(new TxtDocumentParser()));
|
||||
|
||||
DocumentParseException exception = assertThrows(
|
||||
DocumentParseException.class,
|
||||
() -> factory.resolve(context("zip"))
|
||||
);
|
||||
|
||||
assertEquals("不支持的文档类型: zip", exception.getMessage());
|
||||
}
|
||||
|
||||
private DocumentParseContext context(String suffix) {
|
||||
DocumentParseContext context = new DocumentParseContext();
|
||||
context.setSuffix(suffix);
|
||||
context.setFilePath(Path.of("sample." + suffix));
|
||||
return context;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package com.bruce.common.document.parse;
|
||||
|
||||
import com.bruce.common.document.parse.impl.TxtDocumentParser;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
class TxtDocumentParserTests {
|
||||
|
||||
@TempDir
|
||||
private Path tempDir;
|
||||
|
||||
@Test
|
||||
void parseShouldReadPlainTextContent() throws Exception {
|
||||
Path file = tempDir.resolve("people.txt");
|
||||
Files.writeString(file, "张三 是 产品经理\n李四 是 后端工程师", StandardCharsets.UTF_8);
|
||||
DocumentParseContext context = new DocumentParseContext();
|
||||
context.setOriginalName("people.txt");
|
||||
context.setSuffix("txt");
|
||||
context.setContentType("text/plain");
|
||||
context.setFilePath(file);
|
||||
|
||||
DocumentParseResult result = new TxtDocumentParser().parse(context);
|
||||
|
||||
assertEquals("张三 是 产品经理\n李四 是 后端工程师", result.getText());
|
||||
assertEquals(result.getText().length(), result.getTextLength());
|
||||
assertTrue(result.getMetadata().get("contentType").toString().startsWith("text/plain"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void supportsShouldAcceptTextSuffixAndContentType() {
|
||||
TxtDocumentParser parser = new TxtDocumentParser();
|
||||
DocumentParseContext suffixContext = new DocumentParseContext();
|
||||
suffixContext.setSuffix("TXT");
|
||||
DocumentParseContext contentTypeContext = new DocumentParseContext();
|
||||
contentTypeContext.setContentType("text/plain");
|
||||
|
||||
assertTrue(parser.supports(suffixContext));
|
||||
assertTrue(parser.supports(contentTypeContext));
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package com.bruce.common.enumconfig;
|
||||
import com.bruce.common.enums.CommonStatusEnum;
|
||||
import com.bruce.common.enums.EnableStatusEnum;
|
||||
import com.bruce.rag.enums.RagIndexStatusEnum;
|
||||
import com.bruce.rag.enums.RagChunkStrategyEnum;
|
||||
import com.bruce.rag.enums.RagParseStatusEnum;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -24,6 +25,9 @@ class EnumDefinitionTests {
|
||||
assertEquals(4, RagParseStatusEnum.FAILED.getValue());
|
||||
assertEquals(1, RagIndexStatusEnum.PENDING.getValue());
|
||||
assertEquals(3, RagIndexStatusEnum.INDEXED.getValue());
|
||||
assertEquals(1, RagChunkStrategyEnum.FIXED_LENGTH.getValue());
|
||||
assertEquals(5, RagChunkStrategyEnum.DELIMITER.getValue());
|
||||
assertEquals(6, RagChunkStrategyEnum.SEMANTIC.getValue());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -38,5 +42,8 @@ class EnumDefinitionTests {
|
||||
assertEquals("解析失败", RagParseStatusEnum.FAILED.getLabel());
|
||||
assertEquals("待索引", RagIndexStatusEnum.PENDING.getLabel());
|
||||
assertEquals("已索引", RagIndexStatusEnum.INDEXED.getLabel());
|
||||
assertEquals("固定长度切片", RagChunkStrategyEnum.FIXED_LENGTH.getLabel());
|
||||
assertEquals("按分隔符切片", RagChunkStrategyEnum.DELIMITER.getLabel());
|
||||
assertEquals("语义切片", RagChunkStrategyEnum.SEMANTIC.getLabel());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.bruce.common.domain.entity.SysEnum;
|
||||
import com.bruce.common.enums.CommonStatusEnum;
|
||||
import com.bruce.common.enums.EnableStatusEnum;
|
||||
import com.bruce.common.service.ISysEnumService;
|
||||
import com.bruce.rag.enums.RagChunkStrategyEnum;
|
||||
import com.bruce.rag.enums.RagIndexStatusEnum;
|
||||
import com.bruce.rag.enums.RagParseStatusEnum;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -40,6 +41,13 @@ class SysEnumDataInitTests {
|
||||
saveOrUpdate("rag", "index_status", RagIndexStatusEnum.INDEXING.getLabel(), RagIndexStatusEnum.INDEXING.getValue(), 2, "RAG文档索引状态");
|
||||
saveOrUpdate("rag", "index_status", RagIndexStatusEnum.INDEXED.getLabel(), RagIndexStatusEnum.INDEXED.getValue(), 3, "RAG文档索引状态");
|
||||
saveOrUpdate("rag", "index_status", RagIndexStatusEnum.FAILED.getLabel(), RagIndexStatusEnum.FAILED.getValue(), 4, "RAG文档索引状态");
|
||||
|
||||
saveOrUpdate("rag", "chunk_strategy", RagChunkStrategyEnum.FIXED_LENGTH.getLabel(), RagChunkStrategyEnum.FIXED_LENGTH.getValue(), 1, "RAG文档切片方式");
|
||||
saveOrUpdate("rag", "chunk_strategy", RagChunkStrategyEnum.PARAGRAPH.getLabel(), RagChunkStrategyEnum.PARAGRAPH.getValue(), 2, "RAG文档切片方式");
|
||||
saveOrUpdate("rag", "chunk_strategy", RagChunkStrategyEnum.HEADING.getLabel(), RagChunkStrategyEnum.HEADING.getValue(), 3, "RAG文档切片方式");
|
||||
saveOrUpdate("rag", "chunk_strategy", RagChunkStrategyEnum.TABLE_ROW.getLabel(), RagChunkStrategyEnum.TABLE_ROW.getValue(), 4, "RAG文档切片方式");
|
||||
saveOrUpdate("rag", "chunk_strategy", RagChunkStrategyEnum.DELIMITER.getLabel(), RagChunkStrategyEnum.DELIMITER.getValue(), 5, "RAG文档切片方式");
|
||||
saveOrUpdate("rag", "chunk_strategy", RagChunkStrategyEnum.SEMANTIC.getLabel(), RagChunkStrategyEnum.SEMANTIC.getValue(), 6, "RAG文档切片方式");
|
||||
}
|
||||
|
||||
private void saveOrUpdate(String catalog, String type, String name, Integer value, Integer sort, String remark) {
|
||||
|
||||
@@ -8,18 +8,29 @@ import com.bruce.rag.constant.RagSystemConstants;
|
||||
import com.bruce.rag.controller.RagDocumentController;
|
||||
import com.bruce.rag.controller.RagStoreController;
|
||||
import com.bruce.rag.dto.request.RagDocumentQueryRequest;
|
||||
import com.bruce.rag.dto.request.RagDocumentParseRequest;
|
||||
import com.bruce.rag.dto.request.RagStoreQueryRequest;
|
||||
import com.bruce.rag.dto.request.RagStoreSaveRequest;
|
||||
import com.bruce.rag.dto.response.RagDocumentParseResponse;
|
||||
import com.bruce.rag.dto.response.RagStoreDocumentOverviewResponse;
|
||||
import com.bruce.rag.dto.response.RagStoreOverviewResponse;
|
||||
import com.bruce.rag.dto.response.RagDocumentResponse;
|
||||
import com.bruce.rag.dto.response.RagStoreResponse;
|
||||
import com.bruce.rag.entity.RagChunk;
|
||||
import com.bruce.rag.entity.RagChunkEmbedding;
|
||||
import com.bruce.rag.entity.RagDocument;
|
||||
import com.bruce.rag.entity.RagStore;
|
||||
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
|
||||
import com.bruce.rag.mapper.RagChunkMapper;
|
||||
import com.bruce.rag.mapper.RagDocumentMapper;
|
||||
import com.bruce.rag.mapper.RagStoreMapper;
|
||||
import com.bruce.rag.service.IRagChunkEmbeddingService;
|
||||
import com.bruce.rag.service.IRagChunkService;
|
||||
import com.bruce.rag.service.IRagDocumentParseService;
|
||||
import com.bruce.rag.service.IRagDocumentService;
|
||||
import com.bruce.rag.service.IRagStoreService;
|
||||
import com.bruce.rag.service.impl.RagChunkEmbeddingServiceImpl;
|
||||
import com.bruce.rag.service.impl.RagChunkServiceImpl;
|
||||
import com.bruce.rag.service.impl.RagDocumentServiceImpl;
|
||||
import com.bruce.rag.service.impl.RagStoreServiceImpl;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -39,10 +50,16 @@ class RagComponentStructureTests {
|
||||
void ragComponentsShouldReuseMybatisPlusBaseTypes() {
|
||||
assertTrue(BaseMapper.class.isAssignableFrom(RagStoreMapper.class));
|
||||
assertTrue(BaseMapper.class.isAssignableFrom(RagDocumentMapper.class));
|
||||
assertTrue(BaseMapper.class.isAssignableFrom(RagChunkMapper.class));
|
||||
assertTrue(BaseMapper.class.isAssignableFrom(RagChunkEmbeddingMapper.class));
|
||||
assertTrue(IService.class.isAssignableFrom(IRagStoreService.class));
|
||||
assertTrue(IService.class.isAssignableFrom(IRagDocumentService.class));
|
||||
assertTrue(IService.class.isAssignableFrom(IRagChunkService.class));
|
||||
assertTrue(IService.class.isAssignableFrom(IRagChunkEmbeddingService.class));
|
||||
assertTrue(ServiceImpl.class.isAssignableFrom(RagStoreServiceImpl.class));
|
||||
assertTrue(ServiceImpl.class.isAssignableFrom(RagDocumentServiceImpl.class));
|
||||
assertTrue(ServiceImpl.class.isAssignableFrom(RagChunkServiceImpl.class));
|
||||
assertTrue(ServiceImpl.class.isAssignableFrom(RagChunkEmbeddingServiceImpl.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -63,8 +80,10 @@ class RagComponentStructureTests {
|
||||
|
||||
Method documentListMethod = RagDocumentController.class.getMethod("list");
|
||||
Method documentQueryMethod = RagDocumentController.class.getMethod("query", RagDocumentQueryRequest.class);
|
||||
Method documentParseMethod = RagDocumentController.class.getMethod("parse", RagDocumentParseRequest.class);
|
||||
Method documentResponseListMethod = IRagDocumentService.class.getMethod("listResponses");
|
||||
Method documentServiceQueryMethod = IRagDocumentService.class.getMethod("query", RagDocumentQueryRequest.class);
|
||||
Method documentParseServiceMethod = IRagDocumentParseService.class.getMethod("parse", RagDocumentParseRequest.class);
|
||||
|
||||
assertEquals(RequestResult.class, storeListMethod.getReturnType());
|
||||
assertEquals(RequestResult.class, storeQueryMethod.getReturnType());
|
||||
@@ -89,11 +108,14 @@ class RagComponentStructureTests {
|
||||
|
||||
assertEquals(RequestResult.class, documentListMethod.getReturnType());
|
||||
assertEquals(RequestResult.class, documentQueryMethod.getReturnType());
|
||||
assertEquals(RequestResult.class, documentParseMethod.getReturnType());
|
||||
assertEquals(List.class, documentServiceQueryMethod.getReturnType());
|
||||
assertEquals(List.class, documentParseServiceMethod.getReturnType());
|
||||
assertTrue(documentResponseListMethod.getGenericReturnType().getTypeName().contains("RagDocumentResponse"));
|
||||
assertTrue(documentServiceQueryMethod.getGenericReturnType().getTypeName().contains("RagDocumentResponse"));
|
||||
assertTrue(documentListMethod.getGenericReturnType().getTypeName().contains("RagDocumentResponse"));
|
||||
assertTrue(documentQueryMethod.getGenericReturnType().getTypeName().contains("RagDocumentResponse"));
|
||||
assertTrue(documentParseMethod.getGenericReturnType().getTypeName().contains("RagDocumentParseResponse"));
|
||||
assertEquals(RagDocumentResponse.class, RagDocumentResponse.class.getMethod("fromEntity", RagDocument.class).getReturnType());
|
||||
}
|
||||
|
||||
@@ -121,4 +143,34 @@ class RagComponentStructureTests {
|
||||
assertTrue(RagStoreController.class.getSimpleName().contains("RagStoreController"));
|
||||
assertTrue(RagDocumentController.class.getSimpleName().contains("RagDocumentController"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void ragChunkStructureShouldSupportChunkMetadata() throws NoSuchFieldException {
|
||||
assertEquals(Long.class, RagChunk.class.getDeclaredField("storeId").getType());
|
||||
assertEquals(Long.class, RagChunk.class.getDeclaredField("documentId").getType());
|
||||
assertEquals(Integer.class, RagChunk.class.getDeclaredField("chunkIndex").getType());
|
||||
assertEquals(String.class, RagChunk.class.getDeclaredField("chunkContent").getType());
|
||||
assertEquals(String.class, RagChunk.class.getDeclaredField("chunkSummary").getType());
|
||||
assertEquals(Integer.class, RagChunk.class.getDeclaredField("tokenCount").getType());
|
||||
assertEquals(Integer.class, RagChunk.class.getDeclaredField("pageNumber").getType());
|
||||
assertEquals(String.class, RagChunk.class.getDeclaredField("sectionTitle").getType());
|
||||
assertEquals(String.class, RagChunk.class.getDeclaredField("headingPath").getType());
|
||||
assertEquals(String.class, RagChunk.class.getDeclaredField("vectorId").getType());
|
||||
assertEquals(String.class, RagChunk.class.getDeclaredField("metadataJson").getType());
|
||||
assertEquals(Boolean.class, RagChunk.class.getDeclaredField("enabled").getType());
|
||||
assertEquals(String.class, RagChunk.class.getDeclaredField("remark").getType());
|
||||
}
|
||||
|
||||
@Test
|
||||
void ragChunkEmbeddingStructureShouldSupportPgvectorMetadata() throws NoSuchFieldException {
|
||||
assertEquals(Long.class, RagChunkEmbedding.class.getDeclaredField("storeId").getType());
|
||||
assertEquals(Long.class, RagChunkEmbedding.class.getDeclaredField("documentId").getType());
|
||||
assertEquals(Long.class, RagChunkEmbedding.class.getDeclaredField("chunkId").getType());
|
||||
assertEquals(String.class, RagChunkEmbedding.class.getDeclaredField("embeddingModel").getType());
|
||||
assertEquals(Integer.class, RagChunkEmbedding.class.getDeclaredField("embeddingDimension").getType());
|
||||
assertEquals(String.class, RagChunkEmbedding.class.getDeclaredField("embedding").getType());
|
||||
assertEquals(String.class, RagChunkEmbedding.class.getDeclaredField("contentHash").getType());
|
||||
assertEquals(Boolean.class, RagChunkEmbedding.class.getDeclaredField("enabled").getType());
|
||||
assertEquals(String.class, RagChunkEmbedding.class.getDeclaredField("remark").getType());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
package com.bruce.rag;
|
||||
|
||||
import com.bruce.common.config.AttachmentProperties;
|
||||
import com.bruce.common.document.parse.DocumentParseContext;
|
||||
import com.bruce.common.document.parse.DocumentParseResult;
|
||||
import com.bruce.common.document.parse.DocumentParser;
|
||||
import com.bruce.common.document.parse.DocumentParserFactory;
|
||||
import com.bruce.common.domain.entity.SysAttachment;
|
||||
import com.bruce.common.service.ISysAttachmentService;
|
||||
import com.bruce.rag.dto.request.RagDocumentParseRequest;
|
||||
import com.bruce.rag.dto.response.RagDocumentParseResponse;
|
||||
import com.bruce.rag.entity.RagDocument;
|
||||
import com.bruce.rag.enums.RagParseStatusEnum;
|
||||
import com.bruce.rag.service.IRagDocumentService;
|
||||
import com.bruce.rag.service.impl.RagDocumentParseServiceImpl;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class RagDocumentParseServiceImplTests {
|
||||
|
||||
@TempDir
|
||||
private Path tempDir;
|
||||
|
||||
@Mock
|
||||
private IRagDocumentService ragDocumentService;
|
||||
|
||||
@Mock
|
||||
private ISysAttachmentService sysAttachmentService;
|
||||
|
||||
@Test
|
||||
void parseShouldUpdateStatusAndReturnParseResponse() throws Exception {
|
||||
Path file = tempDir.resolve("rag").resolve("people.txt");
|
||||
Files.createDirectories(file.getParent());
|
||||
Files.writeString(file, "people profiles");
|
||||
|
||||
RagDocument document = new RagDocument();
|
||||
document.setId(1001L);
|
||||
document.setStoreId(2002L);
|
||||
document.setAttachmentId(3003L);
|
||||
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
|
||||
|
||||
SysAttachment attachment = new SysAttachment();
|
||||
attachment.setId(3003L);
|
||||
attachment.setOriginalName("people.txt");
|
||||
attachment.setFileSuffix("txt");
|
||||
attachment.setContentType("text/plain");
|
||||
attachment.setFilePath("rag/people.txt");
|
||||
|
||||
AttachmentProperties attachmentProperties = new AttachmentProperties();
|
||||
attachmentProperties.setBasePath(tempDir.toString());
|
||||
DocumentParser parser = new FixedDocumentParser("people profiles");
|
||||
RagDocumentParseServiceImpl service = new RagDocumentParseServiceImpl(
|
||||
ragDocumentService,
|
||||
sysAttachmentService,
|
||||
attachmentProperties,
|
||||
new DocumentParserFactory(List.of(parser))
|
||||
);
|
||||
|
||||
when(ragDocumentService.getById(1001L)).thenReturn(document);
|
||||
when(sysAttachmentService.getById(3003L)).thenReturn(attachment);
|
||||
when(ragDocumentService.updateById(any(RagDocument.class))).thenReturn(true);
|
||||
|
||||
RagDocumentParseResponse response = service.parse(1001L);
|
||||
|
||||
assertEquals(1001L, response.getDocumentId());
|
||||
assertEquals(RagParseStatusEnum.PARSED.name(), response.getParseStatus());
|
||||
assertEquals(15, response.getTextLength());
|
||||
assertEquals("fixed", response.getMetadata().get("parser"));
|
||||
|
||||
ArgumentCaptor<RagDocument> captor = ArgumentCaptor.forClass(RagDocument.class);
|
||||
verify(ragDocumentService, times(2)).updateById(captor.capture());
|
||||
List<RagDocument> updates = captor.getAllValues();
|
||||
assertEquals(RagParseStatusEnum.PARSING.name(), updates.get(0).getParseStatus());
|
||||
assertEquals(RagParseStatusEnum.PARSED.name(), updates.get(1).getParseStatus());
|
||||
assertTrue(parser.supports(new DocumentParseContext()));
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseShouldSupportBatchRequestAndChunkStrategyStructure() throws Exception {
|
||||
Path file = tempDir.resolve("rag").resolve("batch.txt");
|
||||
Files.createDirectories(file.getParent());
|
||||
Files.writeString(file, "batch profiles");
|
||||
|
||||
RagDocument document = new RagDocument();
|
||||
document.setId(1002L);
|
||||
document.setStoreId(2002L);
|
||||
document.setAttachmentId(3004L);
|
||||
document.setParseStatus(RagParseStatusEnum.UPLOADED.name());
|
||||
|
||||
SysAttachment attachment = new SysAttachment();
|
||||
attachment.setId(3004L);
|
||||
attachment.setOriginalName("batch.txt");
|
||||
attachment.setFileSuffix("txt");
|
||||
attachment.setContentType("text/plain");
|
||||
attachment.setFilePath("rag/batch.txt");
|
||||
|
||||
AttachmentProperties attachmentProperties = new AttachmentProperties();
|
||||
attachmentProperties.setBasePath(tempDir.toString());
|
||||
RagDocumentParseServiceImpl service = new RagDocumentParseServiceImpl(
|
||||
ragDocumentService,
|
||||
sysAttachmentService,
|
||||
attachmentProperties,
|
||||
new DocumentParserFactory(List.of(new FixedDocumentParser("batch profiles")))
|
||||
);
|
||||
RagDocumentParseRequest request = new RagDocumentParseRequest();
|
||||
request.setDocumentIds(List.of(1002L));
|
||||
request.setChunkStrategy("DELIMITER");
|
||||
request.setDelimiter("。");
|
||||
|
||||
when(ragDocumentService.getById(1002L)).thenReturn(document);
|
||||
when(sysAttachmentService.getById(3004L)).thenReturn(attachment);
|
||||
when(ragDocumentService.updateById(any(RagDocument.class))).thenReturn(true);
|
||||
|
||||
List<RagDocumentParseResponse> responses = service.parse(request);
|
||||
|
||||
assertEquals(1, responses.size());
|
||||
assertEquals(1002L, responses.getFirst().getDocumentId());
|
||||
assertEquals(RagParseStatusEnum.PARSED.name(), responses.getFirst().getParseStatus());
|
||||
}
|
||||
|
||||
private static class FixedDocumentParser implements DocumentParser {
|
||||
|
||||
private final String text;
|
||||
|
||||
private FixedDocumentParser(String text) {
|
||||
this.text = text;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supports(DocumentParseContext context) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocumentParseResult parse(DocumentParseContext context) {
|
||||
DocumentParseResult result = new DocumentParseResult();
|
||||
result.setText(text);
|
||||
result.setTextLength(text.length());
|
||||
result.setMetadata(Map.of("parser", "fixed"));
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -102,13 +102,13 @@ class RagDocumentServiceImplTests {
|
||||
request.setRemark(" 备注信息 ");
|
||||
|
||||
doReturn(existingDocument).when(ragDocumentService).getById(3003L);
|
||||
doReturn(true).when(ragDocumentService).saveOrUpdate(any(RagDocument.class));
|
||||
doReturn(true).when(ragDocumentService).updateById(any(RagDocument.class));
|
||||
|
||||
boolean result = ragDocumentService.saveOrUpdate(request);
|
||||
|
||||
assertTrue(result);
|
||||
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
|
||||
verify(ragDocumentService).saveOrUpdate(documentCaptor.capture());
|
||||
verify(ragDocumentService).updateById(documentCaptor.capture());
|
||||
RagDocument savedDocument = documentCaptor.getValue();
|
||||
assertEquals(3003L, savedDocument.getId());
|
||||
assertEquals(1001L, savedDocument.getStoreId());
|
||||
@@ -121,4 +121,40 @@ class RagDocumentServiceImplTests {
|
||||
assertEquals("已修复", savedDocument.getErrorMessage());
|
||||
assertEquals("备注信息", savedDocument.getRemark());
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveOrUpdateShouldPreserveExistingFieldsForPartialUpdate() {
|
||||
RagDocument existingDocument = new RagDocument();
|
||||
existingDocument.setId(3003L);
|
||||
existingDocument.setStoreId(1001L);
|
||||
existingDocument.setAttachmentId(2002L);
|
||||
existingDocument.setDocumentTitle("people_profiles.txt");
|
||||
existingDocument.setDocumentSummary("测试人员信息,有多条人员信息");
|
||||
existingDocument.setParseStatus(RagParseStatusEnum.UPLOADED.name());
|
||||
existingDocument.setIndexStatus(RagIndexStatusEnum.PENDING.name());
|
||||
existingDocument.setEnabled(true);
|
||||
existingDocument.setRemark("测试人员信息");
|
||||
|
||||
RagDocumentSaveRequest request = new RagDocumentSaveRequest();
|
||||
request.setId(3003L);
|
||||
request.setStoreId(1001L);
|
||||
request.setDocumentTitle("people_profiles.txt");
|
||||
request.setEnabled(false);
|
||||
|
||||
doReturn(existingDocument).when(ragDocumentService).getById(3003L);
|
||||
doReturn(true).when(ragDocumentService).updateById(any(RagDocument.class));
|
||||
|
||||
boolean result = ragDocumentService.saveOrUpdate(request);
|
||||
|
||||
assertTrue(result);
|
||||
ArgumentCaptor<RagDocument> documentCaptor = ArgumentCaptor.forClass(RagDocument.class);
|
||||
verify(ragDocumentService).updateById(documentCaptor.capture());
|
||||
RagDocument savedDocument = documentCaptor.getValue();
|
||||
assertEquals(2002L, savedDocument.getAttachmentId());
|
||||
assertEquals("测试人员信息,有多条人员信息", savedDocument.getDocumentSummary());
|
||||
assertEquals(RagParseStatusEnum.UPLOADED.name(), savedDocument.getParseStatus());
|
||||
assertEquals(RagIndexStatusEnum.PENDING.name(), savedDocument.getIndexStatus());
|
||||
assertEquals(false, savedDocument.getEnabled());
|
||||
assertEquals("测试人员信息", savedDocument.getRemark());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user