feat:【ai 大模型】支持 Chat Role 的 mcp 配置
This commit is contained in:
@@ -197,6 +197,7 @@
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-starter-mcp-client-webflux</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
|
||||
<!-- TinyFlow:AI 工作流 -->
|
||||
|
||||
@@ -52,6 +52,9 @@ public class AiChatRoleRespVO implements VO {
|
||||
@Schema(description = "引用的工具编号列表", example = "1,2,3")
|
||||
private List<Long> toolIds;
|
||||
|
||||
@Schema(description = "引用的 MCP Client 名字列表", example = "filesystem")
|
||||
private List<String> mcpClientNames;
|
||||
|
||||
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
private Boolean publicStatus;
|
||||
|
||||
|
||||
@@ -37,4 +37,7 @@ public class AiChatRoleSaveMyReqVO {
|
||||
@Schema(description = "引用的工具编号列表", example = "1,2,3")
|
||||
private List<Long> toolIds;
|
||||
|
||||
@Schema(description = "引用的 MCP Client 名字列表", example = "filesystem")
|
||||
private List<String> mcpClientNames;
|
||||
|
||||
}
|
||||
@@ -50,6 +50,9 @@ public class AiChatRoleSaveReqVO {
|
||||
@Schema(description = "引用的工具编号列表", example = "1,2,3")
|
||||
private List<Long> toolIds;
|
||||
|
||||
@Schema(description = "引用的 MCP Client 名字列表", example = "filesystem")
|
||||
private List<String> mcpClientNames;
|
||||
|
||||
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
@NotNull(message = "是否公开不能为空")
|
||||
private Boolean publicStatus;
|
||||
|
||||
@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.model;
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.type.StringListTypeHandler;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableField;
|
||||
@@ -80,6 +81,13 @@ public class AiChatRoleDO extends BaseDO {
|
||||
*/
|
||||
@TableField(typeHandler = LongListTypeHandler.class)
|
||||
private List<Long> toolIds;
|
||||
/**
|
||||
* 引用的 MCP Client 名字列表
|
||||
*
|
||||
* 关联 spring.ai.mcp.client 下的名字
|
||||
*/
|
||||
@TableField(typeHandler = StringListTypeHandler.class)
|
||||
private List<String> mcpClientNames;
|
||||
|
||||
/**
|
||||
* 是否公开
|
||||
|
||||
@@ -36,6 +36,7 @@ import cn.iocoder.yudao.module.ai.service.model.AiToolService;
|
||||
import cn.iocoder.yudao.module.ai.util.AiUtils;
|
||||
import cn.iocoder.yudao.module.infra.framework.file.core.utils.FileTypeUtils;
|
||||
import com.google.common.collect.Maps;
|
||||
import io.modelcontextprotocol.client.McpSyncClient;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
@@ -47,6 +48,10 @@ import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
|
||||
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
@@ -119,6 +124,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||
@Autowired(required = false) // 由于 yudao.ai.web-search.enable 配置项,可以关闭 AiWebSearchClient 的功能,所以这里只能不强制注入
|
||||
private AiWebSearchClient webSearchClient;
|
||||
|
||||
@SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection")
|
||||
@Autowired(required = false) // 由于 yudao.ai.mcp.client.enable 配置项,可以关闭 McpSyncClient 的功能,所以这里只能不强制注入
|
||||
private List<McpSyncClient> mcpClients;
|
||||
|
||||
@SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection")
|
||||
@Autowired(required = false) // 由于 yudao.ai.mcp.client.enable 配置项,可以关闭 McpSyncClient 的功能,所以这里只能不强制注入
|
||||
private McpClientCommonProperties mcpClientCommonProperties;
|
||||
|
||||
@Resource
|
||||
private ToolCallbackResolver toolCallbackResolver;
|
||||
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
||||
// 1.1 校验对话存在
|
||||
@@ -334,23 +350,54 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||
}
|
||||
|
||||
// 2.1 查询 tool 工具
|
||||
Set<String> toolNames = null;
|
||||
Map<String,Object> toolContext = Map.of();
|
||||
if (conversation.getRoleId() != null) {
|
||||
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
|
||||
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
|
||||
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
|
||||
toolContext = AiUtils.buildCommonToolContext();
|
||||
}
|
||||
}
|
||||
List<ToolCallback> toolCallbacks = getToolCallbackListByRoleId(conversation.getRoleId());
|
||||
Map<String,Object> toolContext = CollUtil.isNotEmpty(toolCallbacks) ? AiUtils.buildCommonToolContext()
|
||||
: Map.of();
|
||||
// 2.2 构建 ChatOptions 对象
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
|
||||
conversation.getTemperature(), conversation.getMaxTokens(),
|
||||
toolNames, toolContext);
|
||||
toolCallbacks, toolContext);
|
||||
return new Prompt(chatMessages, chatOptions);
|
||||
}
|
||||
|
||||
private List<ToolCallback> getToolCallbackListByRoleId(Long roleId) {
|
||||
if (roleId == null) {
|
||||
return null;
|
||||
}
|
||||
AiChatRoleDO chatRole = chatRoleService.getChatRole(roleId);
|
||||
if (chatRole == null) {
|
||||
return null;
|
||||
}
|
||||
List<ToolCallback> toolCallbacks = new ArrayList<>();
|
||||
// 1. 通过 toolIds
|
||||
if (CollUtil.isNotEmpty(chatRole.getToolIds())) {
|
||||
Set<String> toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
|
||||
toolNames.forEach(toolName -> {
|
||||
ToolCallback toolCallback = toolCallbackResolver.resolve(toolName);
|
||||
if (toolCallback != null) {
|
||||
toolCallbacks.add(toolCallback);
|
||||
}
|
||||
});
|
||||
}
|
||||
// 2. 通过 mcpClients
|
||||
if (CollUtil.isNotEmpty(mcpClients) && CollUtil.isNotEmpty(chatRole.getMcpClientNames())) {
|
||||
chatRole.getMcpClientNames().forEach(mcpClientName -> {
|
||||
// 2.1 标准化名字,参考 McpClientAutoConfiguration 的 connectedClientName 方法
|
||||
String finalMcpClientName = mcpClientCommonProperties.getName() + " - " + mcpClientName;
|
||||
// 2.2 匹配对应的 McpSyncClient
|
||||
mcpClients.forEach(mcpClient -> {
|
||||
if (ObjUtil.notEqual(mcpClient.getClientInfo().name(), finalMcpClientName)) {
|
||||
return;
|
||||
}
|
||||
ToolCallback[] mcpToolCallBacks = new SyncMcpToolCallbackProvider(mcpClient).getToolCallbacks();
|
||||
CollUtil.addAll(toolCallbacks, mcpToolCallBacks);
|
||||
});
|
||||
});
|
||||
}
|
||||
return toolCallbacks;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
||||
* <p>
|
||||
|
||||
@@ -18,12 +18,10 @@ import org.springframework.ai.deepseek.DeepSeekChatOptions;
|
||||
import org.springframework.ai.minimax.MiniMaxChatOptions;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Spring AI 工具类
|
||||
@@ -40,15 +38,15 @@ public class AiUtils {
|
||||
}
|
||||
|
||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
|
||||
Set<String> toolNames, Map<String, Object> toolContext) {
|
||||
toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
|
||||
List<ToolCallback> toolCallbacks, Map<String, Object> toolContext) {
|
||||
toolCallbacks = ObjUtil.defaultIfNull(toolCallbacks, Collections.emptyList());
|
||||
toolContext = ObjUtil.defaultIfNull(toolContext, Collections.emptyMap());
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
|
||||
.withEnableThinking(true) // TODO 芋艿:默认都开启 thinking 模式,后续可以让用户配置
|
||||
.withToolNames(toolNames).withToolContext(toolContext).build();
|
||||
.withToolCallbacks(toolCallbacks).withToolContext(toolContext).build();
|
||||
case YI_YAN:
|
||||
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
|
||||
case DEEP_SEEK:
|
||||
@@ -57,30 +55,30 @@ public class AiUtils {
|
||||
case SILICON_FLOW: // 复用 DeepSeek 客户端
|
||||
case XING_HUO: // 复用 DeepSeek 客户端
|
||||
return DeepSeekChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||
case ZHI_PU:
|
||||
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||
case MINI_MAX:
|
||||
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||
case MOONSHOT:
|
||||
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||
case OPENAI:
|
||||
case GEMINI: // 复用 OpenAI 客户端
|
||||
case BAI_CHUAN: // 复用 OpenAI 客户端
|
||||
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||
case AZURE_OPENAI:
|
||||
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||
case ANTHROPIC:
|
||||
return AnthropicChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||
case OLLAMA:
|
||||
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
|
||||
.toolNames(toolNames).toolContext(toolContext).build();
|
||||
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
|
||||
@@ -36,44 +36,47 @@ public class DouBaoMcpTests {
|
||||
|
||||
@Test
|
||||
public void testMcpGetUserInfo() {
|
||||
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("目前有哪些工具可以使用")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("小新的年龄是多少")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("获取小新的基本信息")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("小新是什么职业的")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("小新的教育背景")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
|
||||
// 打印结果
|
||||
System.out.println(chatClient.prompt()
|
||||
.user("小新的兴趣爱好是什么")
|
||||
.call()
|
||||
.content());
|
||||
System.out.println("====================================");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user