@@ -23,6 +23,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper ;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants ;
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum ;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchClient ;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchRequest ;
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse ;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService ;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService ;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO ;
@@ -44,6 +47,7 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel ;
import org.springframework.ai.chat.prompt.ChatOptions ;
import org.springframework.ai.chat.prompt.Prompt ;
import org.springframework.beans.factory.annotation.Autowired ;
import org.springframework.stereotype.Service ;
import org.springframework.transaction.annotation.Transactional ;
import reactor.core.publisher.Flux ;
@@ -69,6 +73,11 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
@Slf4j
public class AiChatMessageServiceImpl implements AiChatMessageService {
/**
* 联网搜索的结束数
*/
private static final Integer WEB_SEARCH_COUNT = 10 ;
// TODO @芋艿:后续优化下对话的 Prompt 整体结构
/**
@@ -78,6 +87,10 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
" %s \ n \ n " + // 多个 <Reference></Reference> 的拼接
" 回答要求: \ n- 避免提及你是从 <Reference></Reference> 获取的知识。 " ;
private static final String WEB_SEARCH_USER_MESSAGE_TEMPLATE = " 使用 <WebSearch></WebSearch> 标记中的内容作为本次对话的参考: \ n \ n " +
" %s \ n \ n " + // 多个 <WebSearch></WebSearch> 的拼接
" 回答要求: \ n- 避免提及你是从 <WebSearch></WebSearch> 获取的知识。 " ;
/**
* 附件转 ${@link UserMessage} 的内容模版
*/
@@ -102,6 +115,10 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Resource
private AiToolService toolService ;
@SuppressWarnings ( " SpringJavaAutowiredFieldsWarningInspection " )
@Autowired ( required = false ) // 由于 yudao.ai.web-search.enable 配置项,可以关闭 AiWebSearchClient 的功能,所以这里只能不强制注入
private AiWebSearchClient webSearchClient ;
@Transactional ( rollbackFor = Exception . class )
public AiChatMessageSendRespVO sendMessage ( AiChatMessageSendReqVO sendReqVO , Long userId ) {
// 1.1 校验对话存在
@@ -115,30 +132,35 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiModelDO model = modalService . validateModel ( conversation . getModelId ( ) ) ;
ChatModel chatModel = modalService . getChatModel ( model . getId ( ) ) ;
// 2. 知识库找 回
// 2.1 知识库召 回
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = recallKnowledgeSegment (
sendReqVO . getContent ( ) , conversation ) ;
// 2.2 联网搜索
AiWebSearchResponse webSearchResponse = Boolean . TRUE . equals ( sendReqVO . getUseSearch ( ) ) & & webSearchClient ! = null ?
webSearchClient . search ( new AiWebSearchRequest ( ) . setQuery ( sendReqVO . getContent ( ) )
. setSummary ( true ) . setCount ( WEB_SEARCH_COUNT ) ) : null ;
// 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage ( conversation . getId ( ) , null , model ,
userId , conversation . getRoleId ( ) , MessageType . USER , sendReqVO . getContent ( ) , sendReqVO . getUseContext ( ) ,
null , sendReqVO . getAttachmentUrls ( ) ) ;
null , sendReqVO . getAttachmentUrls ( ) , null );
// 3 .1 插入 assistant 接收消息
// 4 .1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage ( conversation . getId ( ) , userMessage . getId ( ) , model ,
userId , conversation . getRoleId ( ) , MessageType . ASSISTANT , " " , sendReqVO . getUseContext ( ) ,
knowledgeSegments , null ) ;
knowledgeSegments , null , webSearchResponse );
// 3 .2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt ( conversation , historyMessages , knowledgeSegments , model , sendReqVO ) ;
// 4 .2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt ( conversation , historyMessages , knowledgeSegments , webSearchResponse , model , sendReqVO ) ;
ChatResponse chatResponse = chatModel . call ( prompt ) ;
// 3 .3 更新响应内容
// 4 .3 更新响应内容
String newContent = AiUtils . getChatResponseContent ( chatResponse ) ;
String newReasoningContent = AiUtils . getChatResponseReasoningContent ( chatResponse ) ;
chatMessageMapper . updateById ( new AiChatMessageDO ( ) . setId ( assistantMessage . getId ( ) )
. setContent ( newContent ) . setReasoningContent ( newReasoningContent ) ) ;
// 3 .4 响应结果
// 4 .4 响应结果
Map < Long , AiKnowledgeDocumentDO > documentMap = knowledgeDocumentService . getKnowledgeDocumentMap (
convertSet ( knowledgeSegments , AiKnowledgeSegmentSearchRespBO : : getDocumentId ) ) ;
List < AiChatMessageRespVO . KnowledgeSegment > segments = BeanUtils . toBean ( knowledgeSegments ,
@@ -149,7 +171,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return new AiChatMessageSendRespVO ( )
. setSend ( BeanUtils . toBean ( userMessage , AiChatMessageSendRespVO . Message . class ) )
. setReceive ( BeanUtils . toBean ( assistantMessage , AiChatMessageSendRespVO . Message . class )
. setContent ( newContent ) . setSegments ( segments ) ) ;
. setContent ( newContent ) . setSegments ( segments )
. setWebSearchPages ( webSearchResponse ! = null ? webSearchResponse . getLists ( ) : null ) ) ;
}
@Override
@@ -166,30 +189,36 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiModelDO model = modalService . validateModel ( conversation . getModelId ( ) ) ;
StreamingChatModel chatModel = modalService . getChatModel ( model . getId ( ) ) ;
// 2. 知识库找回
// 2.1 知识库找回
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = recallKnowledgeSegment (
sendReqVO . getContent ( ) , conversation ) ;
// 2.2 联网搜索
AiWebSearchResponse webSearchResponse = Boolean . TRUE . equals ( sendReqVO . getUseSearch ( ) ) & & webSearchClient ! = null ?
webSearchClient . search ( new AiWebSearchRequest ( ) . setQuery ( sendReqVO . getContent ( ) )
. setSummary ( true ) . setCount ( WEB_SEARCH_COUNT ) ) : null ;
// 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage ( conversation . getId ( ) , null , model ,
userId , conversation . getRoleId ( ) , MessageType . USER , sendReqVO . getContent ( ) , sendReqVO . getUseContext ( ) ,
null , sendReqVO . getAttachmentUrls ( ) ) ;
null , sendReqVO . getAttachmentUrls ( ) , null );
// 4.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage ( conversation . getId ( ) , userMessage . getId ( ) , model ,
userId , conversation . getRoleId ( ) , MessageType . ASSISTANT , " " , sendReqVO . getUseContext ( ) ,
knowledgeSegments , null ) ;
knowledgeSegments , null , webSearchResponse );
// 4.2 构建 Prompt, 并进行调用
Prompt prompt = buildPrompt ( conversation , historyMessages , knowledgeSegments , model , sendReqVO ) ;
Prompt prompt = buildPrompt ( conversation , historyMessages , knowledgeSegments , webSearchResponse , model , sendReqVO ) ;
Flux < ChatResponse > streamResponse = chatModel . stream ( prompt ) ;
// 4.3 流式返回
StringBuffer contentBuffer = new StringBuffer ( ) ;
StringBuffer reasoningContentBuffer = new StringBuffer ( ) ;
return streamResponse . map ( chunk - > {
// 处理知识库的返回,只有首次才有
// 仅首次:返回知识库、联网搜索
List < AiChatMessageRespVO . KnowledgeSegment > segments = null ;
List < AiWebSearchResponse . WebPage > webSearchPages = null ;
if ( StrUtil . isEmpty ( contentBuffer ) ) {
Map < Long , AiKnowledgeDocumentDO > documentMap = TenantUtils . executeIgnore ( ( ) - >
knowledgeDocumentService . getKnowledgeDocumentMap (
@@ -198,6 +227,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiKnowledgeDocumentDO document = documentMap . get ( segment . getDocumentId ( ) ) ;
segment . setDocumentName ( document ! = null ? document . getName ( ) : null ) ;
} ) ;
if ( webSearchResponse ! = null ) {
webSearchPages = webSearchResponse . getLists ( ) ;
}
}
// 响应结果
String newContent = AiUtils . getChatResponseContent ( chunk ) ;
@@ -213,7 +245,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
. setReceive ( BeanUtils . toBean ( assistantMessage , AiChatMessageSendRespVO . Message . class )
. setContent ( StrUtil . nullToDefault ( newContent , " " ) ) // 避免 null 的 情况
. setReasoningContent ( StrUtil . nullToDefault ( newReasoningContent , " " ) ) // 避免 null 的 情况
. setSegments ( segments ) ) ) ; // 知识库返回
. setSegments ( segments ) . setWebSearchPages ( webSearchPages ) )) ; // 知识库 + 联网搜索
} ) . doOnComplete ( ( ) - > {
// 忽略租户,因为 Flux 异步无法透传租户
TenantUtils . executeIgnore ( ( ) - > chatMessageMapper . updateById (
@@ -239,7 +271,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
return Collections . emptyList ( ) ;
}
// 2. 遍历找 回
// 2. 遍历召 回
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = new ArrayList < > ( ) ;
for ( Long knowledgeId : role . getKnowledgeIds ( ) ) {
knowledgeSegments . addAll ( knowledgeSegmentService . searchKnowledgeSegment ( new AiKnowledgeSegmentSearchReqBO ( )
@@ -250,6 +282,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
private Prompt buildPrompt ( AiChatConversationDO conversation , List < AiChatMessageDO > messages ,
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments ,
AiWebSearchResponse webSearchResponse ,
AiModelDO model , AiChatMessageSendReqVO sendReqVO ) {
List < Message > chatMessages = new ArrayList < > ( ) ;
// 1.1 System Context 角色设定
@@ -265,6 +298,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
if ( attachmentUserMessage ! = null ) {
chatMessages . add ( attachmentUserMessage ) ;
}
// TODO @芋艿:历史的知识库;历史的搜索,要不要拼接?
} ) ;
// 1.3 当前 user message 新发送消息
@@ -278,7 +312,20 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
chatMessages . add ( new UserMessage ( String . format ( KNOWLEDGE_USER_MESSAGE_TEMPLATE , reference ) ) ) ;
}
// 1.5 附件 ,通过 UserMessage 实现
// 1.5 联网搜索 ,通过 UserMessage 实现
if ( webSearchResponse ! = null & & CollUtil . isNotEmpty ( webSearchResponse . getLists ( ) ) ) {
String webSearch = webSearchResponse . getLists ( ) . stream ( )
. map ( page - > {
String summary = StrUtil . isNotEmpty ( page . getSummary ( ) ) ?
" \ nSummary: " + page . getSummary ( ) : " " ;
return " <WebSearch title= \" " + page . getTitle ( ) + " \" url= \" " + page . getUrl ( ) + " \" > "
+ StrUtil . blankToDefault ( page . getSummary ( ) , page . getSnippet ( ) ) + " </WebSearch> " ;
} )
. collect ( Collectors . joining ( " \ n \ n " ) ) ;
chatMessages . add ( new UserMessage ( String . format ( WEB_SEARCH_USER_MESSAGE_TEMPLATE , webSearch ) ) ) ;
}
// 1.6 附件,通过 UserMessage 实现
if ( CollUtil . isNotEmpty ( sendReqVO . getAttachmentUrls ( ) ) ) {
UserMessage attachmentUserMessage = buildAttachmentUserMessage ( sendReqVO . getAttachmentUrls ( ) ) ;
if ( attachmentUserMessage ! = null ) {
@@ -383,12 +430,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiModelDO model , Long userId , Long roleId ,
MessageType messageType , String content , Boolean useContext ,
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments ,
List < String > attachmentUrls ) {
List < String > attachmentUrls ,
AiWebSearchResponse webSearchResponse ) {
AiChatMessageDO message = new AiChatMessageDO ( ) . setConversationId ( conversationId ) . setReplyId ( replyId )
. setModel ( model . getModel ( ) ) . setModelId ( model . getId ( ) ) . setUserId ( userId ) . setRoleId ( roleId )
. setType ( messageType . getValue ( ) ) . setContent ( content ) . setUseContext ( useContext )
. setSegmentIds ( convertList ( knowledgeSegments , AiKnowledgeSegmentSearchRespBO : : getId ) )
. setAttachmentUrls ( attachmentUrls ) ;
if ( webSearchResponse ! = null ) {
message . setWebSearchPages ( webSearchResponse . getLists ( ) ) ;
}
message . setCreateTime ( LocalDateTime . now ( ) ) ;
chatMessageMapper . insert ( message ) ;
return message ;