refactor(modules): 拆分多模块工程并收口common基础模块
This commit is contained in:
60
common-agent-agent/pom.xml
Normal file
60
common-agent-agent/pom.xml
Normal file
@@ -0,0 +1,60 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>com.bruce</groupId>
|
||||
<artifactId>common-agent-parent</artifactId>
|
||||
<version>0.0.1-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>common-agent-agent</artifactId>
|
||||
<name>common-agent-agent</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.bruce</groupId>
|
||||
<artifactId>common-agent-common</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.bruce</groupId>
|
||||
<artifactId>common-agent-rag</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.bruce</groupId>
|
||||
<artifactId>common-agent-modelprovider</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.baomidou</groupId>
|
||||
<artifactId>mybatis-plus-spring-boot4-starter</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.postgresql</groupId>
|
||||
<artifactId>postgresql</artifactId>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@@ -0,0 +1,58 @@
|
||||
package com.bruce.agent.controller;
|
||||
|
||||
import com.bruce.agent.dto.request.AgentChatRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionQueryRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
|
||||
import com.bruce.agent.dto.response.AgentChatResponse;
|
||||
import com.bruce.agent.dto.response.AgentDefinitionResponse;
|
||||
import com.bruce.agent.service.IAgentDefinitionService;
|
||||
import com.bruce.common.domain.model.RequestResult;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/agents")
|
||||
@RequiredArgsConstructor
|
||||
public class AgentDefinitionController {
|
||||
|
||||
private final IAgentDefinitionService agentDefinitionService;
|
||||
|
||||
@PostMapping("/list")
|
||||
public RequestResult<List<AgentDefinitionResponse>> list() {
|
||||
return RequestResult.success(agentDefinitionService.listResponses());
|
||||
}
|
||||
|
||||
@PostMapping("/query")
|
||||
public RequestResult<List<AgentDefinitionResponse>> query(@RequestBody(required = false) AgentDefinitionQueryRequest request) {
|
||||
return RequestResult.success(agentDefinitionService.query(request));
|
||||
}
|
||||
|
||||
@GetMapping("/detail")
|
||||
public RequestResult<AgentDefinitionResponse> detail(@RequestParam("id") Long id) {
|
||||
return RequestResult.success(agentDefinitionService.getResponseById(id));
|
||||
}
|
||||
|
||||
@PostMapping("/save")
|
||||
public RequestResult<Boolean> save(@RequestBody AgentDefinitionSaveRequest request) {
|
||||
return RequestResult.success(agentDefinitionService.saveOrUpdate(request));
|
||||
}
|
||||
|
||||
@PostMapping("/delete")
|
||||
public RequestResult<Boolean> delete(@RequestParam("id") Long id) {
|
||||
return RequestResult.success(agentDefinitionService.removeById(id));
|
||||
}
|
||||
|
||||
@PostMapping("/{agentId}/chat")
|
||||
public RequestResult<AgentChatResponse> chat(@PathVariable("agentId") Long agentId,
|
||||
@RequestBody AgentChatRequest request) {
|
||||
return RequestResult.success(agentDefinitionService.chat(agentId, request));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.bruce.agent.dto.request;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class AgentChatRequest {
|
||||
private List<AgentMessage> messages;
|
||||
private Boolean ragEnabled;
|
||||
|
||||
@Data
|
||||
public static class AgentMessage {
|
||||
private String role;
|
||||
private String content;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.bruce.agent.dto.request;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class AgentDefinitionQueryRequest {
|
||||
private String agentCode;
|
||||
private String agentName;
|
||||
private String status;
|
||||
private Long storeId;
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.bruce.agent.dto.request;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class AgentDefinitionSaveRequest {
|
||||
private Long id;
|
||||
private String agentCode;
|
||||
private String agentName;
|
||||
private String systemPrompt;
|
||||
private Long storeId;
|
||||
private String status;
|
||||
private String remark;
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.bruce.agent.dto.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class AgentChatResponse {
|
||||
private Long agentId;
|
||||
private String agentCode;
|
||||
private String agentName;
|
||||
private Long storeId;
|
||||
private String storeName;
|
||||
private String answer;
|
||||
private String modelRequestId;
|
||||
private List<ReferenceChunk> references;
|
||||
|
||||
@Data
|
||||
public static class ReferenceChunk {
|
||||
private Long chunkId;
|
||||
private Long documentId;
|
||||
private String chunkContent;
|
||||
private Double score;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.bruce.agent.dto.response;
|
||||
|
||||
import com.bruce.agent.entity.AgentDefinition;
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
@Data
|
||||
public class AgentDefinitionResponse {
|
||||
private Long id;
|
||||
private String agentCode;
|
||||
private String agentName;
|
||||
private String systemPrompt;
|
||||
private Long storeId;
|
||||
private String status;
|
||||
private String remark;
|
||||
|
||||
public static AgentDefinitionResponse fromEntity(AgentDefinition entity) {
|
||||
if (entity == null) {
|
||||
return null;
|
||||
}
|
||||
AgentDefinitionResponse response = new AgentDefinitionResponse();
|
||||
BeanUtils.copyProperties(entity, response);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package com.bruce.agent.entity;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.bruce.common.domain.model.BaseEntity;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@TableName("agent_definition")
|
||||
public class AgentDefinition extends BaseEntity {
|
||||
|
||||
@TableField("agent_code")
|
||||
private String agentCode;
|
||||
|
||||
@TableField("agent_name")
|
||||
private String agentName;
|
||||
|
||||
@TableField("system_prompt")
|
||||
private String systemPrompt;
|
||||
|
||||
@TableField("store_id")
|
||||
private Long storeId;
|
||||
|
||||
private String status;
|
||||
|
||||
private String remark;
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.bruce.agent.mapper;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.bruce.agent.entity.AgentDefinition;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface AgentDefinitionMapper extends BaseMapper<AgentDefinition> {
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.bruce.agent.service;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.IService;
|
||||
import com.bruce.agent.dto.request.AgentChatRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionQueryRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
|
||||
import com.bruce.agent.dto.response.AgentChatResponse;
|
||||
import com.bruce.agent.dto.response.AgentDefinitionResponse;
|
||||
import com.bruce.agent.entity.AgentDefinition;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface IAgentDefinitionService extends IService<AgentDefinition> {
|
||||
List<AgentDefinitionResponse> listResponses();
|
||||
|
||||
List<AgentDefinitionResponse> query(AgentDefinitionQueryRequest request);
|
||||
|
||||
AgentDefinitionResponse getResponseById(Long id);
|
||||
|
||||
boolean saveOrUpdate(AgentDefinitionSaveRequest request);
|
||||
|
||||
AgentChatResponse chat(Long agentId, AgentChatRequest request);
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
package com.bruce.agent.service.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.bruce.agent.dto.request.AgentChatRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionQueryRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
|
||||
import com.bruce.agent.dto.response.AgentChatResponse;
|
||||
import com.bruce.agent.dto.response.AgentDefinitionResponse;
|
||||
import com.bruce.agent.entity.AgentDefinition;
|
||||
import com.bruce.agent.mapper.AgentDefinitionMapper;
|
||||
import com.bruce.agent.service.IAgentDefinitionService;
|
||||
import com.bruce.common.enums.EnableStatusEnum;
|
||||
import com.bruce.modelprovider.client.OpenAiChatMessage;
|
||||
import com.bruce.modelprovider.entity.RagStoreModelConfig;
|
||||
import com.bruce.modelprovider.gateway.ChatModelGateway;
|
||||
import com.bruce.modelprovider.gateway.ChatRequest;
|
||||
import com.bruce.modelprovider.gateway.ChatResult;
|
||||
import com.bruce.modelprovider.gateway.EmbeddingModelGateway;
|
||||
import com.bruce.modelprovider.gateway.EmbeddingRequest;
|
||||
import com.bruce.modelprovider.gateway.EmbeddingResult;
|
||||
import com.bruce.modelprovider.service.IRagStoreModelConfigService;
|
||||
import com.bruce.rag.dto.response.RagChunkRecallResponse;
|
||||
import com.bruce.rag.entity.RagStore;
|
||||
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
|
||||
import com.bruce.rag.service.IRagStoreService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class AgentDefinitionServiceImpl extends ServiceImpl<AgentDefinitionMapper, AgentDefinition>
|
||||
implements IAgentDefinitionService {
|
||||
|
||||
private static final int DEFAULT_TOP_K = 5;
|
||||
|
||||
private final IRagStoreService ragStoreService;
|
||||
private final IRagStoreModelConfigService ragStoreModelConfigService;
|
||||
private final RagChunkEmbeddingMapper ragChunkEmbeddingMapper;
|
||||
private final EmbeddingModelGateway embeddingModelGateway;
|
||||
private final ChatModelGateway chatModelGateway;
|
||||
|
||||
@Override
|
||||
public List<AgentDefinitionResponse> listResponses() {
|
||||
return lambdaQuery()
|
||||
.orderByAsc(AgentDefinition::getAgentCode)
|
||||
.list()
|
||||
.stream()
|
||||
.map(AgentDefinitionResponse::fromEntity)
|
||||
.toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AgentDefinitionResponse> query(AgentDefinitionQueryRequest request) {
|
||||
AgentDefinitionQueryRequest queryRequest = request == null ? new AgentDefinitionQueryRequest() : request;
|
||||
return lambdaQuery()
|
||||
.eq(StringUtils.hasText(queryRequest.getAgentCode()), AgentDefinition::getAgentCode, trimToNull(queryRequest.getAgentCode()))
|
||||
.like(StringUtils.hasText(queryRequest.getAgentName()), AgentDefinition::getAgentName, trimToNull(queryRequest.getAgentName()))
|
||||
.eq(StringUtils.hasText(queryRequest.getStatus()), AgentDefinition::getStatus, trimToNull(queryRequest.getStatus()))
|
||||
.eq(queryRequest.getStoreId() != null, AgentDefinition::getStoreId, queryRequest.getStoreId())
|
||||
.orderByAsc(AgentDefinition::getAgentCode)
|
||||
.list()
|
||||
.stream()
|
||||
.map(AgentDefinitionResponse::fromEntity)
|
||||
.toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentDefinitionResponse getResponseById(Long id) {
|
||||
return AgentDefinitionResponse.fromEntity(getById(id));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean saveOrUpdate(AgentDefinitionSaveRequest request) {
|
||||
validateSaveRequest(request);
|
||||
if (ragStoreService.getById(request.getStoreId()) == null) {
|
||||
throw new IllegalArgumentException("绑定知识库不存在,ID: " + request.getStoreId());
|
||||
}
|
||||
AgentDefinition duplicate = lambdaQuery()
|
||||
.eq(AgentDefinition::getAgentCode, request.getAgentCode().trim())
|
||||
.ne(request.getId() != null, AgentDefinition::getId, request.getId())
|
||||
.one();
|
||||
if (duplicate != null) {
|
||||
throw new IllegalArgumentException("Agent编码已存在: " + request.getAgentCode().trim());
|
||||
}
|
||||
AgentDefinition entity = request.getId() == null ? new AgentDefinition() : getById(request.getId());
|
||||
if (entity == null) {
|
||||
throw new IllegalArgumentException("Agent不存在,ID: " + request.getId());
|
||||
}
|
||||
entity.setAgentCode(request.getAgentCode().trim());
|
||||
entity.setAgentName(request.getAgentName().trim());
|
||||
entity.setSystemPrompt(trimToNull(request.getSystemPrompt()));
|
||||
entity.setStoreId(request.getStoreId());
|
||||
entity.setStatus(StringUtils.hasText(request.getStatus())
|
||||
? request.getStatus().trim()
|
||||
: EnableStatusEnum.ENABLED.name());
|
||||
entity.setRemark(trimToNull(request.getRemark()));
|
||||
return request.getId() == null ? save(entity) : updateById(entity);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentChatResponse chat(Long agentId, AgentChatRequest request) {
|
||||
if (agentId == null) {
|
||||
throw new IllegalArgumentException("Agent ID不能为空");
|
||||
}
|
||||
if (request == null || request.getMessages() == null || request.getMessages().isEmpty()) {
|
||||
throw new IllegalArgumentException("对话消息不能为空");
|
||||
}
|
||||
AgentDefinition agent = getById(agentId);
|
||||
if (agent == null) {
|
||||
throw new IllegalArgumentException("Agent不存在,ID: " + agentId);
|
||||
}
|
||||
if (!EnableStatusEnum.ENABLED.name().equals(agent.getStatus())) {
|
||||
throw new IllegalArgumentException("Agent已停用,暂不支持对话");
|
||||
}
|
||||
if (agent.getStoreId() == null) {
|
||||
throw new IllegalArgumentException("Agent未绑定知识库,请先保存知识库配置");
|
||||
}
|
||||
RagStore store = ragStoreService.getById(agent.getStoreId());
|
||||
if (store == null) {
|
||||
throw new IllegalArgumentException("绑定知识库不存在,ID: " + agent.getStoreId());
|
||||
}
|
||||
|
||||
String queryText = resolveLatestUserMessage(request.getMessages());
|
||||
boolean ragEnabled = request.getRagEnabled() == null || request.getRagEnabled();
|
||||
List<RagChunkRecallResponse> recalls = List.of();
|
||||
if (ragEnabled) {
|
||||
RagStoreModelConfig storeModelConfig = ragStoreModelConfigService.getActiveEntity(agent.getStoreId());
|
||||
if (storeModelConfig == null || storeModelConfig.getEmbeddingModelId() == null) {
|
||||
throw new IllegalArgumentException("当前知识库未配置Embedding模型,无法执行检索对话");
|
||||
}
|
||||
EmbeddingRequest embeddingRequest = new EmbeddingRequest();
|
||||
embeddingRequest.setTexts(List.of(queryText));
|
||||
embeddingRequest.setTaskType("RAG_QUERY_EMBEDDING");
|
||||
embeddingRequest.setMatchScope("RAG_STORE");
|
||||
embeddingRequest.setScopeId(agent.getStoreId());
|
||||
embeddingRequest.setBizType("AGENT_CHAT");
|
||||
embeddingRequest.setBizId(String.valueOf(agentId));
|
||||
embeddingRequest.setExpectedDimension(storeModelConfig.getEmbeddingDimension());
|
||||
EmbeddingResult queryEmbedding = embeddingModelGateway.embed(embeddingRequest);
|
||||
if (queryEmbedding.getVectors() == null || queryEmbedding.getVectors().isEmpty()) {
|
||||
throw new IllegalArgumentException("查询向量生成失败,请检查Embedding模型配置");
|
||||
}
|
||||
|
||||
String queryVector = toVectorLiteral(queryEmbedding.getVectors().getFirst());
|
||||
recalls = ragChunkEmbeddingMapper.queryTopKByStore(
|
||||
agent.getStoreId(),
|
||||
queryVector,
|
||||
DEFAULT_TOP_K
|
||||
);
|
||||
if (recalls.isEmpty()) {
|
||||
throw new IllegalArgumentException("未召回到可用知识切片,请先完成知识库切片与向量化");
|
||||
}
|
||||
}
|
||||
|
||||
ChatRequest chatRequest = new ChatRequest();
|
||||
chatRequest.setTaskType(ragEnabled ? "RAG_ANSWER" : "CHAT_SIMPLE");
|
||||
chatRequest.setMatchScope("AGENT");
|
||||
chatRequest.setScopeId(agentId);
|
||||
chatRequest.setBizType("AGENT_CHAT");
|
||||
chatRequest.setBizId(String.valueOf(agentId));
|
||||
chatRequest.setMessages(buildChatMessages(agent, recalls, request.getMessages(), ragEnabled));
|
||||
|
||||
ChatResult chatResult = chatModelGateway.chat(chatRequest);
|
||||
AgentChatResponse response = new AgentChatResponse();
|
||||
response.setAgentId(agent.getId());
|
||||
response.setAgentCode(agent.getAgentCode());
|
||||
response.setAgentName(agent.getAgentName());
|
||||
response.setStoreId(agent.getStoreId());
|
||||
response.setStoreName(store.getStoreName());
|
||||
response.setAnswer(chatResult.getContent());
|
||||
response.setModelRequestId(chatResult.getCallLog().getRequestId());
|
||||
response.setReferences(toReferenceChunks(recalls));
|
||||
return response;
|
||||
}
|
||||
|
||||
private void validateSaveRequest(AgentDefinitionSaveRequest request) {
|
||||
if (request == null) {
|
||||
throw new IllegalArgumentException("保存请求不能为空");
|
||||
}
|
||||
if (!StringUtils.hasText(request.getAgentCode())) {
|
||||
throw new IllegalArgumentException("Agent编码不能为空");
|
||||
}
|
||||
if (!StringUtils.hasText(request.getAgentName())) {
|
||||
throw new IllegalArgumentException("Agent名称不能为空");
|
||||
}
|
||||
if (request.getStoreId() == null) {
|
||||
throw new IllegalArgumentException("绑定知识库不能为空");
|
||||
}
|
||||
}
|
||||
|
||||
private String resolveLatestUserMessage(List<AgentChatRequest.AgentMessage> messages) {
|
||||
for (int index = messages.size() - 1; index >= 0; index--) {
|
||||
AgentChatRequest.AgentMessage message = messages.get(index);
|
||||
if (message != null
|
||||
&& "user".equalsIgnoreCase(message.getRole())
|
||||
&& StringUtils.hasText(message.getContent())) {
|
||||
return message.getContent();
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("缺少用户提问内容");
|
||||
}
|
||||
|
||||
private List<OpenAiChatMessage> buildChatMessages(AgentDefinition agent,
|
||||
List<RagChunkRecallResponse> recalls,
|
||||
List<AgentChatRequest.AgentMessage> rawMessages,
|
||||
boolean ragEnabled) {
|
||||
List<OpenAiChatMessage> messages = new ArrayList<>();
|
||||
OpenAiChatMessage instructionMessage = new OpenAiChatMessage();
|
||||
instructionMessage.setRole("system");
|
||||
instructionMessage.setContent(buildSystemInstruction(agent));
|
||||
messages.add(instructionMessage);
|
||||
|
||||
if (ragEnabled) {
|
||||
OpenAiChatMessage contextMessage = new OpenAiChatMessage();
|
||||
contextMessage.setRole("system");
|
||||
contextMessage.setContent(buildContextText(recalls));
|
||||
messages.add(contextMessage);
|
||||
}
|
||||
|
||||
for (AgentChatRequest.AgentMessage rawMessage : rawMessages) {
|
||||
if (rawMessage == null || !StringUtils.hasText(rawMessage.getContent())) {
|
||||
continue;
|
||||
}
|
||||
OpenAiChatMessage message = new OpenAiChatMessage();
|
||||
message.setRole(normalizeRole(rawMessage.getRole()));
|
||||
message.setContent(rawMessage.getContent());
|
||||
messages.add(message);
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
private String buildSystemInstruction(AgentDefinition agent) {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
if (StringUtils.hasText(agent.getSystemPrompt())) {
|
||||
builder.append(agent.getSystemPrompt().trim()).append("\n\n");
|
||||
}
|
||||
builder.append("请优先基于已给出的知识库引用片段回答。");
|
||||
builder.append("如果引用无法支持结论,请明确告知“知识库中暂无直接依据”。");
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
private String buildContextText(List<RagChunkRecallResponse> recalls) {
|
||||
StringBuilder builder = new StringBuilder("以下是知识库召回片段:\n");
|
||||
for (int i = 0; i < recalls.size(); i++) {
|
||||
RagChunkRecallResponse recall = recalls.get(i);
|
||||
builder.append(i + 1)
|
||||
.append(". [chunkId=")
|
||||
.append(recall.getChunkId())
|
||||
.append(", score=")
|
||||
.append(String.format("%.4f", recall.getScore() == null ? 0D : recall.getScore()))
|
||||
.append("] ")
|
||||
.append(recall.getChunkContent())
|
||||
.append("\n");
|
||||
}
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
private List<AgentChatResponse.ReferenceChunk> toReferenceChunks(List<RagChunkRecallResponse> recalls) {
|
||||
return recalls.stream().map(recall -> {
|
||||
AgentChatResponse.ReferenceChunk chunk = new AgentChatResponse.ReferenceChunk();
|
||||
chunk.setChunkId(recall.getChunkId());
|
||||
chunk.setDocumentId(recall.getDocumentId());
|
||||
chunk.setChunkContent(recall.getChunkContent());
|
||||
chunk.setScore(recall.getScore());
|
||||
return chunk;
|
||||
}).toList();
|
||||
}
|
||||
|
||||
private String normalizeRole(String role) {
|
||||
if (!StringUtils.hasText(role)) {
|
||||
return "user";
|
||||
}
|
||||
String normalized = role.trim().toLowerCase();
|
||||
if ("system".equals(normalized) || "assistant".equals(normalized) || "user".equals(normalized)) {
|
||||
return normalized;
|
||||
}
|
||||
return "user";
|
||||
}
|
||||
|
||||
private String toVectorLiteral(List<Double> vector) {
|
||||
StringBuilder builder = new StringBuilder("[");
|
||||
for (int index = 0; index < vector.size(); index++) {
|
||||
if (index > 0) {
|
||||
builder.append(',');
|
||||
}
|
||||
builder.append(vector.get(index));
|
||||
}
|
||||
builder.append(']');
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
private String trimToNull(String value) {
|
||||
if (!StringUtils.hasText(value)) {
|
||||
return null;
|
||||
}
|
||||
return value.trim();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.bruce.agent;
|
||||
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.baomidou.mybatisplus.extension.service.IService;
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.bruce.agent.controller.AgentDefinitionController;
|
||||
import com.bruce.agent.dto.request.AgentChatRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionQueryRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
|
||||
import com.bruce.agent.dto.response.AgentChatResponse;
|
||||
import com.bruce.agent.dto.response.AgentDefinitionResponse;
|
||||
import com.bruce.agent.entity.AgentDefinition;
|
||||
import com.bruce.agent.mapper.AgentDefinitionMapper;
|
||||
import com.bruce.agent.service.IAgentDefinitionService;
|
||||
import com.bruce.agent.service.impl.AgentDefinitionServiceImpl;
|
||||
import com.bruce.common.domain.model.RequestResult;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
class AgentComponentStructureTests {
|
||||
|
||||
@Test
|
||||
void agentComponentsShouldReuseMybatisPlusBaseTypes() {
|
||||
assertTrue(BaseMapper.class.isAssignableFrom(AgentDefinitionMapper.class));
|
||||
assertTrue(IService.class.isAssignableFrom(IAgentDefinitionService.class));
|
||||
assertTrue(ServiceImpl.class.isAssignableFrom(AgentDefinitionServiceImpl.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void agentControllerShouldExposeRequestResultMethods() throws NoSuchMethodException {
|
||||
Method listMethod = AgentDefinitionController.class.getMethod("list");
|
||||
Method queryMethod = AgentDefinitionController.class.getMethod("query", AgentDefinitionQueryRequest.class);
|
||||
Method detailMethod = AgentDefinitionController.class.getMethod("detail", Long.class);
|
||||
Method saveMethod = AgentDefinitionController.class.getMethod("save", AgentDefinitionSaveRequest.class);
|
||||
Method deleteMethod = AgentDefinitionController.class.getMethod("delete", Long.class);
|
||||
Method chatMethod = AgentDefinitionController.class.getMethod("chat", Long.class, AgentChatRequest.class);
|
||||
|
||||
Method listServiceMethod = IAgentDefinitionService.class.getMethod("listResponses");
|
||||
Method queryServiceMethod = IAgentDefinitionService.class.getMethod("query", AgentDefinitionQueryRequest.class);
|
||||
Method detailServiceMethod = IAgentDefinitionService.class.getMethod("getResponseById", Long.class);
|
||||
Method saveServiceMethod = IAgentDefinitionService.class.getMethod("saveOrUpdate", AgentDefinitionSaveRequest.class);
|
||||
Method chatServiceMethod = IAgentDefinitionService.class.getMethod("chat", Long.class, AgentChatRequest.class);
|
||||
|
||||
assertEquals(RequestResult.class, listMethod.getReturnType());
|
||||
assertEquals(RequestResult.class, queryMethod.getReturnType());
|
||||
assertEquals(RequestResult.class, detailMethod.getReturnType());
|
||||
assertEquals(RequestResult.class, saveMethod.getReturnType());
|
||||
assertEquals(RequestResult.class, deleteMethod.getReturnType());
|
||||
assertEquals(RequestResult.class, chatMethod.getReturnType());
|
||||
|
||||
assertEquals(List.class, listServiceMethod.getReturnType());
|
||||
assertEquals(List.class, queryServiceMethod.getReturnType());
|
||||
assertEquals(AgentDefinitionResponse.class, detailServiceMethod.getReturnType());
|
||||
assertEquals(boolean.class, saveServiceMethod.getReturnType());
|
||||
assertEquals(AgentChatResponse.class, chatServiceMethod.getReturnType());
|
||||
assertEquals(AgentDefinitionResponse.class, AgentDefinitionResponse.class.getMethod("fromEntity", AgentDefinition.class).getReturnType());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
package com.bruce.agent;
|
||||
|
||||
import com.bruce.agent.dto.request.AgentChatRequest;
|
||||
import com.bruce.agent.dto.request.AgentDefinitionSaveRequest;
|
||||
import com.bruce.agent.dto.response.AgentChatResponse;
|
||||
import com.bruce.agent.entity.AgentDefinition;
|
||||
import com.bruce.agent.service.impl.AgentDefinitionServiceImpl;
|
||||
import com.bruce.modelprovider.entity.ModelCallLog;
|
||||
import com.bruce.modelprovider.entity.RagStoreModelConfig;
|
||||
import com.bruce.modelprovider.gateway.ChatRequest;
|
||||
import com.bruce.modelprovider.gateway.ChatResult;
|
||||
import com.bruce.modelprovider.gateway.EmbeddingRequest;
|
||||
import com.bruce.modelprovider.gateway.EmbeddingResult;
|
||||
import com.bruce.modelprovider.service.IRagStoreModelConfigService;
|
||||
import com.bruce.rag.dto.response.RagChunkRecallResponse;
|
||||
import com.bruce.rag.entity.RagStore;
|
||||
import com.bruce.rag.mapper.RagChunkEmbeddingMapper;
|
||||
import com.bruce.rag.service.IRagStoreService;
|
||||
import com.bruce.modelprovider.gateway.ChatModelGateway;
|
||||
import com.bruce.modelprovider.gateway.EmbeddingModelGateway;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.Spy;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class AgentDefinitionServiceImplTests {
|
||||
|
||||
@Spy
|
||||
@InjectMocks
|
||||
private AgentDefinitionServiceImpl agentDefinitionService;
|
||||
|
||||
@Mock
|
||||
private IRagStoreService ragStoreService;
|
||||
|
||||
@Mock
|
||||
private IRagStoreModelConfigService ragStoreModelConfigService;
|
||||
|
||||
@Mock
|
||||
private RagChunkEmbeddingMapper ragChunkEmbeddingMapper;
|
||||
|
||||
@Mock
|
||||
private EmbeddingModelGateway embeddingModelGateway;
|
||||
|
||||
@Mock
|
||||
private ChatModelGateway chatModelGateway;
|
||||
|
||||
@Test
|
||||
void saveOrUpdateShouldValidateBoundStoreExists() {
|
||||
AgentDefinitionSaveRequest request = new AgentDefinitionSaveRequest();
|
||||
request.setAgentCode("A_1");
|
||||
request.setAgentName("Agent 1");
|
||||
request.setStoreId(1001L);
|
||||
|
||||
when(ragStoreService.getById(1001L)).thenReturn(null);
|
||||
|
||||
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.saveOrUpdate(request));
|
||||
assertTrue(exception.getMessage().contains("绑定知识库不存在"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatShouldRejectDisabledAgent() {
|
||||
AgentDefinition agent = new AgentDefinition();
|
||||
agent.setId(1001L);
|
||||
agent.setStoreId(2001L);
|
||||
agent.setStatus("DISABLED");
|
||||
doReturn(agent).when(agentDefinitionService).getById(1001L);
|
||||
|
||||
AgentChatRequest request = new AgentChatRequest();
|
||||
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
|
||||
message.setRole("user");
|
||||
message.setContent("你好");
|
||||
request.setMessages(List.of(message));
|
||||
|
||||
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
|
||||
assertTrue(exception.getMessage().contains("停用"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatShouldRejectAgentWithoutStore() {
|
||||
AgentDefinition agent = new AgentDefinition();
|
||||
agent.setId(1001L);
|
||||
agent.setStatus("ENABLED");
|
||||
agent.setStoreId(null);
|
||||
doReturn(agent).when(agentDefinitionService).getById(1001L);
|
||||
|
||||
AgentChatRequest request = new AgentChatRequest();
|
||||
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
|
||||
message.setRole("user");
|
||||
message.setContent("你好");
|
||||
request.setMessages(List.of(message));
|
||||
|
||||
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request));
|
||||
assertTrue(exception.getMessage().contains("未绑定知识库"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatShouldUseStoreScopedRecallAndReturnAnswer() {
|
||||
AgentDefinition agent = new AgentDefinition();
|
||||
agent.setId(1001L);
|
||||
agent.setAgentCode("AGENT_1");
|
||||
agent.setAgentName("知识助手");
|
||||
agent.setSystemPrompt("你是企业知识助手");
|
||||
agent.setStoreId(2001L);
|
||||
agent.setStatus("ENABLED");
|
||||
doReturn(agent).when(agentDefinitionService).getById(1001L);
|
||||
|
||||
RagStore store = new RagStore();
|
||||
store.setId(2001L);
|
||||
store.setStoreName("企业知识库");
|
||||
when(ragStoreService.getById(2001L)).thenReturn(store);
|
||||
|
||||
RagStoreModelConfig modelConfig = new RagStoreModelConfig();
|
||||
modelConfig.setStoreId(2001L);
|
||||
modelConfig.setEmbeddingModelId(3001L);
|
||||
modelConfig.setEmbeddingDimension(1024);
|
||||
when(ragStoreModelConfigService.getActiveEntity(2001L)).thenReturn(modelConfig);
|
||||
|
||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||
embeddingResult.setVectors(List.of(List.of(0.12, 0.34, 0.56)));
|
||||
when(embeddingModelGateway.embed(any(EmbeddingRequest.class))).thenReturn(embeddingResult);
|
||||
|
||||
RagChunkRecallResponse recall = new RagChunkRecallResponse();
|
||||
recall.setChunkId(4001L);
|
||||
recall.setDocumentId(5001L);
|
||||
recall.setChunkContent("公司请假流程:先提交审批单。");
|
||||
recall.setScore(0.91);
|
||||
when(ragChunkEmbeddingMapper.queryTopKByStore(anyLong(), anyString(), anyInt()))
|
||||
.thenReturn(List.of(recall));
|
||||
|
||||
ModelCallLog callLog = new ModelCallLog();
|
||||
callLog.setRequestId("req_001");
|
||||
ChatResult chatResult = new ChatResult();
|
||||
chatResult.setContent("根据知识库,先在OA提交请假审批。");
|
||||
chatResult.setCallLog(callLog);
|
||||
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
|
||||
|
||||
AgentChatRequest request = new AgentChatRequest();
|
||||
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
|
||||
message.setRole("user");
|
||||
message.setContent("请假流程是什么?");
|
||||
request.setMessages(List.of(message));
|
||||
|
||||
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
|
||||
|
||||
assertEquals(1001L, response.getAgentId());
|
||||
assertEquals(2001L, response.getStoreId());
|
||||
assertEquals("企业知识库", response.getStoreName());
|
||||
assertEquals("根据知识库,先在OA提交请假审批。", response.getAnswer());
|
||||
assertEquals("req_001", response.getModelRequestId());
|
||||
assertEquals(1, response.getReferences().size());
|
||||
assertEquals(4001L, response.getReferences().getFirst().getChunkId());
|
||||
|
||||
ArgumentCaptor<EmbeddingRequest> embeddingRequestCaptor = ArgumentCaptor.forClass(EmbeddingRequest.class);
|
||||
verify(embeddingModelGateway).embed(embeddingRequestCaptor.capture());
|
||||
EmbeddingRequest embeddingRequest = embeddingRequestCaptor.getValue();
|
||||
assertEquals("RAG_QUERY_EMBEDDING", embeddingRequest.getTaskType());
|
||||
assertEquals("RAG_STORE", embeddingRequest.getMatchScope());
|
||||
assertEquals(2001L, embeddingRequest.getScopeId());
|
||||
|
||||
verify(ragChunkEmbeddingMapper).queryTopKByStore(anyLong(), anyString(), anyInt());
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatShouldSupportSimpleModeWithoutRagRecall() {
|
||||
AgentDefinition agent = new AgentDefinition();
|
||||
agent.setId(1001L);
|
||||
agent.setAgentCode("AGENT_1");
|
||||
agent.setAgentName("知识助手");
|
||||
agent.setStoreId(2001L);
|
||||
agent.setStatus("ENABLED");
|
||||
doReturn(agent).when(agentDefinitionService).getById(1001L);
|
||||
|
||||
RagStore store = new RagStore();
|
||||
store.setId(2001L);
|
||||
store.setStoreName("企业知识库");
|
||||
when(ragStoreService.getById(2001L)).thenReturn(store);
|
||||
|
||||
ModelCallLog callLog = new ModelCallLog();
|
||||
callLog.setRequestId("req_simple_001");
|
||||
ChatResult chatResult = new ChatResult();
|
||||
chatResult.setContent("这是普通对话回答。");
|
||||
chatResult.setCallLog(callLog);
|
||||
when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult);
|
||||
|
||||
AgentChatRequest request = new AgentChatRequest();
|
||||
request.setRagEnabled(false);
|
||||
AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage();
|
||||
message.setRole("user");
|
||||
message.setContent("直接聊聊今天安排");
|
||||
request.setMessages(List.of(message));
|
||||
|
||||
AgentChatResponse response = agentDefinitionService.chat(1001L, request);
|
||||
|
||||
assertEquals("这是普通对话回答。", response.getAnswer());
|
||||
assertTrue(response.getReferences().isEmpty());
|
||||
verify(embeddingModelGateway, never()).embed(any(EmbeddingRequest.class));
|
||||
verify(ragChunkEmbeddingMapper, never()).queryTopKByStore(anyLong(), anyString(), anyInt());
|
||||
|
||||
ArgumentCaptor<ChatRequest> chatRequestCaptor = ArgumentCaptor.forClass(ChatRequest.class);
|
||||
verify(chatModelGateway).chat(chatRequestCaptor.capture());
|
||||
assertEquals("CHAT_SIMPLE", chatRequestCaptor.getValue().getTaskType());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user