feat:【ai 大模型】RAG 增加 rerank 模型
This commit is contained in:
@@ -18,6 +18,10 @@ import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
|
|||||||
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
|
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
|
||||||
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
|
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
||||||
|
import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankOptions;
|
||||||
|
import com.alibaba.cloud.ai.model.RerankModel;
|
||||||
|
import com.alibaba.cloud.ai.model.RerankRequest;
|
||||||
|
import com.alibaba.cloud.ai.model.RerankResponse;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.document.Document;
|
import org.springframework.ai.document.Document;
|
||||||
@@ -27,6 +31,7 @@ import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
|||||||
import org.springframework.ai.vectorstore.SearchRequest;
|
import org.springframework.ai.vectorstore.SearchRequest;
|
||||||
import org.springframework.ai.vectorstore.VectorStore;
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
|
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.context.annotation.Lazy;
|
import org.springframework.context.annotation.Lazy;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@@ -36,6 +41,7 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
|
|||||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
|
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
|
||||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG;
|
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG;
|
||||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
|
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
|
||||||
|
import static org.springframework.ai.vectorstore.SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI 知识库分片 Service 实现类
|
* AI 知识库分片 Service 实现类
|
||||||
@@ -55,6 +61,11 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
|
|||||||
VECTOR_STORE_METADATA_DOCUMENT_ID, String.class,
|
VECTOR_STORE_METADATA_DOCUMENT_ID, String.class,
|
||||||
VECTOR_STORE_METADATA_SEGMENT_ID, String.class);
|
VECTOR_STORE_METADATA_SEGMENT_ID, String.class);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rerank 在向量检索时,检索数量 * 该系数,目的是为了提升 Rerank 的效果
|
||||||
|
*/
|
||||||
|
private static final Integer RERANK_RETRIEVAL_FACTOR = 4;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private AiKnowledgeSegmentMapper segmentMapper;
|
private AiKnowledgeSegmentMapper segmentMapper;
|
||||||
|
|
||||||
@@ -69,6 +80,9 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
|
|||||||
@Resource
|
@Resource
|
||||||
private TokenCountEstimator tokenCountEstimator;
|
private TokenCountEstimator tokenCountEstimator;
|
||||||
|
|
||||||
|
@Autowired(required = false) // 由于 spring.ai.model.rerank 配置项,可以关闭 RerankModel 的功能,所以这里只能不强制注入
|
||||||
|
private RerankModel rerankModel;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) {
|
public PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) {
|
||||||
return segmentMapper.selectPage(pageReqVO);
|
return segmentMapper.selectPage(pageReqVO);
|
||||||
@@ -211,28 +225,16 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
|
|||||||
// 1. 校验
|
// 1. 校验
|
||||||
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId());
|
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId());
|
||||||
|
|
||||||
// 2.1 向量检索
|
// 2. 检索
|
||||||
VectorStore vectorStore = getVectorStoreById(knowledge);
|
List<Document> documents = searchDocument(knowledge, reqBO);
|
||||||
List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
|
|
||||||
.query(reqBO.getContent())
|
// 3.1 段落召回
|
||||||
.topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK()))
|
|
||||||
.similarityThreshold(
|
|
||||||
ObjUtil.defaultIfNull(reqBO.getSimilarityThreshold(), knowledge.getSimilarityThreshold()))
|
|
||||||
.filterExpression(new FilterExpressionBuilder()
|
|
||||||
.eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqBO.getKnowledgeId().toString())
|
|
||||||
.build())
|
|
||||||
.build());
|
|
||||||
if (CollUtil.isEmpty(documents)) {
|
|
||||||
return ListUtil.empty();
|
|
||||||
}
|
|
||||||
// 2.2 段落召回
|
|
||||||
List<AiKnowledgeSegmentDO> segments = segmentMapper
|
List<AiKnowledgeSegmentDO> segments = segmentMapper
|
||||||
.selectListByVectorIds(convertList(documents, Document::getId));
|
.selectListByVectorIds(convertList(documents, Document::getId));
|
||||||
if (CollUtil.isEmpty(segments)) {
|
if (CollUtil.isEmpty(segments)) {
|
||||||
return ListUtil.empty();
|
return ListUtil.empty();
|
||||||
}
|
}
|
||||||
|
// 3.2 增加召回次数
|
||||||
// 3. 增加召回次数
|
|
||||||
segmentMapper.updateRetrievalCountIncrByIds(convertList(segments, AiKnowledgeSegmentDO::getId));
|
segmentMapper.updateRetrievalCountIncrByIds(convertList(segments, AiKnowledgeSegmentDO::getId));
|
||||||
|
|
||||||
// 4. 构建结果
|
// 4. 构建结果
|
||||||
@@ -249,6 +251,42 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 基于 Embedding + Rerank Model,检索知识库中的文档
|
||||||
|
*
|
||||||
|
* @param knowledge 知识库
|
||||||
|
* @param reqBO 检索请求
|
||||||
|
* @return 文档列表
|
||||||
|
*/
|
||||||
|
private List<Document> searchDocument(AiKnowledgeDO knowledge, AiKnowledgeSegmentSearchReqBO reqBO) {
|
||||||
|
VectorStore vectorStore = getVectorStoreById(knowledge);
|
||||||
|
Integer topK = ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK());
|
||||||
|
Double similarityThreshold = ObjUtil.defaultIfNull(reqBO.getSimilarityThreshold(), knowledge.getSimilarityThreshold());
|
||||||
|
|
||||||
|
// 1. 向量检索
|
||||||
|
int searchTopK = rerankModel != null ? topK * RERANK_RETRIEVAL_FACTOR : topK;
|
||||||
|
double searchSimilarityThreshold = rerankModel != null ? SIMILARITY_THRESHOLD_ACCEPT_ALL : similarityThreshold;
|
||||||
|
SearchRequest.Builder searchRequestBuilder = SearchRequest.builder()
|
||||||
|
.query(reqBO.getContent())
|
||||||
|
.topK(searchTopK).similarityThreshold(searchSimilarityThreshold)
|
||||||
|
.filterExpression(new FilterExpressionBuilder()
|
||||||
|
.eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqBO.getKnowledgeId().toString()).build());
|
||||||
|
List<Document> documents = vectorStore.similaritySearch(searchRequestBuilder.build());
|
||||||
|
if (CollUtil.isEmpty(documents)) {
|
||||||
|
return documents;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Rerank 重排序
|
||||||
|
if (rerankModel != null) {
|
||||||
|
RerankResponse rerankResponse = rerankModel.call(new RerankRequest(reqBO.getContent(), documents,
|
||||||
|
DashScopeRerankOptions.builder().withTopN(topK).build()));
|
||||||
|
documents = convertList(rerankResponse.getResults(),
|
||||||
|
documentWithScore -> documentWithScore.getScore() >= similarityThreshold
|
||||||
|
? documentWithScore.getOutput() : null);
|
||||||
|
}
|
||||||
|
return documents;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<AiKnowledgeSegmentDO> splitContent(String url, Integer segmentMaxTokens) {
|
public List<AiKnowledgeSegmentDO> splitContent(String url, Integer segmentMaxTokens) {
|
||||||
// 1. 读取 URL 内容
|
// 1. 读取 URL 内容
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||||
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
|
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
|
||||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
||||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
||||||
|
import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankModel;
|
||||||
|
import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankOptions;
|
||||||
|
import com.alibaba.cloud.ai.model.RerankModel;
|
||||||
|
import com.alibaba.cloud.ai.model.RerankOptions;
|
||||||
|
import com.alibaba.cloud.ai.model.RerankRequest;
|
||||||
|
import com.alibaba.cloud.ai.model.RerankResponse;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.springframework.ai.chat.messages.Message;
|
import org.springframework.ai.chat.messages.Message;
|
||||||
@@ -10,11 +17,14 @@ import org.springframework.ai.chat.messages.SystemMessage;
|
|||||||
import org.springframework.ai.chat.messages.UserMessage;
|
import org.springframework.ai.chat.messages.UserMessage;
|
||||||
import org.springframework.ai.chat.model.ChatResponse;
|
import org.springframework.ai.chat.model.ChatResponse;
|
||||||
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
|
import org.springframework.ai.document.Document;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static java.util.Arrays.asList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link DashScopeChatModel} 集成测试类
|
* {@link DashScopeChatModel} 集成测试类
|
||||||
*
|
*
|
||||||
@@ -89,4 +99,31 @@ public class TongYiChatModelTests {
|
|||||||
}).then().block();
|
}).then().block();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Disabled
|
||||||
|
public void testRerank() {
|
||||||
|
// 准备环境
|
||||||
|
RerankModel rerankModel = new DashScopeRerankModel(
|
||||||
|
DashScopeApi.builder()
|
||||||
|
.apiKey("sk-47aa124781be4bfb95244cc62f63f7d0")
|
||||||
|
.build());
|
||||||
|
// 准备参数
|
||||||
|
String query = "spring";
|
||||||
|
Document document01 = new Document("abc");
|
||||||
|
Document document02 = new Document("sapring");
|
||||||
|
RerankOptions options = DashScopeRerankOptions.builder()
|
||||||
|
.withTopN(1)
|
||||||
|
.withModel("gte-rerank-v2")
|
||||||
|
.build();
|
||||||
|
RerankRequest rerankRequest = new RerankRequest(
|
||||||
|
query,
|
||||||
|
asList(document01, document02),
|
||||||
|
options);
|
||||||
|
|
||||||
|
// 调用
|
||||||
|
RerankResponse call = rerankModel.call(rerankRequest);
|
||||||
|
// 打印结果
|
||||||
|
System.out.println(JsonUtils.toJsonPrettyString(call));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ spring:
|
|||||||
stabilityai:
|
stabilityai:
|
||||||
api-key: sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx
|
api-key: sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx
|
||||||
dashscope: # 通义千问
|
dashscope: # 通义千问
|
||||||
api-key: sk-71800982914041848008480000000000
|
api-key: sk-47aa124781be4bfb95244cc62f6xxxx
|
||||||
minimax: # Minimax:https://www.minimaxi.com/
|
minimax: # Minimax:https://www.minimaxi.com/
|
||||||
api-key: xxxx
|
api-key: xxxx
|
||||||
moonshot: # 月之暗灭(KIMI)
|
moonshot: # 月之暗灭(KIMI)
|
||||||
@@ -194,6 +194,8 @@ spring:
|
|||||||
chat:
|
chat:
|
||||||
options:
|
options:
|
||||||
model: deepseek-chat
|
model: deepseek-chat
|
||||||
|
model:
|
||||||
|
rerank: dashscope # 是否开启“通义千问”的 Rerank 模型,填写 dashscope 开启
|
||||||
|
|
||||||
yudao:
|
yudao:
|
||||||
ai:
|
ai:
|
||||||
|
|||||||
Reference in New Issue
Block a user