feat: 热点改进

This commit is contained in:
2026-02-25 01:24:57 +08:00
parent cd5a2f0c7d
commit 7285534405
8 changed files with 537 additions and 122 deletions

View File

@@ -84,6 +84,55 @@ public class DifyClient {
.doOnError(e -> log.error("[chatStream] Dify 流式响应错误", e));
}
/**
* 调用 Dify 聊天流式 API支持自定义 inputs
*
* @param apiKey Dify API Key
* @param inputs 自定义输入参数
* @param content 用户输入
* @param conversationId 会话ID可选
* @return 流式响应
*/
public Flux<DifyChatRespVO> chatStreamWithInputs(String apiKey, Map<String, Object> inputs, String content, String conversationId) {
String apiUrl = difyProperties.getApiUrl() + "/v1/chat-messages";
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("inputs", inputs);
requestBody.put("query", content);
requestBody.put("response_mode", "streaming");
requestBody.put("conversation_id", conversationId != null ? conversationId : "");
requestBody.put("user", "user-" + System.currentTimeMillis());
AtomicReference<String> responseConversationId = new AtomicReference<>(conversationId);
StringBuilder fullContent = new StringBuilder();
WebClient webClient = WebClient.builder()
.baseUrl(apiUrl)
.defaultHeader("Authorization", "Bearer " + apiKey)
.defaultHeader("Content-Type", "application/json")
.build();
return webClient.post()
.bodyValue(requestBody)
.accept(MediaType.TEXT_EVENT_STREAM)
.retrieve()
.bodyToFlux(String.class)
.flatMap(event -> Mono.justOrEmpty(parseSSEEvent(event)))
.doOnNext(resp -> {
if (resp.getConversationId() != null) {
responseConversationId.set(resp.getConversationId());
}
if (resp.getContent() != null) {
fullContent.append(resp.getContent());
}
})
.doOnComplete(() -> {
log.info("[chatStreamWithInputs] Dify 流式响应完成会话ID: {}, 内容长度: {}",
responseConversationId.get(), fullContent.length());
})
.doOnError(e -> log.error("[chatStreamWithInputs] Dify 流式响应错误", e));
}
/**
* 解析 SSE 事件
*/

View File

@@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.tik.dify.service.DifyService;
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatReqVO;
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatRespVO;
import cn.iocoder.yudao.module.tik.dify.vo.ForecastRewriteReqVO;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
@@ -39,4 +40,14 @@ public class AppDifyController {
.map(CommonResult::success);
}
@PostMapping(value = "/forecast/rewrite", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@Operation(summary = "Forecast 文案改写(流式)")
public Flux<CommonResult<DifyChatRespVO>> rewriteStream(@Valid @RequestBody ForecastRewriteReqVO reqVO) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
String userId = loginUserId != null ? loginUserId.toString() : "1"; // 默认用户ID
return difyService.rewriteStream(reqVO, userId)
.map(CommonResult::success);
}
}

View File

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.tik.dify.service;
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatReqVO;
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatRespVO;
import cn.iocoder.yudao.module.tik.dify.vo.ForecastRewriteReqVO;
import reactor.core.publisher.Flux;
/**
@@ -20,4 +21,13 @@ public interface DifyService {
*/
Flux<DifyChatRespVO> chatStream(DifyChatReqVO reqVO, String userId);
/**
* Forecast 文案改写(流式)
*
* @param reqVO 请求参数
* @param userId 用户ID
* @return 流式响应
*/
Flux<DifyChatRespVO> rewriteStream(ForecastRewriteReqVO reqVO, String userId);
}

View File

@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.tik.dify.service;
import cn.iocoder.yudao.module.tik.dify.client.DifyClient;
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatReqVO;
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatRespVO;
import cn.iocoder.yudao.module.tik.dify.vo.ForecastRewriteReqVO;
import cn.iocoder.yudao.module.tik.enums.AiModelTypeEnum;
import cn.iocoder.yudao.module.tik.enums.AiPlatformEnum;
import cn.iocoder.yudao.module.tik.muye.aiagent.dal.AiAgentDO;
@@ -16,6 +17,8 @@ import org.springframework.validation.annotation.Validated;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
@@ -133,9 +136,114 @@ public class DifyServiceImpl implements DifyService {
});
}
@Override
public Flux<DifyChatRespVO> rewriteStream(ForecastRewriteReqVO reqVO, String userId) {
// 用于存储预扣记录ID
AtomicLong pendingRecordId = new AtomicLong();
// 用于存储会话ID
AtomicReference<String> conversationIdRef = new AtomicReference<>("");
return Mono.fromCallable(() -> {
// 1. 获取智能体配置(通过 agentId 获取 systemPrompt
AiAgentDO agent = aiAgentService.getAiAgent(reqVO.getAgentId());
if (agent == null) {
throw new RuntimeException("智能体不存在");
}
// 2. 根据 modelType 获取对应的积分配置
AiModelTypeEnum modelTypeEnum = "forecast_meiju".equals(reqVO.getModelType())
? AiModelTypeEnum.FORECAST_MEIJU
: AiModelTypeEnum.FORECAST_STANDARD;
AiModelConfigDO config = pointsService.getConfig(
AiPlatformEnum.DIFY.getPlatform(),
modelTypeEnum.getModelCode());
// 3. 预检积分
pointsService.checkPoints(userId, config.getConsumePoints());
// 4. 创建预扣记录
Long recordId = pointsService.createPendingDeduct(
userId,
config.getConsumePoints(),
"forecast_rewrite",
reqVO.getModelType()
);
pendingRecordId.set(recordId);
// 5. 构建 inputs 参数(使用 agent 的 systemPrompt
Map<String, Object> inputs = new HashMap<>();
inputs.put("sysPrompt", agent.getSystemPrompt());
inputs.put("userText", reqVO.getUserText());
inputs.put("level", reqVO.getLevel());
// 6. 返回调用参数
return new ForecastRewriteContext(inputs, config.getApiKey(), config.getConsumePoints());
})
.flatMapMany(context -> {
// 7. 调用 Dify 流式 API
return difyClient.chatStreamWithInputs(
context.apiKey(),
context.inputs(),
reqVO.getUserText(),
null
)
.doOnNext(resp -> {
if (resp.getConversationId() != null) {
conversationIdRef.set(resp.getConversationId());
}
})
// 8. 流结束时确认扣费
.doOnComplete(() -> {
if (pendingRecordId.get() > 0) {
try {
pointsService.confirmPendingDeduct(pendingRecordId.get());
log.info("[rewriteStream] 流结束确认扣费记录ID: {}", pendingRecordId.get());
} catch (Exception e) {
log.error("[rewriteStream] 确认扣费失败", e);
}
}
})
// 9. 流出错时取消预扣
.doOnError(e -> {
if (pendingRecordId.get() > 0) {
try {
pointsService.cancelPendingDeduct(pendingRecordId.get());
log.info("[rewriteStream] 流出错取消预扣记录ID: {}", pendingRecordId.get());
} catch (Exception ex) {
log.error("[rewriteStream] 取消预扣失败", ex);
}
}
})
// 10. 用户取消时确认扣费(已消费的部分)
.doOnCancel(() -> {
if (pendingRecordId.get() > 0) {
try {
pointsService.confirmPendingDeduct(pendingRecordId.get());
log.info("[rewriteStream] 用户取消确认扣费记录ID: {}", pendingRecordId.get());
} catch (Exception e) {
log.error("[rewriteStream] 用户取消后扣费失败", e);
}
}
});
})
// 11. 在最后添加 done 事件
.concatWith(Mono.defer(() -> {
return Mono.just(DifyChatRespVO.done(conversationIdRef.get(), null));
}))
.onErrorResume(e -> {
log.error("[rewriteStream] Forecast 文案改写异常", e);
return Flux.just(DifyChatRespVO.error(e.getMessage()));
});
}
/**
* Dify 聊天上下文
*/
private record DifyChatContext(String systemPrompt, String apiKey, Integer consumePoints) {}
/**
* Forecast 改写上下文
*/
private record ForecastRewriteContext(Map<String, Object> inputs, String apiKey, Integer consumePoints) {}
}

View File

@@ -0,0 +1,29 @@
package cn.iocoder.yudao.module.tik.dify.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
/**
* Forecast 文案改写请求 VO
*/
@Schema(description = "Forecast 文案改写请求")
@Data
public class ForecastRewriteReqVO {
@Schema(description = "智能体ID", requiredMode = Schema.RequiredMode.REQUIRED)
@NotNull(message = "智能体ID不能为空")
private Long agentId;
@Schema(description = "用户输入文案", requiredMode = Schema.RequiredMode.REQUIRED)
@NotEmpty(message = "用户文案不能为空")
private String userText;
@Schema(description = "改写级别/强度", example = "50")
private Integer level = 50;
@Schema(description = "模型类型forecast_standard-标准版 forecast_meiju-美剧版", example = "forecast_standard")
private String modelType = "forecast_standard";
}

View File

@@ -20,6 +20,10 @@ public enum AiModelTypeEnum implements ArrayValuable<String> {
DIFY_WRITING_PRO("writing_pro", "Pro深度版", AiPlatformEnum.DIFY, "text"),
DIFY_WRITING_STANDARD("writing_standard", "标准版", AiPlatformEnum.DIFY, "text"),
// ========== Forecast 文案改写 ==========
FORECAST_STANDARD("forecast_standard", "文案改写-标准版", AiPlatformEnum.DIFY, "text"),
FORECAST_MEIJU("forecast_meiju", "文案改写-Pro版", AiPlatformEnum.DIFY, "text"),
// ========== 数字人模型 ==========
DIGITAL_HUMAN_LATENTSYNC("latentsync", "LatentSync", AiPlatformEnum.DIGITAL_HUMAN, "video"),
DIGITAL_HUMAN_KLING("kling", "可灵", AiPlatformEnum.DIGITAL_HUMAN, "video"),