feat:【ai 大模型】支持 Chat Role 的 mcp 配置
This commit is contained in:
@@ -197,6 +197,7 @@
|
|||||||
<groupId>org.springframework.ai</groupId>
|
<groupId>org.springframework.ai</groupId>
|
||||||
<artifactId>spring-ai-starter-mcp-client-webflux</artifactId>
|
<artifactId>spring-ai-starter-mcp-client-webflux</artifactId>
|
||||||
<version>${spring-ai.version}</version>
|
<version>${spring-ai.version}</version>
|
||||||
|
<optional>true</optional>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<!-- TinyFlow:AI 工作流 -->
|
<!-- TinyFlow:AI 工作流 -->
|
||||||
|
|||||||
@@ -52,6 +52,9 @@ public class AiChatRoleRespVO implements VO {
|
|||||||
@Schema(description = "引用的工具编号列表", example = "1,2,3")
|
@Schema(description = "引用的工具编号列表", example = "1,2,3")
|
||||||
private List<Long> toolIds;
|
private List<Long> toolIds;
|
||||||
|
|
||||||
|
@Schema(description = "引用的 MCP Client 名字列表", example = "filesystem")
|
||||||
|
private List<String> mcpClientNames;
|
||||||
|
|
||||||
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||||
private Boolean publicStatus;
|
private Boolean publicStatus;
|
||||||
|
|
||||||
|
|||||||
@@ -37,4 +37,7 @@ public class AiChatRoleSaveMyReqVO {
|
|||||||
@Schema(description = "引用的工具编号列表", example = "1,2,3")
|
@Schema(description = "引用的工具编号列表", example = "1,2,3")
|
||||||
private List<Long> toolIds;
|
private List<Long> toolIds;
|
||||||
|
|
||||||
|
@Schema(description = "引用的 MCP Client 名字列表", example = "filesystem")
|
||||||
|
private List<String> mcpClientNames;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -50,6 +50,9 @@ public class AiChatRoleSaveReqVO {
|
|||||||
@Schema(description = "引用的工具编号列表", example = "1,2,3")
|
@Schema(description = "引用的工具编号列表", example = "1,2,3")
|
||||||
private List<Long> toolIds;
|
private List<Long> toolIds;
|
||||||
|
|
||||||
|
@Schema(description = "引用的 MCP Client 名字列表", example = "filesystem")
|
||||||
|
private List<String> mcpClientNames;
|
||||||
|
|
||||||
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
@Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||||
@NotNull(message = "是否公开不能为空")
|
@NotNull(message = "是否公开不能为空")
|
||||||
private Boolean publicStatus;
|
private Boolean publicStatus;
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.model;
|
|||||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||||
import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
|
import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
|
||||||
|
import cn.iocoder.yudao.framework.mybatis.core.type.StringListTypeHandler;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
|
||||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||||
import com.baomidou.mybatisplus.annotation.TableField;
|
import com.baomidou.mybatisplus.annotation.TableField;
|
||||||
@@ -80,6 +81,13 @@ public class AiChatRoleDO extends BaseDO {
|
|||||||
*/
|
*/
|
||||||
@TableField(typeHandler = LongListTypeHandler.class)
|
@TableField(typeHandler = LongListTypeHandler.class)
|
||||||
private List<Long> toolIds;
|
private List<Long> toolIds;
|
||||||
|
/**
|
||||||
|
* 引用的 MCP Client 名字列表
|
||||||
|
*
|
||||||
|
* 关联 spring.ai.mcp.client 下的名字
|
||||||
|
*/
|
||||||
|
@TableField(typeHandler = StringListTypeHandler.class)
|
||||||
|
private List<String> mcpClientNames;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 是否公开
|
* 是否公开
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ import cn.iocoder.yudao.module.ai.service.model.AiToolService;
|
|||||||
import cn.iocoder.yudao.module.ai.util.AiUtils;
|
import cn.iocoder.yudao.module.ai.util.AiUtils;
|
||||||
import cn.iocoder.yudao.module.infra.framework.file.core.utils.FileTypeUtils;
|
import cn.iocoder.yudao.module.infra.framework.file.core.utils.FileTypeUtils;
|
||||||
import com.google.common.collect.Maps;
|
import com.google.common.collect.Maps;
|
||||||
|
import io.modelcontextprotocol.client.McpSyncClient;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.chat.messages.Message;
|
import org.springframework.ai.chat.messages.Message;
|
||||||
@@ -47,6 +48,10 @@ import org.springframework.ai.chat.model.ChatResponse;
|
|||||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||||
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
|
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
|
||||||
|
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
|
||||||
|
import org.springframework.ai.tool.ToolCallback;
|
||||||
|
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
@@ -119,6 +124,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
@Autowired(required = false) // 由于 yudao.ai.web-search.enable 配置项,可以关闭 AiWebSearchClient 的功能,所以这里只能不强制注入
|
@Autowired(required = false) // 由于 yudao.ai.web-search.enable 配置项,可以关闭 AiWebSearchClient 的功能,所以这里只能不强制注入
|
||||||
private AiWebSearchClient webSearchClient;
|
private AiWebSearchClient webSearchClient;
|
||||||
|
|
||||||
|
@SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection")
|
||||||
|
@Autowired(required = false) // 由于 yudao.ai.mcp.client.enable 配置项,可以关闭 McpSyncClient 的功能,所以这里只能不强制注入
|
||||||
|
private List<McpSyncClient> mcpClients;
|
||||||
|
|
||||||
|
@SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection")
|
||||||
|
@Autowired(required = false) // 由于 yudao.ai.mcp.client.enable 配置项,可以关闭 McpSyncClient 的功能,所以这里只能不强制注入
|
||||||
|
private McpClientCommonProperties mcpClientCommonProperties;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private ToolCallbackResolver toolCallbackResolver;
|
||||||
|
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
||||||
// 1.1 校验对话存在
|
// 1.1 校验对话存在
|
||||||
@@ -334,23 +350,54 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2.1 查询 tool 工具
|
// 2.1 查询 tool 工具
|
||||||
Set<String> toolNames = null;
|
List<ToolCallback> toolCallbacks = getToolCallbackListByRoleId(conversation.getRoleId());
|
||||||
Map<String,Object> toolContext = Map.of();
|
Map<String,Object> toolContext = CollUtil.isNotEmpty(toolCallbacks) ? AiUtils.buildCommonToolContext()
|
||||||
if (conversation.getRoleId() != null) {
|
: Map.of();
|
||||||
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
|
|
||||||
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
|
|
||||||
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
|
|
||||||
toolContext = AiUtils.buildCommonToolContext();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 2.2 构建 ChatOptions 对象
|
// 2.2 构建 ChatOptions 对象
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
|
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
|
||||||
conversation.getTemperature(), conversation.getMaxTokens(),
|
conversation.getTemperature(), conversation.getMaxTokens(),
|
||||||
toolNames, toolContext);
|
toolCallbacks, toolContext);
|
||||||
return new Prompt(chatMessages, chatOptions);
|
return new Prompt(chatMessages, chatOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<ToolCallback> getToolCallbackListByRoleId(Long roleId) {
|
||||||
|
if (roleId == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
AiChatRoleDO chatRole = chatRoleService.getChatRole(roleId);
|
||||||
|
if (chatRole == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
List<ToolCallback> toolCallbacks = new ArrayList<>();
|
||||||
|
// 1. 通过 toolIds
|
||||||
|
if (CollUtil.isNotEmpty(chatRole.getToolIds())) {
|
||||||
|
Set<String> toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
|
||||||
|
toolNames.forEach(toolName -> {
|
||||||
|
ToolCallback toolCallback = toolCallbackResolver.resolve(toolName);
|
||||||
|
if (toolCallback != null) {
|
||||||
|
toolCallbacks.add(toolCallback);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// 2. 通过 mcpClients
|
||||||
|
if (CollUtil.isNotEmpty(mcpClients) && CollUtil.isNotEmpty(chatRole.getMcpClientNames())) {
|
||||||
|
chatRole.getMcpClientNames().forEach(mcpClientName -> {
|
||||||
|
// 2.1 标准化名字,参考 McpClientAutoConfiguration 的 connectedClientName 方法
|
||||||
|
String finalMcpClientName = mcpClientCommonProperties.getName() + " - " + mcpClientName;
|
||||||
|
// 2.2 匹配对应的 McpSyncClient
|
||||||
|
mcpClients.forEach(mcpClient -> {
|
||||||
|
if (ObjUtil.notEqual(mcpClient.getClientInfo().name(), finalMcpClientName)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ToolCallback[] mcpToolCallBacks = new SyncMcpToolCallbackProvider(mcpClient).getToolCallbacks();
|
||||||
|
CollUtil.addAll(toolCallbacks, mcpToolCallBacks);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return toolCallbacks;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
||||||
* <p>
|
* <p>
|
||||||
|
|||||||
@@ -18,12 +18,10 @@ import org.springframework.ai.deepseek.DeepSeekChatOptions;
|
|||||||
import org.springframework.ai.minimax.MiniMaxChatOptions;
|
import org.springframework.ai.minimax.MiniMaxChatOptions;
|
||||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||||
|
import org.springframework.ai.tool.ToolCallback;
|
||||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.*;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Spring AI 工具类
|
* Spring AI 工具类
|
||||||
@@ -40,15 +38,15 @@ public class AiUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
|
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
|
||||||
Set<String> toolNames, Map<String, Object> toolContext) {
|
List<ToolCallback> toolCallbacks, Map<String, Object> toolContext) {
|
||||||
toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
|
toolCallbacks = ObjUtil.defaultIfNull(toolCallbacks, Collections.emptyList());
|
||||||
toolContext = ObjUtil.defaultIfNull(toolContext, Collections.emptyMap());
|
toolContext = ObjUtil.defaultIfNull(toolContext, Collections.emptyMap());
|
||||||
// noinspection EnhancedSwitchMigration
|
// noinspection EnhancedSwitchMigration
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
case TONG_YI:
|
case TONG_YI:
|
||||||
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
|
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
|
||||||
.withEnableThinking(true) // TODO 芋艿:默认都开启 thinking 模式,后续可以让用户配置
|
.withEnableThinking(true) // TODO 芋艿:默认都开启 thinking 模式,后续可以让用户配置
|
||||||
.withToolNames(toolNames).withToolContext(toolContext).build();
|
.withToolCallbacks(toolCallbacks).withToolContext(toolContext).build();
|
||||||
case YI_YAN:
|
case YI_YAN:
|
||||||
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
|
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
|
||||||
case DEEP_SEEK:
|
case DEEP_SEEK:
|
||||||
@@ -57,30 +55,30 @@ public class AiUtils {
|
|||||||
case SILICON_FLOW: // 复用 DeepSeek 客户端
|
case SILICON_FLOW: // 复用 DeepSeek 客户端
|
||||||
case XING_HUO: // 复用 DeepSeek 客户端
|
case XING_HUO: // 复用 DeepSeek 客户端
|
||||||
return DeepSeekChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return DeepSeekChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.toolNames(toolNames).toolContext(toolContext).build();
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
case ZHI_PU:
|
case ZHI_PU:
|
||||||
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.toolNames(toolNames).toolContext(toolContext).build();
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
case MINI_MAX:
|
case MINI_MAX:
|
||||||
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.toolNames(toolNames).toolContext(toolContext).build();
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
case MOONSHOT:
|
case MOONSHOT:
|
||||||
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.toolNames(toolNames).toolContext(toolContext).build();
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
case OPENAI:
|
case OPENAI:
|
||||||
case GEMINI: // 复用 OpenAI 客户端
|
case GEMINI: // 复用 OpenAI 客户端
|
||||||
case BAI_CHUAN: // 复用 OpenAI 客户端
|
case BAI_CHUAN: // 复用 OpenAI 客户端
|
||||||
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.toolNames(toolNames).toolContext(toolContext).build();
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
case AZURE_OPENAI:
|
case AZURE_OPENAI:
|
||||||
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
|
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.toolNames(toolNames).toolContext(toolContext).build();
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
case ANTHROPIC:
|
case ANTHROPIC:
|
||||||
return AnthropicChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
return AnthropicChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
.toolNames(toolNames).toolContext(toolContext).build();
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
case OLLAMA:
|
case OLLAMA:
|
||||||
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
|
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
|
||||||
.toolNames(toolNames).toolContext(toolContext).build();
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,44 +36,47 @@ public class DouBaoMcpTests {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMcpGetUserInfo() {
|
public void testMcpGetUserInfo() {
|
||||||
|
|
||||||
// 打印结果
|
// 打印结果
|
||||||
System.out.println(chatClient.prompt()
|
System.out.println(chatClient.prompt()
|
||||||
.user("目前有哪些工具可以使用")
|
.user("目前有哪些工具可以使用")
|
||||||
.call()
|
.call()
|
||||||
.content());
|
.content());
|
||||||
System.out.println("====================================");
|
System.out.println("====================================");
|
||||||
|
|
||||||
// 打印结果
|
// 打印结果
|
||||||
System.out.println(chatClient.prompt()
|
System.out.println(chatClient.prompt()
|
||||||
.user("小新的年龄是多少")
|
.user("小新的年龄是多少")
|
||||||
.call()
|
.call()
|
||||||
.content());
|
.content());
|
||||||
System.out.println("====================================");
|
System.out.println("====================================");
|
||||||
|
|
||||||
// 打印结果
|
// 打印结果
|
||||||
System.out.println(chatClient.prompt()
|
System.out.println(chatClient.prompt()
|
||||||
.user("获取小新的基本信息")
|
.user("获取小新的基本信息")
|
||||||
.call()
|
.call()
|
||||||
.content());
|
.content());
|
||||||
System.out.println("====================================");
|
System.out.println("====================================");
|
||||||
|
|
||||||
// 打印结果
|
// 打印结果
|
||||||
System.out.println(chatClient.prompt()
|
System.out.println(chatClient.prompt()
|
||||||
.user("小新是什么职业的")
|
.user("小新是什么职业的")
|
||||||
.call()
|
.call()
|
||||||
.content());
|
.content());
|
||||||
System.out.println("====================================");
|
System.out.println("====================================");
|
||||||
|
|
||||||
// 打印结果
|
// 打印结果
|
||||||
System.out.println(chatClient.prompt()
|
System.out.println(chatClient.prompt()
|
||||||
.user("小新的教育背景")
|
.user("小新的教育背景")
|
||||||
.call()
|
.call()
|
||||||
.content());
|
.content());
|
||||||
System.out.println("====================================");
|
System.out.println("====================================");
|
||||||
|
|
||||||
// 打印结果
|
// 打印结果
|
||||||
System.out.println(chatClient.prompt()
|
System.out.println(chatClient.prompt()
|
||||||
.user("小新的兴趣爱好是什么")
|
.user("小新的兴趣爱好是什么")
|
||||||
.call()
|
.call()
|
||||||
.content());
|
.content());
|
||||||
System.out.println("====================================");
|
System.out.println("====================================");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user