feat: 热点改进
This commit is contained in:
@@ -84,6 +84,55 @@ public class DifyClient {
|
||||
.doOnError(e -> log.error("[chatStream] Dify 流式响应错误", e));
|
||||
}
|
||||
|
||||
/**
|
||||
* 调用 Dify 聊天流式 API(支持自定义 inputs)
|
||||
*
|
||||
* @param apiKey Dify API Key
|
||||
* @param inputs 自定义输入参数
|
||||
* @param content 用户输入
|
||||
* @param conversationId 会话ID(可选)
|
||||
* @return 流式响应
|
||||
*/
|
||||
public Flux<DifyChatRespVO> chatStreamWithInputs(String apiKey, Map<String, Object> inputs, String content, String conversationId) {
|
||||
String apiUrl = difyProperties.getApiUrl() + "/v1/chat-messages";
|
||||
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("inputs", inputs);
|
||||
requestBody.put("query", content);
|
||||
requestBody.put("response_mode", "streaming");
|
||||
requestBody.put("conversation_id", conversationId != null ? conversationId : "");
|
||||
requestBody.put("user", "user-" + System.currentTimeMillis());
|
||||
|
||||
AtomicReference<String> responseConversationId = new AtomicReference<>(conversationId);
|
||||
StringBuilder fullContent = new StringBuilder();
|
||||
|
||||
WebClient webClient = WebClient.builder()
|
||||
.baseUrl(apiUrl)
|
||||
.defaultHeader("Authorization", "Bearer " + apiKey)
|
||||
.defaultHeader("Content-Type", "application/json")
|
||||
.build();
|
||||
|
||||
return webClient.post()
|
||||
.bodyValue(requestBody)
|
||||
.accept(MediaType.TEXT_EVENT_STREAM)
|
||||
.retrieve()
|
||||
.bodyToFlux(String.class)
|
||||
.flatMap(event -> Mono.justOrEmpty(parseSSEEvent(event)))
|
||||
.doOnNext(resp -> {
|
||||
if (resp.getConversationId() != null) {
|
||||
responseConversationId.set(resp.getConversationId());
|
||||
}
|
||||
if (resp.getContent() != null) {
|
||||
fullContent.append(resp.getContent());
|
||||
}
|
||||
})
|
||||
.doOnComplete(() -> {
|
||||
log.info("[chatStreamWithInputs] Dify 流式响应完成,会话ID: {}, 内容长度: {}",
|
||||
responseConversationId.get(), fullContent.length());
|
||||
})
|
||||
.doOnError(e -> log.error("[chatStreamWithInputs] Dify 流式响应错误", e));
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析 SSE 事件
|
||||
*/
|
||||
|
||||
@@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||
import cn.iocoder.yudao.module.tik.dify.service.DifyService;
|
||||
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatReqVO;
|
||||
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatRespVO;
|
||||
import cn.iocoder.yudao.module.tik.dify.vo.ForecastRewriteReqVO;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.annotation.Resource;
|
||||
@@ -39,4 +40,14 @@ public class AppDifyController {
|
||||
.map(CommonResult::success);
|
||||
}
|
||||
|
||||
@PostMapping(value = "/forecast/rewrite", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
@Operation(summary = "Forecast 文案改写(流式)")
|
||||
public Flux<CommonResult<DifyChatRespVO>> rewriteStream(@Valid @RequestBody ForecastRewriteReqVO reqVO) {
|
||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||
String userId = loginUserId != null ? loginUserId.toString() : "1"; // 默认用户ID
|
||||
|
||||
return difyService.rewriteStream(reqVO, userId)
|
||||
.map(CommonResult::success);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.tik.dify.service;
|
||||
|
||||
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatReqVO;
|
||||
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatRespVO;
|
||||
import cn.iocoder.yudao.module.tik.dify.vo.ForecastRewriteReqVO;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
@@ -20,4 +21,13 @@ public interface DifyService {
|
||||
*/
|
||||
Flux<DifyChatRespVO> chatStream(DifyChatReqVO reqVO, String userId);
|
||||
|
||||
/**
|
||||
* Forecast 文案改写(流式)
|
||||
*
|
||||
* @param reqVO 请求参数
|
||||
* @param userId 用户ID
|
||||
* @return 流式响应
|
||||
*/
|
||||
Flux<DifyChatRespVO> rewriteStream(ForecastRewriteReqVO reqVO, String userId);
|
||||
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.tik.dify.service;
|
||||
import cn.iocoder.yudao.module.tik.dify.client.DifyClient;
|
||||
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatReqVO;
|
||||
import cn.iocoder.yudao.module.tik.dify.vo.DifyChatRespVO;
|
||||
import cn.iocoder.yudao.module.tik.dify.vo.ForecastRewriteReqVO;
|
||||
import cn.iocoder.yudao.module.tik.enums.AiModelTypeEnum;
|
||||
import cn.iocoder.yudao.module.tik.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.module.tik.muye.aiagent.dal.AiAgentDO;
|
||||
@@ -16,6 +17,8 @@ import org.springframework.validation.annotation.Validated;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
@@ -133,9 +136,114 @@ public class DifyServiceImpl implements DifyService {
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<DifyChatRespVO> rewriteStream(ForecastRewriteReqVO reqVO, String userId) {
|
||||
// 用于存储预扣记录ID
|
||||
AtomicLong pendingRecordId = new AtomicLong();
|
||||
// 用于存储会话ID
|
||||
AtomicReference<String> conversationIdRef = new AtomicReference<>("");
|
||||
|
||||
return Mono.fromCallable(() -> {
|
||||
// 1. 获取智能体配置(通过 agentId 获取 systemPrompt)
|
||||
AiAgentDO agent = aiAgentService.getAiAgent(reqVO.getAgentId());
|
||||
if (agent == null) {
|
||||
throw new RuntimeException("智能体不存在");
|
||||
}
|
||||
|
||||
// 2. 根据 modelType 获取对应的积分配置
|
||||
AiModelTypeEnum modelTypeEnum = "forecast_meiju".equals(reqVO.getModelType())
|
||||
? AiModelTypeEnum.FORECAST_MEIJU
|
||||
: AiModelTypeEnum.FORECAST_STANDARD;
|
||||
AiModelConfigDO config = pointsService.getConfig(
|
||||
AiPlatformEnum.DIFY.getPlatform(),
|
||||
modelTypeEnum.getModelCode());
|
||||
|
||||
// 3. 预检积分
|
||||
pointsService.checkPoints(userId, config.getConsumePoints());
|
||||
|
||||
// 4. 创建预扣记录
|
||||
Long recordId = pointsService.createPendingDeduct(
|
||||
userId,
|
||||
config.getConsumePoints(),
|
||||
"forecast_rewrite",
|
||||
reqVO.getModelType()
|
||||
);
|
||||
pendingRecordId.set(recordId);
|
||||
|
||||
// 5. 构建 inputs 参数(使用 agent 的 systemPrompt)
|
||||
Map<String, Object> inputs = new HashMap<>();
|
||||
inputs.put("sysPrompt", agent.getSystemPrompt());
|
||||
inputs.put("userText", reqVO.getUserText());
|
||||
inputs.put("level", reqVO.getLevel());
|
||||
|
||||
// 6. 返回调用参数
|
||||
return new ForecastRewriteContext(inputs, config.getApiKey(), config.getConsumePoints());
|
||||
})
|
||||
.flatMapMany(context -> {
|
||||
// 7. 调用 Dify 流式 API
|
||||
return difyClient.chatStreamWithInputs(
|
||||
context.apiKey(),
|
||||
context.inputs(),
|
||||
reqVO.getUserText(),
|
||||
null
|
||||
)
|
||||
.doOnNext(resp -> {
|
||||
if (resp.getConversationId() != null) {
|
||||
conversationIdRef.set(resp.getConversationId());
|
||||
}
|
||||
})
|
||||
// 8. 流结束时确认扣费
|
||||
.doOnComplete(() -> {
|
||||
if (pendingRecordId.get() > 0) {
|
||||
try {
|
||||
pointsService.confirmPendingDeduct(pendingRecordId.get());
|
||||
log.info("[rewriteStream] 流结束,确认扣费,记录ID: {}", pendingRecordId.get());
|
||||
} catch (Exception e) {
|
||||
log.error("[rewriteStream] 确认扣费失败", e);
|
||||
}
|
||||
}
|
||||
})
|
||||
// 9. 流出错时取消预扣
|
||||
.doOnError(e -> {
|
||||
if (pendingRecordId.get() > 0) {
|
||||
try {
|
||||
pointsService.cancelPendingDeduct(pendingRecordId.get());
|
||||
log.info("[rewriteStream] 流出错,取消预扣,记录ID: {}", pendingRecordId.get());
|
||||
} catch (Exception ex) {
|
||||
log.error("[rewriteStream] 取消预扣失败", ex);
|
||||
}
|
||||
}
|
||||
})
|
||||
// 10. 用户取消时确认扣费(已消费的部分)
|
||||
.doOnCancel(() -> {
|
||||
if (pendingRecordId.get() > 0) {
|
||||
try {
|
||||
pointsService.confirmPendingDeduct(pendingRecordId.get());
|
||||
log.info("[rewriteStream] 用户取消,确认扣费,记录ID: {}", pendingRecordId.get());
|
||||
} catch (Exception e) {
|
||||
log.error("[rewriteStream] 用户取消后扣费失败", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
})
|
||||
// 11. 在最后添加 done 事件
|
||||
.concatWith(Mono.defer(() -> {
|
||||
return Mono.just(DifyChatRespVO.done(conversationIdRef.get(), null));
|
||||
}))
|
||||
.onErrorResume(e -> {
|
||||
log.error("[rewriteStream] Forecast 文案改写异常", e);
|
||||
return Flux.just(DifyChatRespVO.error(e.getMessage()));
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Dify 聊天上下文
|
||||
*/
|
||||
private record DifyChatContext(String systemPrompt, String apiKey, Integer consumePoints) {}
|
||||
|
||||
/**
|
||||
* Forecast 改写上下文
|
||||
*/
|
||||
private record ForecastRewriteContext(Map<String, Object> inputs, String apiKey, Integer consumePoints) {}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package cn.iocoder.yudao.module.tik.dify.vo;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* Forecast 文案改写请求 VO
|
||||
*/
|
||||
@Schema(description = "Forecast 文案改写请求")
|
||||
@Data
|
||||
public class ForecastRewriteReqVO {
|
||||
|
||||
@Schema(description = "智能体ID", requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@NotNull(message = "智能体ID不能为空")
|
||||
private Long agentId;
|
||||
|
||||
@Schema(description = "用户输入文案", requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@NotEmpty(message = "用户文案不能为空")
|
||||
private String userText;
|
||||
|
||||
@Schema(description = "改写级别/强度", example = "50")
|
||||
private Integer level = 50;
|
||||
|
||||
@Schema(description = "模型类型:forecast_standard-标准版 forecast_meiju-美剧版", example = "forecast_standard")
|
||||
private String modelType = "forecast_standard";
|
||||
|
||||
}
|
||||
@@ -20,6 +20,10 @@ public enum AiModelTypeEnum implements ArrayValuable<String> {
|
||||
DIFY_WRITING_PRO("writing_pro", "Pro深度版", AiPlatformEnum.DIFY, "text"),
|
||||
DIFY_WRITING_STANDARD("writing_standard", "标准版", AiPlatformEnum.DIFY, "text"),
|
||||
|
||||
// ========== Forecast 文案改写 ==========
|
||||
FORECAST_STANDARD("forecast_standard", "文案改写-标准版", AiPlatformEnum.DIFY, "text"),
|
||||
FORECAST_MEIJU("forecast_meiju", "文案改写-Pro版", AiPlatformEnum.DIFY, "text"),
|
||||
|
||||
// ========== 数字人模型 ==========
|
||||
DIGITAL_HUMAN_LATENTSYNC("latentsync", "LatentSync", AiPlatformEnum.DIGITAL_HUMAN, "video"),
|
||||
DIGITAL_HUMAN_KLING("kling", "可灵", AiPlatformEnum.DIGITAL_HUMAN, "video"),
|
||||
|
||||
Reference in New Issue
Block a user