feat:【ai 大模型】对话增加附件的支持

This commit is contained in:
YunaiV
2025-08-24 20:32:19 +08:00
parent 3578e0bb5d
commit a5a0383f10
6 changed files with 133 additions and 19 deletions

View File

@@ -20,6 +20,30 @@ tenant-id: {{adminTenantId}}
"content": "1+1=?" "content": "1+1=?"
} }
### 发送消息(流式)【带文件】
POST {{baseUrl}}/ai/chat/message/send-stream
Content-Type: application/json
Authorization: {{token}}
tenant-id: {{adminTenantId}}
{
"conversationId": "1781604279872581797",
"content": "图片里有什么?",
"attachmentUrls": ["http://test.yudao.iocoder.cn/1755531278.jpeg"]
}
### 发送消息(流式)【追问带文件】
POST {{baseUrl}}/ai/chat/message/send-stream
Content-Type: application/json
Authorization: {{token}}
tenant-id: {{adminTenantId}}
{
"conversationId": "1781604279872581797",
"content": "说下图片里,有哪些字?",
"useContext": true
}
### 获得指定对话的消息列表 ### 获得指定对话的消息列表
GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649 GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649
Authorization: {{token}} Authorization: {{token}}

View File

@@ -49,6 +49,9 @@ public class AiChatMessageRespVO {
@Schema(description = "知识库段落数组") @Schema(description = "知识库段落数组")
private List<KnowledgeSegment> segments; private List<KnowledgeSegment> segments;
@Schema(description = "附件 URL 数组", example = "https://www.iocoder.cn/1.png")
private List<String> attachmentUrls;
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51") @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
private LocalDateTime createTime; private LocalDateTime createTime;

View File

@@ -3,9 +3,9 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO") @Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
@Data @Data
@@ -22,4 +22,7 @@ public class AiChatMessageSendReqVO {
@Schema(description = "是否携带上下文", example = "true") @Schema(description = "是否携带上下文", example = "true")
private Boolean useContext; private Boolean useContext;
@Schema(description = "附件 URL 数组", example = "https://www.iocoder.cn/1.png")
private List<String> attachmentUrls;
} }

View File

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler; import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
import cn.iocoder.yudao.framework.mybatis.core.type.StringListTypeHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
@@ -105,4 +106,10 @@ public class AiChatMessageDO extends BaseDO {
@TableField(typeHandler = LongListTypeHandler.class) @TableField(typeHandler = LongListTypeHandler.class)
private List<Long> segmentIds; private List<Long> segmentIds;
/**
* 附件 URL 数组
*/
@TableField(typeHandler = StringListTypeHandler.class)
private List<String> attachmentUrls;
} }

View File

@@ -1,8 +1,11 @@
package cn.iocoder.yudao.module.ai.service.chat; package cn.iocoder.yudao.module.ai.service.chat;
import cn.hutool.core.codec.Base64;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.io.file.FileNameUtil;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@@ -28,6 +31,8 @@ import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService; import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.ai.service.model.AiToolService; import cn.iocoder.yudao.module.ai.service.model.AiToolService;
import cn.iocoder.yudao.module.ai.util.AiUtils; import cn.iocoder.yudao.module.ai.util.AiUtils;
import cn.iocoder.yudao.module.infra.framework.file.core.utils.FileTypeUtils;
import com.google.common.collect.Maps;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.Message;
@@ -64,6 +69,8 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
@Slf4j @Slf4j
public class AiChatMessageServiceImpl implements AiChatMessageService { public class AiChatMessageServiceImpl implements AiChatMessageService {
// TODO @芋艿:后续优化下对话的 Prompt 整体结构
/** /**
* 知识库转 {@link UserMessage} 的内容模版 * 知识库转 {@link UserMessage} 的内容模版
*/ */
@@ -71,6 +78,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
"%s\n\n" + // 多个 <Reference></Reference> 的拼接 "%s\n\n" + // 多个 <Reference></Reference> 的拼接
"回答要求:\n- 避免提及你是从 <Reference></Reference> 获取的知识。"; "回答要求:\n- 避免提及你是从 <Reference></Reference> 获取的知识。";
/**
* 附件转 ${@link UserMessage} 的内容模版
*/
@SuppressWarnings("TextBlockMigration")
private static final String Attachment_USER_MESSAGE_TEMPLATE = "使用 <Attachment></Attachment> 标记用户对话上传的附件内容:\n\n" +
"%s\n\n" + // 多个 <Attachment></Attachment> 的拼接
"回答要求:\n- 避免提及 <Attachment></Attachment> 附件的编码格式。";
@Resource @Resource
private AiChatMessageMapper chatMessageMapper; private AiChatMessageMapper chatMessageMapper;
@@ -101,17 +116,18 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
ChatModel chatModel = modalService.getChatModel(model.getId()); ChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 知识库找回 // 2. 知识库找回
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), conversation); List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(
sendReqVO.getContent(), conversation);
// 3. 插入 user 发送消息 // 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(), userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
null); null, sendReqVO.getAttachmentUrls());
// 3.1 插入 assistant 接收消息 // 3.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(), userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
knowledgeSegments); knowledgeSegments, null);
// 3.2 创建 chat 需要的 Prompt // 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO); Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
@@ -151,18 +167,18 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
StreamingChatModel chatModel = modalService.getChatModel(model.getId()); StreamingChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 知识库找回 // 2. 知识库找回
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(
conversation); sendReqVO.getContent(), conversation);
// 3. 插入 user 发送消息 // 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(), userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
null); null, sendReqVO.getAttachmentUrls());
// 4.1 插入 assistant 接收消息 // 4.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(), userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
knowledgeSegments); knowledgeSegments, null);
// 4.2 构建 Prompt并进行调用 // 4.2 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO); Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
@@ -243,8 +259,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 1.2 历史 history message 历史消息 // 1.2 历史 history message 历史消息
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO); List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages contextMessages.forEach(message -> {
.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()));
UserMessage attachmentUserMessage = buildAttachmentUserMessage(message.getAttachmentUrls());
if (attachmentUserMessage != null) {
chatMessages.add(attachmentUserMessage);
}
});
// 1.3 当前 user message 新发送消息 // 1.3 当前 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent())); chatMessages.add(new UserMessage(sendReqVO.getContent()));
@@ -257,6 +278,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference))); chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
} }
// 1.5 附件,通过 UserMessage 实现
if (CollUtil.isNotEmpty(sendReqVO.getAttachmentUrls())) {
UserMessage attachmentUserMessage = buildAttachmentUserMessage(sendReqVO.getAttachmentUrls());
if (attachmentUserMessage != null) {
chatMessages.add(attachmentUserMessage);
}
}
// 2.1 查询 tool 工具 // 2.1 查询 tool 工具
Set<String> toolNames = null; Set<String> toolNames = null;
Map<String,Object> toolContext = Map.of(); Map<String,Object> toolContext = Map.of();
@@ -314,14 +343,52 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return contextMessages; return contextMessages;
} }
private UserMessage buildAttachmentUserMessage(List<String> attachmentUrls) {
if (CollUtil.isEmpty(attachmentUrls)) {
return null;
}
// 读取文件内容
Map<String, String> attachmentContents = Maps.newLinkedHashMapWithExpectedSize(attachmentUrls.size());
for (String attachmentUrl : attachmentUrls) {
try {
String name = FileNameUtil.getName(attachmentUrl);
String mineType = FileTypeUtils.getMineType(name);
String content;
if (FileTypeUtils.isImage(mineType)) {
// 特殊:图片则转为 Base64
byte[] bytes = HttpUtil.downloadBytes(attachmentUrl);
content = Base64.encode(bytes);
} else {
content = knowledgeDocumentService.readUrl(attachmentUrl);
}
if (StrUtil.isNotEmpty(content)) {
attachmentContents.put(name, content);
}
} catch (Exception e) {
log.error("[buildAttachmentUserMessage][读取附件({}) 发生异常]", attachmentUrl, e);
}
}
if (CollUtil.isEmpty(attachmentContents)) {
return null;
}
// 拼接 UserMessage 消息
String attachment = attachmentContents.entrySet().stream()
.map(entry -> "<Attachment name=\"" + entry.getKey() + "\">" + entry.getValue() + "</Attachment>")
.collect(Collectors.joining("\n\n"));
return new UserMessage(String.format(Attachment_USER_MESSAGE_TEMPLATE, attachment));
}
private AiChatMessageDO createChatMessage(Long conversationId, Long replyId, private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
AiModelDO model, Long userId, Long roleId, AiModelDO model, Long userId, Long roleId,
MessageType messageType, String content, Boolean useContext, MessageType messageType, String content, Boolean useContext,
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments) { List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
List<String> attachmentUrls) {
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId) AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId) .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
.setType(messageType.getValue()).setContent(content).setUseContext(useContext) .setType(messageType.getValue()).setContent(content).setUseContext(useContext)
.setSegmentIds(convertList(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getId)); .setSegmentIds(convertList(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getId))
.setAttachmentUrls(attachmentUrls);
message.setCreateTime(LocalDateTime.now()); message.setCreateTime(LocalDateTime.now());
chatMessageMapper.insert(message); chatMessageMapper.insert(message);
return message; return message;

View File

@@ -80,17 +80,17 @@ public class FileTypeUtils {
*/ */
public static void writeAttachment(HttpServletResponse response, String filename, byte[] content) throws IOException { public static void writeAttachment(HttpServletResponse response, String filename, byte[] content) throws IOException {
// 设置 header 和 contentType // 设置 header 和 contentType
String contentType = getMineType(content, filename); String mineType = getMineType(content, filename);
response.setContentType(contentType); response.setContentType(mineType);
// 设置内容显示、下载文件名https://www.cnblogs.com/wq-9/articles/12165056.html // 设置内容显示、下载文件名https://www.cnblogs.com/wq-9/articles/12165056.html
if (StrUtil.containsIgnoreCase(contentType, "image/")) { if (isImage(mineType)) {
// 参见 https://github.com/YunaiV/ruoyi-vue-pro/issues/692 讨论 // 参见 https://github.com/YunaiV/ruoyi-vue-pro/issues/692 讨论
response.setHeader("Content-Disposition", "inline;filename=" + HttpUtils.encodeUtf8(filename)); response.setHeader("Content-Disposition", "inline;filename=" + HttpUtils.encodeUtf8(filename));
} else { } else {
response.setHeader("Content-Disposition", "attachment;filename=" + HttpUtils.encodeUtf8(filename)); response.setHeader("Content-Disposition", "attachment;filename=" + HttpUtils.encodeUtf8(filename));
} }
// 针对 video 的特殊处理,解决视频地址在移动端播放的兼容性问题 // 针对 video 的特殊处理,解决视频地址在移动端播放的兼容性问题
if (StrUtil.containsIgnoreCase(contentType, "video")) { if (StrUtil.containsIgnoreCase(mineType, "video")) {
response.setHeader("Content-Length", String.valueOf(content.length)); response.setHeader("Content-Length", String.valueOf(content.length));
response.setHeader("Content-Range", "bytes 0-" + (content.length - 1) + "/" + content.length); response.setHeader("Content-Range", "bytes 0-" + (content.length - 1) + "/" + content.length);
response.setHeader("Accept-Ranges", "bytes"); response.setHeader("Accept-Ranges", "bytes");
@@ -99,4 +99,14 @@ public class FileTypeUtils {
IoUtil.write(response.getOutputStream(), false, content); IoUtil.write(response.getOutputStream(), false, content);
} }
/**
* 判断是否是图片
*
* @param mineType 类型
* @return 是否是图片
*/
public static boolean isImage(String mineType) {
return StrUtil.startWith(mineType, "image/");
}
} }