@@ -1,8 +1,11 @@
package cn.iocoder.yudao.module.ai.service.chat ;
import cn.hutool.core.codec.Base64 ;
import cn.hutool.core.collection.CollUtil ;
import cn.hutool.core.io.file.FileNameUtil ;
import cn.hutool.core.util.ObjUtil ;
import cn.hutool.core.util.StrUtil ;
import cn.hutool.http.HttpUtil ;
import cn.iocoder.yudao.framework.common.pojo.CommonResult ;
import cn.iocoder.yudao.framework.common.pojo.PageResult ;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils ;
@@ -28,6 +31,8 @@ import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService ;
import cn.iocoder.yudao.module.ai.service.model.AiToolService ;
import cn.iocoder.yudao.module.ai.util.AiUtils ;
import cn.iocoder.yudao.module.infra.framework.file.core.utils.FileTypeUtils ;
import com.google.common.collect.Maps ;
import jakarta.annotation.Resource ;
import lombok.extern.slf4j.Slf4j ;
import org.springframework.ai.chat.messages.Message ;
@@ -64,6 +69,8 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
@Slf4j
public class AiChatMessageServiceImpl implements AiChatMessageService {
// TODO @芋艿:后续优化下对话的 Prompt 整体结构
/**
* 知识库转 {@link UserMessage} 的内容模版
*/
@@ -71,6 +78,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
" %s \ n \ n " + // 多个 <Reference></Reference> 的拼接
" 回答要求: \ n- 避免提及你是从 <Reference></Reference> 获取的知识。 " ;
/**
* 附件转 ${@link UserMessage} 的内容模版
*/
@SuppressWarnings ( " TextBlockMigration " )
private static final String Attachment_USER_MESSAGE_TEMPLATE = " 使用 <Attachment></Attachment> 标记用户对话上传的附件内容: \ n \ n " +
" %s \ n \ n " + // 多个 <Attachment></Attachment> 的拼接
" 回答要求: \ n- 避免提及 <Attachment></Attachment> 附件的编码格式。 " ;
@Resource
private AiChatMessageMapper chatMessageMapper ;
@@ -101,17 +116,18 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
ChatModel chatModel = modalService . getChatModel ( model . getId ( ) ) ;
// 2. 知识库找回
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = recallKnowledgeSegment ( sendReqVO . getContent ( ) , conversation ) ;
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = recallKnowledgeSegment (
sendReqVO . getContent ( ) , conversation ) ;
// 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage ( conversation . getId ( ) , null , model ,
userId , conversation . getRoleId ( ) , MessageType . USER , sendReqVO . getContent ( ) , sendReqVO . getUseContext ( ) ,
null ) ;
null , sendReqVO . getAttachmentUrls ( ) );
// 3.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage ( conversation . getId ( ) , userMessage . getId ( ) , model ,
userId , conversation . getRoleId ( ) , MessageType . ASSISTANT , " " , sendReqVO . getUseContext ( ) ,
knowledgeSegments ) ;
knowledgeSegments , null );
// 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt ( conversation , historyMessages , knowledgeSegments , model , sendReqVO ) ;
@@ -151,18 +167,18 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
StreamingChatModel chatModel = modalService . getChatModel ( model . getId ( ) ) ;
// 2. 知识库找回
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = recallKnowledgeSegment ( sendReqVO . getContent ( ) ,
conversation ) ;
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = recallKnowledgeSegment (
sendReqVO . getContent ( ) , conversation) ;
// 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage ( conversation . getId ( ) , null , model ,
userId , conversation . getRoleId ( ) , MessageType . USER , sendReqVO . getContent ( ) , sendReqVO . getUseContext ( ) ,
null ) ;
null , sendReqVO . getAttachmentUrls ( ) );
// 4.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage ( conversation . getId ( ) , userMessage . getId ( ) , model ,
userId , conversation . getRoleId ( ) , MessageType . ASSISTANT , " " , sendReqVO . getUseContext ( ) ,
knowledgeSegments ) ;
knowledgeSegments , null );
// 4.2 构建 Prompt, 并进行调用
Prompt prompt = buildPrompt ( conversation , historyMessages , knowledgeSegments , model , sendReqVO ) ;
@@ -243,8 +259,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 1.2 历史 history message 历史消息
List < AiChatMessageDO > contextMessages = filterContextMessages ( messages , conversation , sendReqVO ) ;
contextMessages
. forEach ( message - > chatMessages . add ( AiUtils . buildMessage ( message . getType ( ) , message . getContent ( ) ) ) ) ;
contextMessages . forEach ( message - > {
chatMessages . add ( AiUtils . buildMessage ( message . getType ( ) , message . getContent ( ) ) ) ;
UserMessage attachmentUserMessage = buildAttachmentUserMessage ( message . getAttachmentUrls ( ) ) ;
if ( attachmentUserMessage ! = null ) {
chatMessages . add ( attachmentUserMessage ) ;
}
} ) ;
// 1.3 当前 user message 新发送消息
chatMessages . add ( new UserMessage ( sendReqVO . getContent ( ) ) ) ;
@@ -257,6 +278,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
chatMessages . add ( new UserMessage ( String . format ( KNOWLEDGE_USER_MESSAGE_TEMPLATE , reference ) ) ) ;
}
// 1.5 附件,通过 UserMessage 实现
if ( CollUtil . isNotEmpty ( sendReqVO . getAttachmentUrls ( ) ) ) {
UserMessage attachmentUserMessage = buildAttachmentUserMessage ( sendReqVO . getAttachmentUrls ( ) ) ;
if ( attachmentUserMessage ! = null ) {
chatMessages . add ( attachmentUserMessage ) ;
}
}
// 2.1 查询 tool 工具
Set < String > toolNames = null ;
Map < String , Object > toolContext = Map . of ( ) ;
@@ -314,14 +343,52 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return contextMessages ;
}
private UserMessage buildAttachmentUserMessage ( List < String > attachmentUrls ) {
if ( CollUtil . isEmpty ( attachmentUrls ) ) {
return null ;
}
// 读取文件内容
Map < String , String > attachmentContents = Maps . newLinkedHashMapWithExpectedSize ( attachmentUrls . size ( ) ) ;
for ( String attachmentUrl : attachmentUrls ) {
try {
String name = FileNameUtil . getName ( attachmentUrl ) ;
String mineType = FileTypeUtils . getMineType ( name ) ;
String content ;
if ( FileTypeUtils . isImage ( mineType ) ) {
// 特殊:图片则转为 Base64
byte [ ] bytes = HttpUtil . downloadBytes ( attachmentUrl ) ;
content = Base64 . encode ( bytes ) ;
} else {
content = knowledgeDocumentService . readUrl ( attachmentUrl ) ;
}
if ( StrUtil . isNotEmpty ( content ) ) {
attachmentContents . put ( name , content ) ;
}
} catch ( Exception e ) {
log . error ( " [buildAttachmentUserMessage][读取附件({}) 发生异常] " , attachmentUrl , e ) ;
}
}
if ( CollUtil . isEmpty ( attachmentContents ) ) {
return null ;
}
// 拼接 UserMessage 消息
String attachment = attachmentContents . entrySet ( ) . stream ( )
. map ( entry - > " <Attachment name= \" " + entry . getKey ( ) + " \" > " + entry . getValue ( ) + " </Attachment> " )
. collect ( Collectors . joining ( " \ n \ n " ) ) ;
return new UserMessage ( String . format ( Attachment_USER_MESSAGE_TEMPLATE , attachment ) ) ;
}
private AiChatMessageDO createChatMessage ( Long conversationId , Long replyId ,
AiModelDO model , Long userId , Long roleId ,
MessageType messageType , String content , Boolean useContext ,
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments ) {
AiModelDO model , Long userId , Long roleId ,
MessageType messageType , String content , Boolean useContext ,
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments ,
List < String > attachmentUrls ) {
AiChatMessageDO message = new AiChatMessageDO ( ) . setConversationId ( conversationId ) . setReplyId ( replyId )
. setModel ( model . getModel ( ) ) . setModelId ( model . getId ( ) ) . setUserId ( userId ) . setRoleId ( roleId )
. setType ( messageType . getValue ( ) ) . setContent ( content ) . setUseContext ( useContext )
. setSegmentIds ( convertList ( knowledgeSegments , AiKnowledgeSegmentSearchRespBO : : getId ) ) ;
. setSegmentIds ( convertList ( knowledgeSegments , AiKnowledgeSegmentSearchRespBO : : getId ) )
. setAttachmentUrls ( attachmentUrls ) ;
message . setCreateTime ( LocalDateTime . now ( ) ) ;
chatMessageMapper . insert ( message ) ;
return message ;