diff --git a/agent-page-apis.md b/agent-page-apis.md new file mode 100644 index 0000000..394e0d4 --- /dev/null +++ b/agent-page-apis.md @@ -0,0 +1,159 @@ +# Agent 页面后端接口清单 + +本文对应前端页面:[AgentManagePage.vue](frontend/src/pages/agent/AgentManagePage.vue) 和 [AgentDebugPage.vue](frontend/src/pages/agent/AgentDebugPage.vue)。 + +## 1. 页面目标 + +Agent 页面分为两块: + +- Agent 管理:维护 `agent_definition` 基础配置(编码、名称、知识库绑定、状态、系统提示词)。 +- Agent 调试:选择 Agent 发起对话,支持普通对话与 RAG 对话切换,并回显引用切片。 + +## 2. Agent 管理接口 + +### 2.1 查询全部 Agent + +- `POST /api/agents/list` + +返回类型: + +- `RequestResult>` + +### 2.2 条件查询 Agent + +- `POST /api/agents/query` + +请求体示例: + +```json +{ + "agentCode": "AGENT_RAG_HELPER", + "agentName": "知识助手", + "status": "ENABLED", + "storeId": 1001 +} +``` + +### 2.3 查询 Agent 详情 + +- `GET /api/agents/detail?id={id}` + +### 2.4 新增或更新 Agent + +- `POST /api/agents/save` + +请求体示例: + +```json +{ + "id": 1, + "agentCode": "AGENT_RAG_HELPER", + "agentName": "知识问答助手", + "systemPrompt": "你是企业知识助手,请优先基于知识库回答。", + "storeId": 1001, + "status": "ENABLED", + "remark": "客服场景" +} +``` + +说明: + +- `id` 为空时新增,非空时更新。 +- `agentCode` 全局唯一。 +- `storeId` 必须指向已存在的 `rag_store`。 +- `status` 默认 `ENABLED`,可选 `ENABLED` / `DISABLED`。 + +### 2.5 删除 Agent + +- `POST /api/agents/delete?id={id}` + +## 3. Agent 调试接口 + +### 3.1 发起对话 + +- `POST /api/agents/{agentId}/chat` + +请求体示例: + +```json +{ + "messages": [ + { "role": "user", "content": "请说明请假流程" } + ], + "ragEnabled": true +} +``` + +返回示例: + +```json +{ + "resultcode": "0", + "message": null, + "data": { + "agentId": 1, + "agentCode": "AGENT_RAG_HELPER", + "agentName": "知识问答助手", + "storeId": 1001, + "storeName": "企业知识库", + "answer": "根据知识库,先提交 OA 审批单。", + "modelRequestId": "f4215d13d0b3493e963297f15428e2f2", + "references": [ + { + "chunkId": 9001, + "documentId": 8001, + "chunkContent": "请假流程:员工先在OA提交审批单...", + "score": 0.9123 + } + ] + } +} +``` + +## 4. 对话模式说明 + +### 4.1 `ragEnabled=true`(默认) + +执行路径: + +1. 从消息列表中提取最后一条 `role=user` 的问题。 +2. 读取该 Agent 绑定知识库的生效 Embedding 配置。 +3. 生成查询向量并在 `rag_chunk_embedding` 按知识库 TopK 召回切片。 +4. 将系统提示词、召回片段和会话消息组装后调用 Chat 模型。 +5. 返回回答 + 引用切片 + `modelRequestId`。 + +### 4.2 `ragEnabled=false` + +执行路径: + +- 跳过向量化与召回,直接使用会话消息调用 Chat 模型,返回普通对话结果。 + +## 5. 调试联调前置条件 + +### 5.1 普通对话前置条件 + +- Agent 状态为 `ENABLED`。 +- Agent 已绑定存在的知识库。 +- 已配置可用的 Chat 路由(`taskType=CHAT_SIMPLE` 或 `RAG_ANSWER`)。 + +### 5.2 RAG 对话前置条件 + +- 满足普通对话前置条件。 +- 知识库存在生效 `rag_store_model_config` 且已绑定 Embedding 模型。 +- 目标知识库至少有可用向量数据(`rag_chunk_embedding`)。 + +## 6. 常见失败提示 + +- `Agent已停用,暂不支持对话`:需启用 Agent。 +- `当前知识库未配置Embedding模型,无法执行检索对话`:需先配置知识库 Embedding 模型。 +- `未召回到可用知识切片,请先完成知识库切片与向量化`:需补齐切片向量化流程。 + +## 7. 相关代码入口 + +- `src/main/java/com/bruce/agent/controller/AgentDefinitionController.java` +- `src/main/java/com/bruce/agent/service/impl/AgentDefinitionServiceImpl.java` +- `src/main/java/com/bruce/agent/entity/AgentDefinition.java` +- `src/main/java/com/bruce/modelprovider/gateway/ChatModelGatewayImpl.java` +- `frontend/src/api/agent.ts` +- `frontend/src/pages/agent/AgentManagePage.vue` +- `frontend/src/pages/agent/AgentDebugPage.vue` diff --git a/frontend/src/api/__tests__/agent.spec.ts b/frontend/src/api/__tests__/agent.spec.ts new file mode 100644 index 0000000..f7e0b0d --- /dev/null +++ b/frontend/src/api/__tests__/agent.spec.ts @@ -0,0 +1,43 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { + chatWithAgent, + deleteAgent, + getAgentById, + listAgents, + queryAgents, + saveAgent, +} from '../agent'; +import { get, post } from '../request'; + +vi.mock('../request', () => ({ + get: vi.fn(), + post: vi.fn(), +})); + +describe('agent api', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('maps agent endpoints correctly', () => { + listAgents(); + queryAgents({ agentCode: 'demo' }); + getAgentById('1001'); + saveAgent({ agentCode: 'agent_1', agentName: 'Agent 1', storeId: '2001', status: 'ENABLED' }); + deleteAgent('1001'); + chatWithAgent('1001', { messages: [{ role: 'user', content: '你好' }] }); + + expect(post).toHaveBeenCalledWith('/agents/list'); + expect(post).toHaveBeenCalledWith('/agents/query', { agentCode: 'demo' }); + expect(get).toHaveBeenCalledWith('/agents/detail', { params: { id: '1001' } }); + expect(post).toHaveBeenCalledWith('/agents/save', { + agentCode: 'agent_1', + agentName: 'Agent 1', + storeId: '2001', + status: 'ENABLED', + }); + expect(post).toHaveBeenCalledWith('/agents/delete', undefined, { params: { id: '1001' } }); + expect(post).toHaveBeenCalledWith('/agents/1001/chat', { messages: [{ role: 'user', content: '你好' }] }); + }); +}); diff --git a/frontend/src/api/agent.ts b/frontend/src/api/agent.ts new file mode 100644 index 0000000..23ec587 --- /dev/null +++ b/frontend/src/api/agent.ts @@ -0,0 +1,70 @@ +import { get, post } from './request'; + +export interface AgentDefinition { + id?: string; + agentCode: string; + agentName: string; + systemPrompt?: string; + storeId: string; + status: string; + remark?: string; +} + +export interface AgentDefinitionQueryRequest { + agentCode?: string; + agentName?: string; + status?: string; + storeId?: string; +} + +export interface AgentMessage { + role: 'system' | 'user' | 'assistant'; + content: string; +} + +export interface AgentChatRequest { + messages: AgentMessage[]; + ragEnabled?: boolean; +} + +export interface AgentReferenceChunk { + chunkId: string; + documentId: string; + chunkContent: string; + score?: number; +} + +export interface AgentChatResponse { + agentId: string; + agentCode: string; + agentName: string; + storeId: string; + storeName?: string; + answer: string; + modelRequestId: string; + references: AgentReferenceChunk[]; +} + +export function listAgents() { + return post('/agents/list'); +} + +export function queryAgents(query?: AgentDefinitionQueryRequest) { + return post('/agents/query', query); +} + +export function getAgentById(id: string) { + return get('/agents/detail', { params: { id } }); +} + +export function saveAgent(data: Partial & { id?: string }) { + return post('/agents/save', data); +} + +export function deleteAgent(id: string) { + return post('/agents/delete', undefined, { params: { id } }); +} + +export function chatWithAgent(agentId: string, data: AgentChatRequest) { + return post(`/agents/${agentId}/chat`, data); +} diff --git a/frontend/src/pages/agent/AgentDebugPage.vue b/frontend/src/pages/agent/AgentDebugPage.vue new file mode 100644 index 0000000..39ad817 --- /dev/null +++ b/frontend/src/pages/agent/AgentDebugPage.vue @@ -0,0 +1,270 @@ + + + + + diff --git a/frontend/src/pages/agent/AgentManagePage.vue b/frontend/src/pages/agent/AgentManagePage.vue new file mode 100644 index 0000000..ac5bd34 --- /dev/null +++ b/frontend/src/pages/agent/AgentManagePage.vue @@ -0,0 +1,195 @@ + + + + + diff --git a/script/sql/agent_definition.sql b/script/sql/agent_definition.sql new file mode 100644 index 0000000..31d06aa --- /dev/null +++ b/script/sql/agent_definition.sql @@ -0,0 +1,35 @@ +DROP TABLE IF EXISTS agent_definition; + +CREATE TABLE agent_definition ( + id BIGSERIAL PRIMARY KEY, + agent_code VARCHAR(100) NOT NULL, + agent_name VARCHAR(200) NOT NULL, + system_prompt TEXT, + store_id BIGINT NOT NULL, + status VARCHAR(50) NOT NULL DEFAULT 'ENABLED', + version INTEGER NOT NULL DEFAULT 1, + create_time TIMESTAMP, + update_time TIMESTAMP, + remark VARCHAR(500) DEFAULT '', + create_by VARCHAR(64), + update_by VARCHAR(64), + CONSTRAINT uk_agent_definition_code UNIQUE (agent_code), + CONSTRAINT fk_agent_definition_store_id FOREIGN KEY (store_id) REFERENCES rag_store (id) +); + +CREATE INDEX idx_agent_definition_store_id ON agent_definition (store_id); +CREATE INDEX idx_agent_definition_status ON agent_definition (status); + +COMMENT ON TABLE agent_definition IS 'Agent定义表'; +COMMENT ON COLUMN agent_definition.id IS 'ID'; +COMMENT ON COLUMN agent_definition.agent_code IS 'Agent编码'; +COMMENT ON COLUMN agent_definition.agent_name IS 'Agent名称'; +COMMENT ON COLUMN agent_definition.system_prompt IS '系统提示词'; +COMMENT ON COLUMN agent_definition.store_id IS '绑定知识库ID'; +COMMENT ON COLUMN agent_definition.status IS '状态'; +COMMENT ON COLUMN agent_definition.version IS '版本'; +COMMENT ON COLUMN agent_definition.create_time IS '创建时间'; +COMMENT ON COLUMN agent_definition.update_time IS '更新时间'; +COMMENT ON COLUMN agent_definition.remark IS '备注'; +COMMENT ON COLUMN agent_definition.create_by IS '创建者'; +COMMENT ON COLUMN agent_definition.update_by IS '更新者'; diff --git a/script/sql/model_call_log_patch.sql b/script/sql/model_call_log_patch.sql new file mode 100644 index 0000000..b7a0c1b --- /dev/null +++ b/script/sql/model_call_log_patch.sql @@ -0,0 +1,20 @@ +-- model_call_log 补丁脚本 +-- 目的:对齐 BaseEntity 字段,避免 MyBatis 查询 create_by / update_by / update_time / version 报错 + +ALTER TABLE model_call_log + ADD COLUMN IF NOT EXISTS create_by VARCHAR(64); + +ALTER TABLE model_call_log + ADD COLUMN IF NOT EXISTS update_by VARCHAR(64); + +ALTER TABLE model_call_log + ADD COLUMN IF NOT EXISTS update_time TIMESTAMP; + +ALTER TABLE model_call_log + ADD COLUMN IF NOT EXISTS version INTEGER NOT NULL DEFAULT 1; + +COMMENT ON COLUMN model_call_log.create_by IS '创建者'; +COMMENT ON COLUMN model_call_log.update_by IS '更新者'; +COMMENT ON COLUMN model_call_log.update_time IS '更新时间'; +COMMENT ON COLUMN model_call_log.version IS '版本'; + diff --git a/src/main/java/com/bruce/agent/controller/AgentDefinitionController.java b/src/main/java/com/bruce/agent/controller/AgentDefinitionController.java new file mode 100644 index 0000000..4e93de6 --- /dev/null +++ b/src/main/java/com/bruce/agent/controller/AgentDefinitionController.java @@ -0,0 +1,58 @@ +package com.bruce.agent.controller; + +import com.bruce.agent.dto.request.AgentChatRequest; +import com.bruce.agent.dto.request.AgentDefinitionQueryRequest; +import com.bruce.agent.dto.request.AgentDefinitionSaveRequest; +import com.bruce.agent.dto.response.AgentChatResponse; +import com.bruce.agent.dto.response.AgentDefinitionResponse; +import com.bruce.agent.service.IAgentDefinitionService; +import com.bruce.common.domain.model.RequestResult; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; + +import java.util.List; + +@RestController +@RequestMapping("/api/agents") +@RequiredArgsConstructor +public class AgentDefinitionController { + + private final IAgentDefinitionService agentDefinitionService; + + @PostMapping("/list") + public RequestResult> list() { + return RequestResult.success(agentDefinitionService.listResponses()); + } + + @PostMapping("/query") + public RequestResult> query(@RequestBody(required = false) AgentDefinitionQueryRequest request) { + return RequestResult.success(agentDefinitionService.query(request)); + } + + @GetMapping("/detail") + public RequestResult detail(@RequestParam("id") Long id) { + return RequestResult.success(agentDefinitionService.getResponseById(id)); + } + + @PostMapping("/save") + public RequestResult save(@RequestBody AgentDefinitionSaveRequest request) { + return RequestResult.success(agentDefinitionService.saveOrUpdate(request)); + } + + @PostMapping("/delete") + public RequestResult delete(@RequestParam("id") Long id) { + return RequestResult.success(agentDefinitionService.removeById(id)); + } + + @PostMapping("/{agentId}/chat") + public RequestResult chat(@PathVariable("agentId") Long agentId, + @RequestBody AgentChatRequest request) { + return RequestResult.success(agentDefinitionService.chat(agentId, request)); + } +} diff --git a/src/main/java/com/bruce/agent/dto/request/AgentChatRequest.java b/src/main/java/com/bruce/agent/dto/request/AgentChatRequest.java new file mode 100644 index 0000000..3e50e94 --- /dev/null +++ b/src/main/java/com/bruce/agent/dto/request/AgentChatRequest.java @@ -0,0 +1,17 @@ +package com.bruce.agent.dto.request; + +import lombok.Data; + +import java.util.List; + +@Data +public class AgentChatRequest { + private List messages; + private Boolean ragEnabled; + + @Data + public static class AgentMessage { + private String role; + private String content; + } +} diff --git a/src/main/java/com/bruce/agent/dto/request/AgentDefinitionQueryRequest.java b/src/main/java/com/bruce/agent/dto/request/AgentDefinitionQueryRequest.java new file mode 100644 index 0000000..cc5cfd9 --- /dev/null +++ b/src/main/java/com/bruce/agent/dto/request/AgentDefinitionQueryRequest.java @@ -0,0 +1,11 @@ +package com.bruce.agent.dto.request; + +import lombok.Data; + +@Data +public class AgentDefinitionQueryRequest { + private String agentCode; + private String agentName; + private String status; + private Long storeId; +} diff --git a/src/main/java/com/bruce/agent/dto/request/AgentDefinitionSaveRequest.java b/src/main/java/com/bruce/agent/dto/request/AgentDefinitionSaveRequest.java new file mode 100644 index 0000000..e439b90 --- /dev/null +++ b/src/main/java/com/bruce/agent/dto/request/AgentDefinitionSaveRequest.java @@ -0,0 +1,14 @@ +package com.bruce.agent.dto.request; + +import lombok.Data; + +@Data +public class AgentDefinitionSaveRequest { + private Long id; + private String agentCode; + private String agentName; + private String systemPrompt; + private Long storeId; + private String status; + private String remark; +} diff --git a/src/main/java/com/bruce/agent/dto/response/AgentChatResponse.java b/src/main/java/com/bruce/agent/dto/response/AgentChatResponse.java new file mode 100644 index 0000000..56f717d --- /dev/null +++ b/src/main/java/com/bruce/agent/dto/response/AgentChatResponse.java @@ -0,0 +1,25 @@ +package com.bruce.agent.dto.response; + +import lombok.Data; + +import java.util.List; + +@Data +public class AgentChatResponse { + private Long agentId; + private String agentCode; + private String agentName; + private Long storeId; + private String storeName; + private String answer; + private String modelRequestId; + private List references; + + @Data + public static class ReferenceChunk { + private Long chunkId; + private Long documentId; + private String chunkContent; + private Double score; + } +} diff --git a/src/main/java/com/bruce/agent/dto/response/AgentDefinitionResponse.java b/src/main/java/com/bruce/agent/dto/response/AgentDefinitionResponse.java new file mode 100644 index 0000000..e9896d0 --- /dev/null +++ b/src/main/java/com/bruce/agent/dto/response/AgentDefinitionResponse.java @@ -0,0 +1,25 @@ +package com.bruce.agent.dto.response; + +import com.bruce.agent.entity.AgentDefinition; +import lombok.Data; +import org.springframework.beans.BeanUtils; + +@Data +public class AgentDefinitionResponse { + private Long id; + private String agentCode; + private String agentName; + private String systemPrompt; + private Long storeId; + private String status; + private String remark; + + public static AgentDefinitionResponse fromEntity(AgentDefinition entity) { + if (entity == null) { + return null; + } + AgentDefinitionResponse response = new AgentDefinitionResponse(); + BeanUtils.copyProperties(entity, response); + return response; + } +} diff --git a/src/main/java/com/bruce/agent/entity/AgentDefinition.java b/src/main/java/com/bruce/agent/entity/AgentDefinition.java new file mode 100644 index 0000000..14cf889 --- /dev/null +++ b/src/main/java/com/bruce/agent/entity/AgentDefinition.java @@ -0,0 +1,29 @@ +package com.bruce.agent.entity; + +import com.baomidou.mybatisplus.annotation.TableField; +import com.baomidou.mybatisplus.annotation.TableName; +import com.bruce.common.domain.model.BaseEntity; +import lombok.Data; +import lombok.EqualsAndHashCode; + +@Data +@EqualsAndHashCode(callSuper = true) +@TableName("agent_definition") +public class AgentDefinition extends BaseEntity { + + @TableField("agent_code") + private String agentCode; + + @TableField("agent_name") + private String agentName; + + @TableField("system_prompt") + private String systemPrompt; + + @TableField("store_id") + private Long storeId; + + private String status; + + private String remark; +} diff --git a/src/main/java/com/bruce/agent/mapper/AgentDefinitionMapper.java b/src/main/java/com/bruce/agent/mapper/AgentDefinitionMapper.java new file mode 100644 index 0000000..dfa3100 --- /dev/null +++ b/src/main/java/com/bruce/agent/mapper/AgentDefinitionMapper.java @@ -0,0 +1,9 @@ +package com.bruce.agent.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.bruce.agent.entity.AgentDefinition; +import org.apache.ibatis.annotations.Mapper; + +@Mapper +public interface AgentDefinitionMapper extends BaseMapper { +} diff --git a/src/main/java/com/bruce/agent/service/IAgentDefinitionService.java b/src/main/java/com/bruce/agent/service/IAgentDefinitionService.java new file mode 100644 index 0000000..f96a48e --- /dev/null +++ b/src/main/java/com/bruce/agent/service/IAgentDefinitionService.java @@ -0,0 +1,23 @@ +package com.bruce.agent.service; + +import com.baomidou.mybatisplus.extension.service.IService; +import com.bruce.agent.dto.request.AgentChatRequest; +import com.bruce.agent.dto.request.AgentDefinitionQueryRequest; +import com.bruce.agent.dto.request.AgentDefinitionSaveRequest; +import com.bruce.agent.dto.response.AgentChatResponse; +import com.bruce.agent.dto.response.AgentDefinitionResponse; +import com.bruce.agent.entity.AgentDefinition; + +import java.util.List; + +public interface IAgentDefinitionService extends IService { + List listResponses(); + + List query(AgentDefinitionQueryRequest request); + + AgentDefinitionResponse getResponseById(Long id); + + boolean saveOrUpdate(AgentDefinitionSaveRequest request); + + AgentChatResponse chat(Long agentId, AgentChatRequest request); +} diff --git a/src/main/java/com/bruce/agent/service/impl/AgentDefinitionServiceImpl.java b/src/main/java/com/bruce/agent/service/impl/AgentDefinitionServiceImpl.java new file mode 100644 index 0000000..7e454d5 --- /dev/null +++ b/src/main/java/com/bruce/agent/service/impl/AgentDefinitionServiceImpl.java @@ -0,0 +1,304 @@ +package com.bruce.agent.service.impl; + +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.bruce.agent.dto.request.AgentChatRequest; +import com.bruce.agent.dto.request.AgentDefinitionQueryRequest; +import com.bruce.agent.dto.request.AgentDefinitionSaveRequest; +import com.bruce.agent.dto.response.AgentChatResponse; +import com.bruce.agent.dto.response.AgentDefinitionResponse; +import com.bruce.agent.entity.AgentDefinition; +import com.bruce.agent.mapper.AgentDefinitionMapper; +import com.bruce.agent.service.IAgentDefinitionService; +import com.bruce.common.enums.EnableStatusEnum; +import com.bruce.modelprovider.client.OpenAiChatMessage; +import com.bruce.modelprovider.entity.RagStoreModelConfig; +import com.bruce.modelprovider.gateway.ChatModelGateway; +import com.bruce.modelprovider.gateway.ChatRequest; +import com.bruce.modelprovider.gateway.ChatResult; +import com.bruce.modelprovider.gateway.EmbeddingModelGateway; +import com.bruce.modelprovider.gateway.EmbeddingRequest; +import com.bruce.modelprovider.gateway.EmbeddingResult; +import com.bruce.modelprovider.service.IRagStoreModelConfigService; +import com.bruce.rag.dto.response.RagChunkRecallResponse; +import com.bruce.rag.entity.RagStore; +import com.bruce.rag.mapper.RagChunkEmbeddingMapper; +import com.bruce.rag.service.IRagStoreService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.List; + +@Slf4j +@Service +@RequiredArgsConstructor +public class AgentDefinitionServiceImpl extends ServiceImpl + implements IAgentDefinitionService { + + private static final int DEFAULT_TOP_K = 5; + + private final IRagStoreService ragStoreService; + private final IRagStoreModelConfigService ragStoreModelConfigService; + private final RagChunkEmbeddingMapper ragChunkEmbeddingMapper; + private final EmbeddingModelGateway embeddingModelGateway; + private final ChatModelGateway chatModelGateway; + + @Override + public List listResponses() { + return lambdaQuery() + .orderByAsc(AgentDefinition::getAgentCode) + .list() + .stream() + .map(AgentDefinitionResponse::fromEntity) + .toList(); + } + + @Override + public List query(AgentDefinitionQueryRequest request) { + AgentDefinitionQueryRequest queryRequest = request == null ? new AgentDefinitionQueryRequest() : request; + return lambdaQuery() + .eq(StringUtils.hasText(queryRequest.getAgentCode()), AgentDefinition::getAgentCode, trimToNull(queryRequest.getAgentCode())) + .like(StringUtils.hasText(queryRequest.getAgentName()), AgentDefinition::getAgentName, trimToNull(queryRequest.getAgentName())) + .eq(StringUtils.hasText(queryRequest.getStatus()), AgentDefinition::getStatus, trimToNull(queryRequest.getStatus())) + .eq(queryRequest.getStoreId() != null, AgentDefinition::getStoreId, queryRequest.getStoreId()) + .orderByAsc(AgentDefinition::getAgentCode) + .list() + .stream() + .map(AgentDefinitionResponse::fromEntity) + .toList(); + } + + @Override + public AgentDefinitionResponse getResponseById(Long id) { + return AgentDefinitionResponse.fromEntity(getById(id)); + } + + @Override + public boolean saveOrUpdate(AgentDefinitionSaveRequest request) { + validateSaveRequest(request); + if (ragStoreService.getById(request.getStoreId()) == null) { + throw new IllegalArgumentException("绑定知识库不存在,ID: " + request.getStoreId()); + } + AgentDefinition duplicate = lambdaQuery() + .eq(AgentDefinition::getAgentCode, request.getAgentCode().trim()) + .ne(request.getId() != null, AgentDefinition::getId, request.getId()) + .one(); + if (duplicate != null) { + throw new IllegalArgumentException("Agent编码已存在: " + request.getAgentCode().trim()); + } + AgentDefinition entity = request.getId() == null ? new AgentDefinition() : getById(request.getId()); + if (entity == null) { + throw new IllegalArgumentException("Agent不存在,ID: " + request.getId()); + } + entity.setAgentCode(request.getAgentCode().trim()); + entity.setAgentName(request.getAgentName().trim()); + entity.setSystemPrompt(trimToNull(request.getSystemPrompt())); + entity.setStoreId(request.getStoreId()); + entity.setStatus(StringUtils.hasText(request.getStatus()) + ? request.getStatus().trim() + : EnableStatusEnum.ENABLED.name()); + entity.setRemark(trimToNull(request.getRemark())); + return request.getId() == null ? save(entity) : updateById(entity); + } + + @Override + public AgentChatResponse chat(Long agentId, AgentChatRequest request) { + if (agentId == null) { + throw new IllegalArgumentException("Agent ID不能为空"); + } + if (request == null || request.getMessages() == null || request.getMessages().isEmpty()) { + throw new IllegalArgumentException("对话消息不能为空"); + } + AgentDefinition agent = getById(agentId); + if (agent == null) { + throw new IllegalArgumentException("Agent不存在,ID: " + agentId); + } + if (!EnableStatusEnum.ENABLED.name().equals(agent.getStatus())) { + throw new IllegalArgumentException("Agent已停用,暂不支持对话"); + } + if (agent.getStoreId() == null) { + throw new IllegalArgumentException("Agent未绑定知识库,请先保存知识库配置"); + } + RagStore store = ragStoreService.getById(agent.getStoreId()); + if (store == null) { + throw new IllegalArgumentException("绑定知识库不存在,ID: " + agent.getStoreId()); + } + + String queryText = resolveLatestUserMessage(request.getMessages()); + boolean ragEnabled = request.getRagEnabled() == null || request.getRagEnabled(); + List recalls = List.of(); + if (ragEnabled) { + RagStoreModelConfig storeModelConfig = ragStoreModelConfigService.getActiveEntity(agent.getStoreId()); + if (storeModelConfig == null || storeModelConfig.getEmbeddingModelId() == null) { + throw new IllegalArgumentException("当前知识库未配置Embedding模型,无法执行检索对话"); + } + EmbeddingRequest embeddingRequest = new EmbeddingRequest(); + embeddingRequest.setTexts(List.of(queryText)); + embeddingRequest.setTaskType("RAG_QUERY_EMBEDDING"); + embeddingRequest.setMatchScope("RAG_STORE"); + embeddingRequest.setScopeId(agent.getStoreId()); + embeddingRequest.setBizType("AGENT_CHAT"); + embeddingRequest.setBizId(String.valueOf(agentId)); + embeddingRequest.setExpectedDimension(storeModelConfig.getEmbeddingDimension()); + EmbeddingResult queryEmbedding = embeddingModelGateway.embed(embeddingRequest); + if (queryEmbedding.getVectors() == null || queryEmbedding.getVectors().isEmpty()) { + throw new IllegalArgumentException("查询向量生成失败,请检查Embedding模型配置"); + } + + String queryVector = toVectorLiteral(queryEmbedding.getVectors().getFirst()); + recalls = ragChunkEmbeddingMapper.queryTopKByStore( + agent.getStoreId(), + queryVector, + DEFAULT_TOP_K + ); + if (recalls.isEmpty()) { + throw new IllegalArgumentException("未召回到可用知识切片,请先完成知识库切片与向量化"); + } + } + + ChatRequest chatRequest = new ChatRequest(); + chatRequest.setTaskType(ragEnabled ? "RAG_ANSWER" : "CHAT_SIMPLE"); + chatRequest.setMatchScope("AGENT"); + chatRequest.setScopeId(agentId); + chatRequest.setBizType("AGENT_CHAT"); + chatRequest.setBizId(String.valueOf(agentId)); + chatRequest.setMessages(buildChatMessages(agent, recalls, request.getMessages(), ragEnabled)); + + ChatResult chatResult = chatModelGateway.chat(chatRequest); + AgentChatResponse response = new AgentChatResponse(); + response.setAgentId(agent.getId()); + response.setAgentCode(agent.getAgentCode()); + response.setAgentName(agent.getAgentName()); + response.setStoreId(agent.getStoreId()); + response.setStoreName(store.getStoreName()); + response.setAnswer(chatResult.getContent()); + response.setModelRequestId(chatResult.getCallLog().getRequestId()); + response.setReferences(toReferenceChunks(recalls)); + return response; + } + + private void validateSaveRequest(AgentDefinitionSaveRequest request) { + if (request == null) { + throw new IllegalArgumentException("保存请求不能为空"); + } + if (!StringUtils.hasText(request.getAgentCode())) { + throw new IllegalArgumentException("Agent编码不能为空"); + } + if (!StringUtils.hasText(request.getAgentName())) { + throw new IllegalArgumentException("Agent名称不能为空"); + } + if (request.getStoreId() == null) { + throw new IllegalArgumentException("绑定知识库不能为空"); + } + } + + private String resolveLatestUserMessage(List messages) { + for (int index = messages.size() - 1; index >= 0; index--) { + AgentChatRequest.AgentMessage message = messages.get(index); + if (message != null + && "user".equalsIgnoreCase(message.getRole()) + && StringUtils.hasText(message.getContent())) { + return message.getContent(); + } + } + throw new IllegalArgumentException("缺少用户提问内容"); + } + + private List buildChatMessages(AgentDefinition agent, + List recalls, + List rawMessages, + boolean ragEnabled) { + List messages = new ArrayList<>(); + OpenAiChatMessage instructionMessage = new OpenAiChatMessage(); + instructionMessage.setRole("system"); + instructionMessage.setContent(buildSystemInstruction(agent)); + messages.add(instructionMessage); + + if (ragEnabled) { + OpenAiChatMessage contextMessage = new OpenAiChatMessage(); + contextMessage.setRole("system"); + contextMessage.setContent(buildContextText(recalls)); + messages.add(contextMessage); + } + + for (AgentChatRequest.AgentMessage rawMessage : rawMessages) { + if (rawMessage == null || !StringUtils.hasText(rawMessage.getContent())) { + continue; + } + OpenAiChatMessage message = new OpenAiChatMessage(); + message.setRole(normalizeRole(rawMessage.getRole())); + message.setContent(rawMessage.getContent()); + messages.add(message); + } + return messages; + } + + private String buildSystemInstruction(AgentDefinition agent) { + StringBuilder builder = new StringBuilder(); + if (StringUtils.hasText(agent.getSystemPrompt())) { + builder.append(agent.getSystemPrompt().trim()).append("\n\n"); + } + builder.append("请优先基于已给出的知识库引用片段回答。"); + builder.append("如果引用无法支持结论,请明确告知“知识库中暂无直接依据”。"); + return builder.toString(); + } + + private String buildContextText(List recalls) { + StringBuilder builder = new StringBuilder("以下是知识库召回片段:\n"); + for (int i = 0; i < recalls.size(); i++) { + RagChunkRecallResponse recall = recalls.get(i); + builder.append(i + 1) + .append(". [chunkId=") + .append(recall.getChunkId()) + .append(", score=") + .append(String.format("%.4f", recall.getScore() == null ? 0D : recall.getScore())) + .append("] ") + .append(recall.getChunkContent()) + .append("\n"); + } + return builder.toString(); + } + + private List toReferenceChunks(List recalls) { + return recalls.stream().map(recall -> { + AgentChatResponse.ReferenceChunk chunk = new AgentChatResponse.ReferenceChunk(); + chunk.setChunkId(recall.getChunkId()); + chunk.setDocumentId(recall.getDocumentId()); + chunk.setChunkContent(recall.getChunkContent()); + chunk.setScore(recall.getScore()); + return chunk; + }).toList(); + } + + private String normalizeRole(String role) { + if (!StringUtils.hasText(role)) { + return "user"; + } + String normalized = role.trim().toLowerCase(); + if ("system".equals(normalized) || "assistant".equals(normalized) || "user".equals(normalized)) { + return normalized; + } + return "user"; + } + + private String toVectorLiteral(List vector) { + StringBuilder builder = new StringBuilder("["); + for (int index = 0; index < vector.size(); index++) { + if (index > 0) { + builder.append(','); + } + builder.append(vector.get(index)); + } + builder.append(']'); + return builder.toString(); + } + + private String trimToNull(String value) { + if (!StringUtils.hasText(value)) { + return null; + } + return value.trim(); + } +} diff --git a/src/main/java/com/bruce/modelprovider/client/OpenAiChatCompletionResult.java b/src/main/java/com/bruce/modelprovider/client/OpenAiChatCompletionResult.java new file mode 100644 index 0000000..bc4c239 --- /dev/null +++ b/src/main/java/com/bruce/modelprovider/client/OpenAiChatCompletionResult.java @@ -0,0 +1,12 @@ +package com.bruce.modelprovider.client; + +import lombok.Data; + +@Data +public class OpenAiChatCompletionResult { + private String upstreamRequestId; + private String content; + private Integer promptTokens; + private Integer completionTokens; + private Integer totalTokens; +} diff --git a/src/main/java/com/bruce/modelprovider/client/OpenAiChatMessage.java b/src/main/java/com/bruce/modelprovider/client/OpenAiChatMessage.java new file mode 100644 index 0000000..6dc26a2 --- /dev/null +++ b/src/main/java/com/bruce/modelprovider/client/OpenAiChatMessage.java @@ -0,0 +1,9 @@ +package com.bruce.modelprovider.client; + +import lombok.Data; + +@Data +public class OpenAiChatMessage { + private String role; + private String content; +} diff --git a/src/main/java/com/bruce/modelprovider/config/AiConfigFilePropertySourceConfig.java b/src/main/java/com/bruce/modelprovider/config/AiConfigFilePropertySourceConfig.java new file mode 100644 index 0000000..af5d2dc --- /dev/null +++ b/src/main/java/com/bruce/modelprovider/config/AiConfigFilePropertySourceConfig.java @@ -0,0 +1,18 @@ +package com.bruce.modelprovider.config; + +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.PropertySource; + +/** + * 加载独立 AI 配置文件。 + *

+ * 说明: + * 1. 该文件使用 INI 扩展名,但内容采用 key=value 形式,Spring 可直接按 Properties 解析; + * 2. ignoreResourceNotFound=true,允许某些环境不提供该文件,避免启动失败; + * 3. 具体键值由 {@link AiSecretProperties} 统一绑定与读取。 + */ +@Configuration +@PropertySource(value = "classpath:ai-config.ini", ignoreResourceNotFound = true) +public class AiConfigFilePropertySourceConfig { +} + diff --git a/src/main/java/com/bruce/modelprovider/config/AiSecretProperties.java b/src/main/java/com/bruce/modelprovider/config/AiSecretProperties.java new file mode 100644 index 0000000..d6a6876 --- /dev/null +++ b/src/main/java/com/bruce/modelprovider/config/AiSecretProperties.java @@ -0,0 +1,41 @@ +package com.bruce.modelprovider.config; + +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; + +import java.util.HashMap; +import java.util.Map; + +/** + * AI 密钥配置绑定。 + *

+ * 支持从 ai-config.ini 读取如下配置: + * ai.secret-refs[SILICONFLOW_API_KEY]=your-key + */ +@Data +@Component +@ConfigurationProperties(prefix = "ai") +public class AiSecretProperties { + + /** + * key 为 secretRef(例如 SILICONFLOW_API_KEY),value 为实际密钥。 + */ + private Map secretRefs = new HashMap<>(); + + /** + * 根据 secretRef 获取配置文件中的密钥,并做空白清理。 + */ + public String getApiKeyBySecretRef(String secretRef) { + if (!StringUtils.hasText(secretRef)) { + return null; + } + String value = secretRefs.get(secretRef.trim()); + if (!StringUtils.hasText(value)) { + return null; + } + return value.trim(); + } +} + diff --git a/src/main/java/com/bruce/modelprovider/gateway/ChatModelGateway.java b/src/main/java/com/bruce/modelprovider/gateway/ChatModelGateway.java new file mode 100644 index 0000000..d26d0bc --- /dev/null +++ b/src/main/java/com/bruce/modelprovider/gateway/ChatModelGateway.java @@ -0,0 +1,5 @@ +package com.bruce.modelprovider.gateway; + +public interface ChatModelGateway { + ChatResult chat(ChatRequest request); +} diff --git a/src/main/java/com/bruce/modelprovider/gateway/ChatModelGatewayImpl.java b/src/main/java/com/bruce/modelprovider/gateway/ChatModelGatewayImpl.java new file mode 100644 index 0000000..cb70f8a --- /dev/null +++ b/src/main/java/com/bruce/modelprovider/gateway/ChatModelGatewayImpl.java @@ -0,0 +1,135 @@ +package com.bruce.modelprovider.gateway; + +import com.bruce.modelprovider.client.OpenAiChatCompletionResult; +import com.bruce.modelprovider.client.OpenAiChatMessage; +import com.bruce.modelprovider.client.OpenAiCompatibleModelClient; +import com.bruce.modelprovider.entity.ModelCallLog; +import com.bruce.modelprovider.entity.ModelConfig; +import com.bruce.modelprovider.entity.ModelProvider; +import com.bruce.modelprovider.enums.ModelCallStatusEnum; +import com.bruce.modelprovider.route.ModelRouteContext; +import com.bruce.modelprovider.route.ModelRouteDecision; +import com.bruce.modelprovider.service.IModelCallLogService; +import com.bruce.modelprovider.service.IModelProviderService; +import com.bruce.modelprovider.service.IModelRouteService; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; +import org.springframework.util.DigestUtils; +import org.springframework.util.StringUtils; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.UUID; + +@Component +@RequiredArgsConstructor +public class ChatModelGatewayImpl implements ChatModelGateway { + + private final IModelRouteService modelRouteService; + private final IModelProviderService modelProviderService; + private final IModelCallLogService modelCallLogService; + private final OpenAiCompatibleModelClient openAiCompatibleModelClient; + + @Override + public ChatResult chat(ChatRequest request) { + if (request == null || request.getMessages() == null || request.getMessages().isEmpty()) { + throw new IllegalArgumentException("聊天请求不能为空"); + } + long start = System.currentTimeMillis(); + ModelCallLog callLog = new ModelCallLog(); + callLog.setRequestId(UUID.randomUUID().toString().replace("-", "")); + callLog.setTaskType(request.getTaskType()); + callLog.setBizType(request.getBizType()); + callLog.setBizId(request.getBizId()); + callLog.setCallType("CHAT"); + callLog.setRequestHash(buildRequestHash(request.getMessages())); + try { + ModelRouteContext routeContext = new ModelRouteContext(); + routeContext.setTaskType(request.getTaskType()); + routeContext.setMatchScope(request.getMatchScope()); + routeContext.setScopeId(request.getScopeId()); + routeContext.setRequiredModelType("CHAT"); + routeContext.setBizType(request.getBizType()); + routeContext.setBizId(request.getBizId()); + ModelRouteDecision decision = modelRouteService.route(routeContext); + + ModelCallExecution execution = executeWithFallback( + decision.getPrimaryModel(), + decision.getFallbackModels(), + request.getMessages() + ); + + callLog.setProviderId(execution.provider().getId()); + callLog.setModelId(execution.model().getId()); + callLog.setStatus(ModelCallStatusEnum.SUCCESS.name()); + callLog.setPromptTokens(execution.result().getPromptTokens()); + callLog.setCompletionTokens(execution.result().getCompletionTokens()); + callLog.setTotalTokens(execution.result().getTotalTokens()); + callLog.setDurationMs((int) (System.currentTimeMillis() - start)); + modelCallLogService.save(callLog); + + ChatResult result = new ChatResult(); + result.setModelId(execution.model().getId()); + result.setModelName(execution.model().getModelName()); + result.setContent(execution.result().getContent()); + result.setUpstreamRequestId(execution.result().getUpstreamRequestId()); + result.setPromptTokens(execution.result().getPromptTokens()); + result.setCompletionTokens(execution.result().getCompletionTokens()); + result.setTotalTokens(execution.result().getTotalTokens()); + result.setCallLog(callLog); + return result; + } catch (Exception ex) { + callLog.setStatus(ModelCallStatusEnum.FAILED.name()); + callLog.setDurationMs((int) (System.currentTimeMillis() - start)); + callLog.setErrorCode("CHAT_COMPLETION_FAILED"); + String message = ex.getMessage(); + callLog.setErrorMessage(message == null ? "unknown" : message.substring(0, Math.min(message.length(), 1000))); + modelCallLogService.save(callLog); + throw ex; + } + } + + private ModelCallExecution executeWithFallback(ModelConfig primaryModel, + List fallbackModels, + List messages) { + ModelProvider primaryProvider = requireAvailableProvider(primaryModel.getProviderId()); + try { + OpenAiChatCompletionResult result = openAiCompatibleModelClient.chatCompletions(primaryProvider, primaryModel, messages); + return new ModelCallExecution(primaryProvider, primaryModel, result); + } catch (Exception primaryEx) { + for (ModelConfig fallbackModel : fallbackModels) { + try { + ModelProvider fallbackProvider = requireAvailableProvider(fallbackModel.getProviderId()); + OpenAiChatCompletionResult result = openAiCompatibleModelClient.chatCompletions( + fallbackProvider, + fallbackModel, + messages + ); + return new ModelCallExecution(fallbackProvider, fallbackModel, result); + } catch (Exception ignored) { + // continue fallback chain + } + } + throw primaryEx; + } + } + + private ModelProvider requireAvailableProvider(Long providerId) { + ModelProvider provider = modelProviderService.getById(providerId); + if (provider == null || !Boolean.TRUE.equals(provider.getEnabled())) { + throw new IllegalStateException("模型服务商不可用"); + } + return provider; + } + + private String buildRequestHash(List messages) { + String plainText = messages.stream() + .map(message -> (StringUtils.hasText(message.getRole()) ? message.getRole() : "user") + ":" + message.getContent()) + .reduce((left, right) -> left + "|" + right) + .orElse(""); + return DigestUtils.md5DigestAsHex(plainText.getBytes(StandardCharsets.UTF_8)); + } + + private record ModelCallExecution(ModelProvider provider, ModelConfig model, OpenAiChatCompletionResult result) { + } +} diff --git a/src/main/java/com/bruce/modelprovider/gateway/ChatRequest.java b/src/main/java/com/bruce/modelprovider/gateway/ChatRequest.java new file mode 100644 index 0000000..cc0cfbc --- /dev/null +++ b/src/main/java/com/bruce/modelprovider/gateway/ChatRequest.java @@ -0,0 +1,16 @@ +package com.bruce.modelprovider.gateway; + +import com.bruce.modelprovider.client.OpenAiChatMessage; +import lombok.Data; + +import java.util.List; + +@Data +public class ChatRequest { + private List messages; + private String taskType; + private String matchScope; + private Long scopeId; + private String bizType; + private String bizId; +} diff --git a/src/main/java/com/bruce/modelprovider/gateway/ChatResult.java b/src/main/java/com/bruce/modelprovider/gateway/ChatResult.java new file mode 100644 index 0000000..864ce80 --- /dev/null +++ b/src/main/java/com/bruce/modelprovider/gateway/ChatResult.java @@ -0,0 +1,16 @@ +package com.bruce.modelprovider.gateway; + +import com.bruce.modelprovider.entity.ModelCallLog; +import lombok.Data; + +@Data +public class ChatResult { + private Long modelId; + private String modelName; + private String content; + private String upstreamRequestId; + private Integer promptTokens; + private Integer completionTokens; + private Integer totalTokens; + private ModelCallLog callLog; +} diff --git a/src/main/java/com/bruce/rag/dto/response/RagChunkRecallResponse.java b/src/main/java/com/bruce/rag/dto/response/RagChunkRecallResponse.java new file mode 100644 index 0000000..2119709 --- /dev/null +++ b/src/main/java/com/bruce/rag/dto/response/RagChunkRecallResponse.java @@ -0,0 +1,11 @@ +package com.bruce.rag.dto.response; + +import lombok.Data; + +@Data +public class RagChunkRecallResponse { + private Long chunkId; + private Long documentId; + private String chunkContent; + private Double score; +} diff --git a/src/main/resources/ai-config.ini b/src/main/resources/ai-config.ini new file mode 100644 index 0000000..5af2d33 --- /dev/null +++ b/src/main/resources/ai-config.ini @@ -0,0 +1,5 @@ +# AI 独立配置文件(建议仅本地/环境覆盖使用,不提交真实密钥) +# 格式:ai.secret-refs[]= + +ai.secret-refs[SILICONFLOW_API_KEY]=your-key + diff --git a/src/test/java/com/bruce/agent/AgentComponentStructureTests.java b/src/test/java/com/bruce/agent/AgentComponentStructureTests.java new file mode 100644 index 0000000..8f59254 --- /dev/null +++ b/src/test/java/com/bruce/agent/AgentComponentStructureTests.java @@ -0,0 +1,63 @@ +package com.bruce.agent; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.baomidou.mybatisplus.extension.service.IService; +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.bruce.agent.controller.AgentDefinitionController; +import com.bruce.agent.dto.request.AgentChatRequest; +import com.bruce.agent.dto.request.AgentDefinitionQueryRequest; +import com.bruce.agent.dto.request.AgentDefinitionSaveRequest; +import com.bruce.agent.dto.response.AgentChatResponse; +import com.bruce.agent.dto.response.AgentDefinitionResponse; +import com.bruce.agent.entity.AgentDefinition; +import com.bruce.agent.mapper.AgentDefinitionMapper; +import com.bruce.agent.service.IAgentDefinitionService; +import com.bruce.agent.service.impl.AgentDefinitionServiceImpl; +import com.bruce.common.domain.model.RequestResult; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class AgentComponentStructureTests { + + @Test + void agentComponentsShouldReuseMybatisPlusBaseTypes() { + assertTrue(BaseMapper.class.isAssignableFrom(AgentDefinitionMapper.class)); + assertTrue(IService.class.isAssignableFrom(IAgentDefinitionService.class)); + assertTrue(ServiceImpl.class.isAssignableFrom(AgentDefinitionServiceImpl.class)); + } + + @Test + void agentControllerShouldExposeRequestResultMethods() throws NoSuchMethodException { + Method listMethod = AgentDefinitionController.class.getMethod("list"); + Method queryMethod = AgentDefinitionController.class.getMethod("query", AgentDefinitionQueryRequest.class); + Method detailMethod = AgentDefinitionController.class.getMethod("detail", Long.class); + Method saveMethod = AgentDefinitionController.class.getMethod("save", AgentDefinitionSaveRequest.class); + Method deleteMethod = AgentDefinitionController.class.getMethod("delete", Long.class); + Method chatMethod = AgentDefinitionController.class.getMethod("chat", Long.class, AgentChatRequest.class); + + Method listServiceMethod = IAgentDefinitionService.class.getMethod("listResponses"); + Method queryServiceMethod = IAgentDefinitionService.class.getMethod("query", AgentDefinitionQueryRequest.class); + Method detailServiceMethod = IAgentDefinitionService.class.getMethod("getResponseById", Long.class); + Method saveServiceMethod = IAgentDefinitionService.class.getMethod("saveOrUpdate", AgentDefinitionSaveRequest.class); + Method chatServiceMethod = IAgentDefinitionService.class.getMethod("chat", Long.class, AgentChatRequest.class); + + assertEquals(RequestResult.class, listMethod.getReturnType()); + assertEquals(RequestResult.class, queryMethod.getReturnType()); + assertEquals(RequestResult.class, detailMethod.getReturnType()); + assertEquals(RequestResult.class, saveMethod.getReturnType()); + assertEquals(RequestResult.class, deleteMethod.getReturnType()); + assertEquals(RequestResult.class, chatMethod.getReturnType()); + + assertEquals(List.class, listServiceMethod.getReturnType()); + assertEquals(List.class, queryServiceMethod.getReturnType()); + assertEquals(AgentDefinitionResponse.class, detailServiceMethod.getReturnType()); + assertEquals(boolean.class, saveServiceMethod.getReturnType()); + assertEquals(AgentChatResponse.class, chatServiceMethod.getReturnType()); + assertEquals(AgentDefinitionResponse.class, AgentDefinitionResponse.class.getMethod("fromEntity", AgentDefinition.class).getReturnType()); + } +} diff --git a/src/test/java/com/bruce/agent/AgentDefinitionServiceImplTests.java b/src/test/java/com/bruce/agent/AgentDefinitionServiceImplTests.java new file mode 100644 index 0000000..da92180 --- /dev/null +++ b/src/test/java/com/bruce/agent/AgentDefinitionServiceImplTests.java @@ -0,0 +1,221 @@ +package com.bruce.agent; + +import com.bruce.agent.dto.request.AgentChatRequest; +import com.bruce.agent.dto.request.AgentDefinitionSaveRequest; +import com.bruce.agent.dto.response.AgentChatResponse; +import com.bruce.agent.entity.AgentDefinition; +import com.bruce.agent.service.impl.AgentDefinitionServiceImpl; +import com.bruce.modelprovider.entity.ModelCallLog; +import com.bruce.modelprovider.entity.RagStoreModelConfig; +import com.bruce.modelprovider.gateway.ChatRequest; +import com.bruce.modelprovider.gateway.ChatResult; +import com.bruce.modelprovider.gateway.EmbeddingRequest; +import com.bruce.modelprovider.gateway.EmbeddingResult; +import com.bruce.modelprovider.service.IRagStoreModelConfigService; +import com.bruce.rag.dto.response.RagChunkRecallResponse; +import com.bruce.rag.entity.RagStore; +import com.bruce.rag.mapper.RagChunkEmbeddingMapper; +import com.bruce.rag.service.IRagStoreService; +import com.bruce.modelprovider.gateway.ChatModelGateway; +import com.bruce.modelprovider.gateway.EmbeddingModelGateway; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Spy; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class AgentDefinitionServiceImplTests { + + @Spy + @InjectMocks + private AgentDefinitionServiceImpl agentDefinitionService; + + @Mock + private IRagStoreService ragStoreService; + + @Mock + private IRagStoreModelConfigService ragStoreModelConfigService; + + @Mock + private RagChunkEmbeddingMapper ragChunkEmbeddingMapper; + + @Mock + private EmbeddingModelGateway embeddingModelGateway; + + @Mock + private ChatModelGateway chatModelGateway; + + @Test + void saveOrUpdateShouldValidateBoundStoreExists() { + AgentDefinitionSaveRequest request = new AgentDefinitionSaveRequest(); + request.setAgentCode("A_1"); + request.setAgentName("Agent 1"); + request.setStoreId(1001L); + + when(ragStoreService.getById(1001L)).thenReturn(null); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.saveOrUpdate(request)); + assertTrue(exception.getMessage().contains("绑定知识库不存在")); + } + + @Test + void chatShouldRejectDisabledAgent() { + AgentDefinition agent = new AgentDefinition(); + agent.setId(1001L); + agent.setStoreId(2001L); + agent.setStatus("DISABLED"); + doReturn(agent).when(agentDefinitionService).getById(1001L); + + AgentChatRequest request = new AgentChatRequest(); + AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage(); + message.setRole("user"); + message.setContent("你好"); + request.setMessages(List.of(message)); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request)); + assertTrue(exception.getMessage().contains("停用")); + } + + @Test + void chatShouldRejectAgentWithoutStore() { + AgentDefinition agent = new AgentDefinition(); + agent.setId(1001L); + agent.setStatus("ENABLED"); + agent.setStoreId(null); + doReturn(agent).when(agentDefinitionService).getById(1001L); + + AgentChatRequest request = new AgentChatRequest(); + AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage(); + message.setRole("user"); + message.setContent("你好"); + request.setMessages(List.of(message)); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> agentDefinitionService.chat(1001L, request)); + assertTrue(exception.getMessage().contains("未绑定知识库")); + } + + @Test + void chatShouldUseStoreScopedRecallAndReturnAnswer() { + AgentDefinition agent = new AgentDefinition(); + agent.setId(1001L); + agent.setAgentCode("AGENT_1"); + agent.setAgentName("知识助手"); + agent.setSystemPrompt("你是企业知识助手"); + agent.setStoreId(2001L); + agent.setStatus("ENABLED"); + doReturn(agent).when(agentDefinitionService).getById(1001L); + + RagStore store = new RagStore(); + store.setId(2001L); + store.setStoreName("企业知识库"); + when(ragStoreService.getById(2001L)).thenReturn(store); + + RagStoreModelConfig modelConfig = new RagStoreModelConfig(); + modelConfig.setStoreId(2001L); + modelConfig.setEmbeddingModelId(3001L); + modelConfig.setEmbeddingDimension(1024); + when(ragStoreModelConfigService.getActiveEntity(2001L)).thenReturn(modelConfig); + + EmbeddingResult embeddingResult = new EmbeddingResult(); + embeddingResult.setVectors(List.of(List.of(0.12, 0.34, 0.56))); + when(embeddingModelGateway.embed(any(EmbeddingRequest.class))).thenReturn(embeddingResult); + + RagChunkRecallResponse recall = new RagChunkRecallResponse(); + recall.setChunkId(4001L); + recall.setDocumentId(5001L); + recall.setChunkContent("公司请假流程:先提交审批单。"); + recall.setScore(0.91); + when(ragChunkEmbeddingMapper.queryTopKByStore(anyLong(), anyString(), anyInt())) + .thenReturn(List.of(recall)); + + ModelCallLog callLog = new ModelCallLog(); + callLog.setRequestId("req_001"); + ChatResult chatResult = new ChatResult(); + chatResult.setContent("根据知识库,先在OA提交请假审批。"); + chatResult.setCallLog(callLog); + when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult); + + AgentChatRequest request = new AgentChatRequest(); + AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage(); + message.setRole("user"); + message.setContent("请假流程是什么?"); + request.setMessages(List.of(message)); + + AgentChatResponse response = agentDefinitionService.chat(1001L, request); + + assertEquals(1001L, response.getAgentId()); + assertEquals(2001L, response.getStoreId()); + assertEquals("企业知识库", response.getStoreName()); + assertEquals("根据知识库,先在OA提交请假审批。", response.getAnswer()); + assertEquals("req_001", response.getModelRequestId()); + assertEquals(1, response.getReferences().size()); + assertEquals(4001L, response.getReferences().getFirst().getChunkId()); + + ArgumentCaptor embeddingRequestCaptor = ArgumentCaptor.forClass(EmbeddingRequest.class); + verify(embeddingModelGateway).embed(embeddingRequestCaptor.capture()); + EmbeddingRequest embeddingRequest = embeddingRequestCaptor.getValue(); + assertEquals("RAG_QUERY_EMBEDDING", embeddingRequest.getTaskType()); + assertEquals("RAG_STORE", embeddingRequest.getMatchScope()); + assertEquals(2001L, embeddingRequest.getScopeId()); + + verify(ragChunkEmbeddingMapper).queryTopKByStore(anyLong(), anyString(), anyInt()); + } + + @Test + void chatShouldSupportSimpleModeWithoutRagRecall() { + AgentDefinition agent = new AgentDefinition(); + agent.setId(1001L); + agent.setAgentCode("AGENT_1"); + agent.setAgentName("知识助手"); + agent.setStoreId(2001L); + agent.setStatus("ENABLED"); + doReturn(agent).when(agentDefinitionService).getById(1001L); + + RagStore store = new RagStore(); + store.setId(2001L); + store.setStoreName("企业知识库"); + when(ragStoreService.getById(2001L)).thenReturn(store); + + ModelCallLog callLog = new ModelCallLog(); + callLog.setRequestId("req_simple_001"); + ChatResult chatResult = new ChatResult(); + chatResult.setContent("这是普通对话回答。"); + chatResult.setCallLog(callLog); + when(chatModelGateway.chat(any(ChatRequest.class))).thenReturn(chatResult); + + AgentChatRequest request = new AgentChatRequest(); + request.setRagEnabled(false); + AgentChatRequest.AgentMessage message = new AgentChatRequest.AgentMessage(); + message.setRole("user"); + message.setContent("直接聊聊今天安排"); + request.setMessages(List.of(message)); + + AgentChatResponse response = agentDefinitionService.chat(1001L, request); + + assertEquals("这是普通对话回答。", response.getAnswer()); + assertTrue(response.getReferences().isEmpty()); + verify(embeddingModelGateway, never()).embed(any(EmbeddingRequest.class)); + verify(ragChunkEmbeddingMapper, never()).queryTopKByStore(anyLong(), anyString(), anyInt()); + + ArgumentCaptor chatRequestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModelGateway).chat(chatRequestCaptor.capture()); + assertEquals("CHAT_SIMPLE", chatRequestCaptor.getValue().getTaskType()); + } +}