feat:【ai 大模型】支持 Chat Role 的 mcp 配置

This commit is contained in:
YunaiV
2025-08-28 21:51:59 +08:00
parent 369ca68a35
commit 50b5aeb442
8 changed files with 93 additions and 27 deletions

View File

@@ -197,6 +197,7 @@
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-mcp-client-webflux</artifactId>
<version>${spring-ai.version}</version>
<optional>true</optional>
</dependency>
<!-- TinyFlowAI 工作流 -->

View File

@@ -52,6 +52,9 @@ public class AiChatRoleRespVO implements VO {
@Schema(description = "引用的工具编号列表", example = "1,2,3")
private List<Long> toolIds;
@Schema(description = "引用的 MCP Client 名字列表", example = "filesystem")
private List<String> mcpClientNames;
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Boolean publicStatus;

View File

@@ -37,4 +37,7 @@ public class AiChatRoleSaveMyReqVO {
@Schema(description = "引用的工具编号列表", example = "1,2,3")
private List<Long> toolIds;
@Schema(description = "引用的 MCP Client 名字列表", example = "filesystem")
private List<String> mcpClientNames;
}

View File

@@ -50,6 +50,9 @@ public class AiChatRoleSaveReqVO {
@Schema(description = "引用的工具编号列表", example = "1,2,3")
private List<Long> toolIds;
@Schema(description = "引用的 MCP Client 名字列表", example = "filesystem")
private List<String> mcpClientNames;
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "是否公开不能为空")
private Boolean publicStatus;

View File

@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.model;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
import cn.iocoder.yudao.framework.mybatis.core.type.StringListTypeHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableField;
@@ -80,6 +81,13 @@ public class AiChatRoleDO extends BaseDO {
*/
@TableField(typeHandler = LongListTypeHandler.class)
private List<Long> toolIds;
/**
* 引用的 MCP Client 名字列表
*
* 关联 spring.ai.mcp.client 下的名字
*/
@TableField(typeHandler = StringListTypeHandler.class)
private List<String> mcpClientNames;
/**
* 是否公开

View File

@@ -36,6 +36,7 @@ import cn.iocoder.yudao.module.ai.service.model.AiToolService;
import cn.iocoder.yudao.module.ai.util.AiUtils;
import cn.iocoder.yudao.module.infra.framework.file.core.utils.FileTypeUtils;
import com.google.common.collect.Maps;
import io.modelcontextprotocol.client.McpSyncClient;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
@@ -47,6 +48,10 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -119,6 +124,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Autowired(required = false) // 由于 yudao.ai.web-search.enable 配置项,可以关闭 AiWebSearchClient 的功能,所以这里只能不强制注入
private AiWebSearchClient webSearchClient;
@SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection")
@Autowired(required = false) // 由于 yudao.ai.mcp.client.enable 配置项,可以关闭 McpSyncClient 的功能,所以这里只能不强制注入
private List<McpSyncClient> mcpClients;
@SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection")
@Autowired(required = false) // 由于 yudao.ai.mcp.client.enable 配置项,可以关闭 McpSyncClient 的功能,所以这里只能不强制注入
private McpClientCommonProperties mcpClientCommonProperties;
@Resource
private ToolCallbackResolver toolCallbackResolver;
@Transactional(rollbackFor = Exception.class)
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
// 1.1 校验对话存在
@@ -334,23 +350,54 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
// 2.1 查询 tool 工具
Set<String> toolNames = null;
Map<String,Object> toolContext = Map.of();
if (conversation.getRoleId() != null) {
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
toolContext = AiUtils.buildCommonToolContext();
}
}
List<ToolCallback> toolCallbacks = getToolCallbackListByRoleId(conversation.getRoleId());
Map<String,Object> toolContext = CollUtil.isNotEmpty(toolCallbacks) ? AiUtils.buildCommonToolContext()
: Map.of();
// 2.2 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
conversation.getTemperature(), conversation.getMaxTokens(),
toolNames, toolContext);
toolCallbacks, toolContext);
return new Prompt(chatMessages, chatOptions);
}
private List<ToolCallback> getToolCallbackListByRoleId(Long roleId) {
if (roleId == null) {
return null;
}
AiChatRoleDO chatRole = chatRoleService.getChatRole(roleId);
if (chatRole == null) {
return null;
}
List<ToolCallback> toolCallbacks = new ArrayList<>();
// 1. 通过 toolIds
if (CollUtil.isNotEmpty(chatRole.getToolIds())) {
Set<String> toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
toolNames.forEach(toolName -> {
ToolCallback toolCallback = toolCallbackResolver.resolve(toolName);
if (toolCallback != null) {
toolCallbacks.add(toolCallback);
}
});
}
// 2. 通过 mcpClients
if (CollUtil.isNotEmpty(mcpClients) && CollUtil.isNotEmpty(chatRole.getMcpClientNames())) {
chatRole.getMcpClientNames().forEach(mcpClientName -> {
// 2.1 标准化名字,参考 McpClientAutoConfiguration 的 connectedClientName 方法
String finalMcpClientName = mcpClientCommonProperties.getName() + " - " + mcpClientName;
// 2.2 匹配对应的 McpSyncClient
mcpClients.forEach(mcpClient -> {
if (ObjUtil.notEqual(mcpClient.getClientInfo().name(), finalMcpClientName)) {
return;
}
ToolCallback[] mcpToolCallBacks = new SyncMcpToolCallbackProvider(mcpClient).getToolCallbacks();
CollUtil.addAll(toolCallbacks, mcpToolCallBacks);
});
});
}
return toolCallbacks;
}
/**
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
* <p>

View File

@@ -18,12 +18,10 @@ import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.*;
/**
* Spring AI 工具类
@@ -40,15 +38,15 @@ public class AiUtils {
}
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
Set<String> toolNames, Map<String, Object> toolContext) {
toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
List<ToolCallback> toolCallbacks, Map<String, Object> toolContext) {
toolCallbacks = ObjUtil.defaultIfNull(toolCallbacks, Collections.emptyList());
toolContext = ObjUtil.defaultIfNull(toolContext, Collections.emptyMap());
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
.withEnableThinking(true) // TODO 芋艿:默认都开启 thinking 模式,后续可以让用户配置
.withToolNames(toolNames).withToolContext(toolContext).build();
.withToolCallbacks(toolCallbacks).withToolContext(toolContext).build();
case YI_YAN:
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
case DEEP_SEEK:
@@ -57,30 +55,30 @@ public class AiUtils {
case SILICON_FLOW: // 复用 DeepSeek 客户端
case XING_HUO: // 复用 DeepSeek 客户端
return DeepSeekChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).toolContext(toolContext).build();
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
case ZHI_PU:
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).toolContext(toolContext).build();
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
case MINI_MAX:
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).toolContext(toolContext).build();
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
case MOONSHOT:
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).toolContext(toolContext).build();
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
case OPENAI:
case GEMINI: // 复用 OpenAI 客户端
case BAI_CHUAN: // 复用 OpenAI 客户端
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).toolContext(toolContext).build();
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
case AZURE_OPENAI:
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).toolContext(toolContext).build();
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
case ANTHROPIC:
return AnthropicChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).toolContext(toolContext).build();
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
case OLLAMA:
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
.toolNames(toolNames).toolContext(toolContext).build();
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}

View File

@@ -36,44 +36,47 @@ public class DouBaoMcpTests {
@Test
public void testMcpGetUserInfo() {
// 打印结果
System.out.println(chatClient.prompt()
.user("目前有哪些工具可以使用")
.call()
.content());
System.out.println("====================================");
// 打印结果
System.out.println(chatClient.prompt()
.user("小新的年龄是多少")
.call()
.content());
System.out.println("====================================");
// 打印结果
System.out.println(chatClient.prompt()
.user("获取小新的基本信息")
.call()
.content());
System.out.println("====================================");
// 打印结果
System.out.println(chatClient.prompt()
.user("小新是什么职业的")
.call()
.content());
System.out.println("====================================");
// 打印结果
System.out.println(chatClient.prompt()
.user("小新的教育背景")
.call()
.content());
System.out.println("====================================");
// 打印结果
System.out.println(chatClient.prompt()
.user("小新的兴趣爱好是什么")
.call()
.content());
System.out.println("====================================");
}