From c31b66b6cc675c9edb26ea03af0dbecb27edf792 Mon Sep 17 00:00:00 2001 From: YunaiV Date: Sun, 24 Aug 2025 09:35:03 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E3=80=90ai=20=E5=A4=A7=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E3=80=91RAG=20=E5=A2=9E=E5=8A=A0=20rerank=20=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../AiKnowledgeSegmentServiceImpl.java | 72 ++++++++++++++----- .../core/model/chat/TongYiChatModelTests.java | 37 ++++++++++ .../src/main/resources/application.yaml | 4 +- 3 files changed, 95 insertions(+), 18 deletions(-) diff --git a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java index e3a6f08a11..dd0f91315b 100644 --- a/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java +++ b/yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java @@ -18,6 +18,10 @@ import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO; import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO; import cn.iocoder.yudao.module.ai.service.model.AiModelService; +import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankOptions; +import com.alibaba.cloud.ai.model.RerankModel; +import com.alibaba.cloud.ai.model.RerankRequest; +import com.alibaba.cloud.ai.model.RerankResponse; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.document.Document; @@ -27,6 +31,7 @@ import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; @@ -36,6 +41,7 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG; import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS; +import static org.springframework.ai.vectorstore.SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL; /** * AI 知识库分片 Service 实现类 @@ -55,6 +61,11 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService VECTOR_STORE_METADATA_DOCUMENT_ID, String.class, VECTOR_STORE_METADATA_SEGMENT_ID, String.class); + /** + * Rerank 在向量检索时,检索数量 * 该系数,目的是为了提升 Rerank 的效果 + */ + private static final Integer RERANK_RETRIEVAL_FACTOR = 4; + @Resource private AiKnowledgeSegmentMapper segmentMapper; @@ -69,6 +80,9 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService @Resource private TokenCountEstimator tokenCountEstimator; + @Autowired(required = false) // 由于 spring.ai.model.rerank 配置项,可以关闭 RerankModel 的功能,所以这里只能不强制注入 + private RerankModel rerankModel; + @Override public PageResult getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) { return segmentMapper.selectPage(pageReqVO); @@ -211,28 +225,16 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService // 1. 校验 AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId()); - // 2.1 向量检索 - VectorStore vectorStore = getVectorStoreById(knowledge); - List documents = vectorStore.similaritySearch(SearchRequest.builder() - .query(reqBO.getContent()) - .topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK())) - .similarityThreshold( - ObjUtil.defaultIfNull(reqBO.getSimilarityThreshold(), knowledge.getSimilarityThreshold())) - .filterExpression(new FilterExpressionBuilder() - .eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqBO.getKnowledgeId().toString()) - .build()) - .build()); - if (CollUtil.isEmpty(documents)) { - return ListUtil.empty(); - } - // 2.2 段落召回 + // 2. 检索 + List documents = searchDocument(knowledge, reqBO); + + // 3.1 段落召回 List segments = segmentMapper .selectListByVectorIds(convertList(documents, Document::getId)); if (CollUtil.isEmpty(segments)) { return ListUtil.empty(); } - - // 3. 增加召回次数 + // 3.2 增加召回次数 segmentMapper.updateRetrievalCountIncrByIds(convertList(segments, AiKnowledgeSegmentDO::getId)); // 4. 构建结果 @@ -249,6 +251,42 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService return result; } + /** + * 基于 Embedding + Rerank Model,检索知识库中的文档 + * + * @param knowledge 知识库 + * @param reqBO 检索请求 + * @return 文档列表 + */ + private List searchDocument(AiKnowledgeDO knowledge, AiKnowledgeSegmentSearchReqBO reqBO) { + VectorStore vectorStore = getVectorStoreById(knowledge); + Integer topK = ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK()); + Double similarityThreshold = ObjUtil.defaultIfNull(reqBO.getSimilarityThreshold(), knowledge.getSimilarityThreshold()); + + // 1. 向量检索 + int searchTopK = rerankModel != null ? topK * RERANK_RETRIEVAL_FACTOR : topK; + double searchSimilarityThreshold = rerankModel != null ? SIMILARITY_THRESHOLD_ACCEPT_ALL : similarityThreshold; + SearchRequest.Builder searchRequestBuilder = SearchRequest.builder() + .query(reqBO.getContent()) + .topK(searchTopK).similarityThreshold(searchSimilarityThreshold) + .filterExpression(new FilterExpressionBuilder() + .eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqBO.getKnowledgeId().toString()).build()); + List documents = vectorStore.similaritySearch(searchRequestBuilder.build()); + if (CollUtil.isEmpty(documents)) { + return documents; + } + + // 2. Rerank 重排序 + if (rerankModel != null) { + RerankResponse rerankResponse = rerankModel.call(new RerankRequest(reqBO.getContent(), documents, + DashScopeRerankOptions.builder().withTopN(topK).build())); + documents = convertList(rerankResponse.getResults(), + documentWithScore -> documentWithScore.getScore() >= similarityThreshold + ? documentWithScore.getOutput() : null); + } + return documents; + } + @Override public List splitContent(String url, Integer segmentMaxTokens) { // 1. 读取 URL 内容 diff --git a/yudao-module-ai/src/test/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/chat/TongYiChatModelTests.java b/yudao-module-ai/src/test/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/chat/TongYiChatModelTests.java index 8a4544967f..7c62ec71b6 100644 --- a/yudao-module-ai/src/test/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/chat/TongYiChatModelTests.java +++ b/yudao-module-ai/src/test/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/chat/TongYiChatModelTests.java @@ -1,8 +1,15 @@ package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat; +import cn.iocoder.yudao.framework.common.util.json.JsonUtils; import com.alibaba.cloud.ai.dashscope.api.DashScopeApi; import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel; import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions; +import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankModel; +import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankOptions; +import com.alibaba.cloud.ai.model.RerankModel; +import com.alibaba.cloud.ai.model.RerankOptions; +import com.alibaba.cloud.ai.model.RerankRequest; +import com.alibaba.cloud.ai.model.RerankResponse; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.messages.Message; @@ -10,11 +17,14 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.Document; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; +import static java.util.Arrays.asList; + /** * {@link DashScopeChatModel} 集成测试类 * @@ -89,4 +99,31 @@ public class TongYiChatModelTests { }).then().block(); } + @Test + @Disabled + public void testRerank() { + // 准备环境 + RerankModel rerankModel = new DashScopeRerankModel( + DashScopeApi.builder() + .apiKey("sk-47aa124781be4bfb95244cc62f63f7d0") + .build()); + // 准备参数 + String query = "spring"; + Document document01 = new Document("abc"); + Document document02 = new Document("sapring"); + RerankOptions options = DashScopeRerankOptions.builder() + .withTopN(1) + .withModel("gte-rerank-v2") + .build(); + RerankRequest rerankRequest = new RerankRequest( + query, + asList(document01, document02), + options); + + // 调用 + RerankResponse call = rerankModel.call(rerankRequest); + // 打印结果 + System.out.println(JsonUtils.toJsonPrettyString(call)); + } + } diff --git a/yudao-server/src/main/resources/application.yaml b/yudao-server/src/main/resources/application.yaml index ca43def16c..2286b73159 100644 --- a/yudao-server/src/main/resources/application.yaml +++ b/yudao-server/src/main/resources/application.yaml @@ -184,7 +184,7 @@ spring: stabilityai: api-key: sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx dashscope: # 通义千问 - api-key: sk-71800982914041848008480000000000 + api-key: sk-47aa124781be4bfb95244cc62f6xxxx minimax: # Minimax:https://www.minimaxi.com/ api-key: xxxx moonshot: # 月之暗灭(KIMI) @@ -194,6 +194,8 @@ spring: chat: options: model: deepseek-chat + model: + rerank: dashscope # 是否开启“通义千问”的 Rerank 模型,填写 dashscope 开启 yudao: ai: