From 50b5aeb4428643f543573083ccfe7a8cf7942dfa Mon Sep 17 00:00:00 2001 From: YunaiV Date: Thu, 28 Aug 2025 21:51:59 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E3=80=90ai=20=E5=A4=A7=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E3=80=91=E6=94=AF=E6=8C=81=20Chat=20Role=20=E7=9A=84?= =?UTF-8?q?=20mcp=20=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yudao-module-ai/pom.xml | 1 + .../model/vo/chatRole/AiChatRoleRespVO.java | 3 + .../vo/chatRole/AiChatRoleSaveMyReqVO.java | 3 + .../vo/chatRole/AiChatRoleSaveReqVO.java | 3 + .../ai/dal/dataobject/model/AiChatRoleDO.java | 8 +++ .../chat/AiChatMessageServiceImpl.java | 67 ++++++++++++++++--- .../iocoder/yudao/module/ai/util/AiUtils.java | 28 ++++---- .../ai/core/model/mcp/DouBaoMcpTests.java | 7 +- 8 files changed, 93 insertions(+), 27 deletions(-) diff --git a/yudao-module-ai/pom.xml b/yudao-module-ai/pom.xml index 5d112e8c7c..6bad77e24b 100644 --- a/yudao-module-ai/pom.xml +++ b/yudao-module-ai/pom.xml @@ -197,6 +197,7 @@ org.springframework.ai spring-ai-starter-mcp-client-webflux ${spring-ai.version} + true diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java index 51e44ed760..2ef9565cc2 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java @@ -52,6 +52,9 @@ public class AiChatRoleRespVO implements VO { @Schema(description = "引用的工具编号列表", example = "1,2,3") private List toolIds; + @Schema(description = "引用的 MCP Client 名字列表", example = "filesystem") + private List mcpClientNames; + @Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") private Boolean publicStatus; diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveMyReqVO.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveMyReqVO.java index 009e8d8afb..bd4a05723c 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveMyReqVO.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveMyReqVO.java @@ -37,4 +37,7 @@ public class AiChatRoleSaveMyReqVO { @Schema(description = "引用的工具编号列表", example = "1,2,3") private List toolIds; + @Schema(description = "引用的 MCP Client 名字列表", example = "filesystem") + private List mcpClientNames; + } \ No newline at end of file diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveReqVO.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveReqVO.java index 3c72cf9834..8f2913dd52 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveReqVO.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveReqVO.java @@ -50,6 +50,9 @@ public class AiChatRoleSaveReqVO { @Schema(description = "引用的工具编号列表", example = "1,2,3") private List toolIds; + @Schema(description = "引用的 MCP Client 名字列表", example = "filesystem") + private List mcpClientNames; + @Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @NotNull(message = "是否公开不能为空") private Boolean publicStatus; diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java index bb6a3ca48d..d20b25e884 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java @@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.model; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler; +import cn.iocoder.yudao.framework.mybatis.core.type.StringListTypeHandler; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.TableField; @@ -80,6 +81,13 @@ public class AiChatRoleDO extends BaseDO { */ @TableField(typeHandler = LongListTypeHandler.class) private List toolIds; + /** + * 引用的 MCP Client 名字列表 + * + * 关联 spring.ai.mcp.client 下的名字 + */ + @TableField(typeHandler = StringListTypeHandler.class) + private List mcpClientNames; /** * 是否公开 diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 541e860f07..0ce0d15bda 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -36,6 +36,7 @@ import cn.iocoder.yudao.module.ai.service.model.AiToolService; import cn.iocoder.yudao.module.ai.util.AiUtils; import cn.iocoder.yudao.module.infra.framework.file.core.utils.FileTypeUtils; import com.google.common.collect.Maps; +import io.modelcontextprotocol.client.McpSyncClient; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; @@ -47,6 +48,10 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; +import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -119,6 +124,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { @Autowired(required = false) // 由于 yudao.ai.web-search.enable 配置项,可以关闭 AiWebSearchClient 的功能,所以这里只能不强制注入 private AiWebSearchClient webSearchClient; + @SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection") + @Autowired(required = false) // 由于 yudao.ai.mcp.client.enable 配置项,可以关闭 McpSyncClient 的功能,所以这里只能不强制注入 + private List mcpClients; + + @SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection") + @Autowired(required = false) // 由于 yudao.ai.mcp.client.enable 配置项,可以关闭 McpSyncClient 的功能,所以这里只能不强制注入 + private McpClientCommonProperties mcpClientCommonProperties; + + @Resource + private ToolCallbackResolver toolCallbackResolver; + @Transactional(rollbackFor = Exception.class) public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { // 1.1 校验对话存在 @@ -334,23 +350,54 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { } // 2.1 查询 tool 工具 - Set toolNames = null; - Map toolContext = Map.of(); - if (conversation.getRoleId() != null) { - AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId()); - if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) { - toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName); - toolContext = AiUtils.buildCommonToolContext(); - } - } + List toolCallbacks = getToolCallbackListByRoleId(conversation.getRoleId()); + Map toolContext = CollUtil.isNotEmpty(toolCallbacks) ? AiUtils.buildCommonToolContext() + : Map.of(); // 2.2 构建 ChatOptions 对象 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), conversation.getTemperature(), conversation.getMaxTokens(), - toolNames, toolContext); + toolCallbacks, toolContext); return new Prompt(chatMessages, chatOptions); } + private List getToolCallbackListByRoleId(Long roleId) { + if (roleId == null) { + return null; + } + AiChatRoleDO chatRole = chatRoleService.getChatRole(roleId); + if (chatRole == null) { + return null; + } + List toolCallbacks = new ArrayList<>(); + // 1. 通过 toolIds + if (CollUtil.isNotEmpty(chatRole.getToolIds())) { + Set toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName); + toolNames.forEach(toolName -> { + ToolCallback toolCallback = toolCallbackResolver.resolve(toolName); + if (toolCallback != null) { + toolCallbacks.add(toolCallback); + } + }); + } + // 2. 通过 mcpClients + if (CollUtil.isNotEmpty(mcpClients) && CollUtil.isNotEmpty(chatRole.getMcpClientNames())) { + chatRole.getMcpClientNames().forEach(mcpClientName -> { + // 2.1 标准化名字,参考 McpClientAutoConfiguration 的 connectedClientName 方法 + String finalMcpClientName = mcpClientCommonProperties.getName() + " - " + mcpClientName; + // 2.2 匹配对应的 McpSyncClient + mcpClients.forEach(mcpClient -> { + if (ObjUtil.notEqual(mcpClient.getClientInfo().name(), finalMcpClientName)) { + return; + } + ToolCallback[] mcpToolCallBacks = new SyncMcpToolCallbackProvider(mcpClient).getToolCallbacks(); + CollUtil.addAll(toolCallbacks, mcpToolCallBacks); + }); + }); + } + return toolCallbacks; + } + /** * 从历史消息中,获得倒序的 n 组消息作为消息上下文 *

diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/util/AiUtils.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/util/AiUtils.java index 35fb26d2cf..d209c62d44 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/util/AiUtils.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/util/AiUtils.java @@ -18,12 +18,10 @@ import org.springframework.ai.deepseek.DeepSeekChatOptions; import org.springframework.ai.minimax.MiniMaxChatOptions; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; +import java.util.*; /** * Spring AI 工具类 @@ -40,15 +38,15 @@ public class AiUtils { } public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens, - Set toolNames, Map toolContext) { - toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet()); + List toolCallbacks, Map toolContext) { + toolCallbacks = ObjUtil.defaultIfNull(toolCallbacks, Collections.emptyList()); toolContext = ObjUtil.defaultIfNull(toolContext, Collections.emptyMap()); // noinspection EnhancedSwitchMigration switch (platform) { case TONG_YI: return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens) .withEnableThinking(true) // TODO 芋艿:默认都开启 thinking 模式,后续可以让用户配置 - .withToolNames(toolNames).withToolContext(toolContext).build(); + .withToolCallbacks(toolCallbacks).withToolContext(toolContext).build(); case YI_YAN: return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build(); case DEEP_SEEK: @@ -57,30 +55,30 @@ public class AiUtils { case SILICON_FLOW: // 复用 DeepSeek 客户端 case XING_HUO: // 复用 DeepSeek 客户端 return DeepSeekChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .toolNames(toolNames).toolContext(toolContext).build(); + .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); case ZHI_PU: return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .toolNames(toolNames).toolContext(toolContext).build(); + .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); case MINI_MAX: return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .toolNames(toolNames).toolContext(toolContext).build(); + .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); case MOONSHOT: return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .toolNames(toolNames).toolContext(toolContext).build(); + .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); case OPENAI: case GEMINI: // 复用 OpenAI 客户端 case BAI_CHUAN: // 复用 OpenAI 客户端 return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .toolNames(toolNames).toolContext(toolContext).build(); + .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); case AZURE_OPENAI: return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens) - .toolNames(toolNames).toolContext(toolContext).build(); + .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); case ANTHROPIC: return AnthropicChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens) - .toolNames(toolNames).toolContext(toolContext).build(); + .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); case OLLAMA: return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens) - .toolNames(toolNames).toolContext(toolContext).build(); + .toolCallbacks(toolCallbacks).toolContext(toolContext).build(); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } diff --git a/yudao-module-ai/src/test/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/mcp/DouBaoMcpTests.java b/yudao-module-ai/src/test/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/mcp/DouBaoMcpTests.java index b50caef5e2..674553ee1e 100644 --- a/yudao-module-ai/src/test/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/mcp/DouBaoMcpTests.java +++ b/yudao-module-ai/src/test/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/mcp/DouBaoMcpTests.java @@ -36,44 +36,47 @@ public class DouBaoMcpTests { @Test public void testMcpGetUserInfo() { - // 打印结果 System.out.println(chatClient.prompt() .user("目前有哪些工具可以使用") .call() .content()); System.out.println("===================================="); + // 打印结果 System.out.println(chatClient.prompt() .user("小新的年龄是多少") .call() .content()); System.out.println("===================================="); + // 打印结果 System.out.println(chatClient.prompt() .user("获取小新的基本信息") .call() .content()); System.out.println("===================================="); + // 打印结果 System.out.println(chatClient.prompt() .user("小新是什么职业的") .call() .content()); System.out.println("===================================="); + // 打印结果 System.out.println(chatClient.prompt() .user("小新的教育背景") .call() .content()); System.out.println("===================================="); + // 打印结果 System.out.println(chatClient.prompt() .user("小新的兴趣爱好是什么") .call() .content()); System.out.println("===================================="); - }