diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgent.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgent.java index bc852556..b89e7428 100644 --- a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgent.java +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgent.java @@ -19,13 +19,17 @@ import com.alibaba.cloud.ai.dashscope.api.DashScopeAgentApi; import com.alibaba.cloud.ai.dashscope.api.DashScopeAgentApi.DashScopeAgentRequest; import com.alibaba.cloud.ai.dashscope.api.DashScopeAgentApi.DashScopeAgentResponse; +import com.alibaba.cloud.ai.dashscope.api.DashScopeAgentApi.DashScopeAgentRequest.DashScopeAgentRequestInput.DashScopeAgentRequestMessage; +import com.alibaba.cloud.ai.dashscope.api.DashScopeAgentApi.DashScopeAgentRequest.DashScopeAgentRequestParameters.DashScopeAgentRequestRagOptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.http.ResponseEntity; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -63,6 +67,7 @@ public DashScopeAgent(DashScopeAgentApi dashScopeAgentApi) { .withMemoryId(null) .withIncrementalOutput(false) .withHasThoughts(false) + .withImages(null) .withBizParams(null) .build(); } @@ -98,24 +103,30 @@ public Flux stream(Prompt prompt) { } private DashScopeAgentRequest toRequest(Prompt prompt, Boolean stream) { - if (prompt == null || prompt.getOptions() == null) { + if (prompt == null) { throw new IllegalArgumentException("option is null"); } - String appId = null; - if (prompt.getOptions() instanceof DashScopeAgentOptions options) { - appId = options.getAppId(); - } + DashScopeAgentOptions runtimeOptions = mergeOptions(prompt.getOptions()); + String appId = runtimeOptions.getAppId(); if (appId == null) { throw new IllegalArgumentException("appId must be set"); } + DashScopeAgentRagOptions ragOptions = runtimeOptions.getRagOptions(); return new DashScopeAgentRequest(appId, - new DashScopeAgentRequest.DashScopeAgentRequestInput(prompt.getContents(), this.options.getSessionId(), - this.options.getMemoryId(), this.options.getBizParams()), - new DashScopeAgentRequest.DashScopeAgentRequestParameters(this.options.getHasThoughts(), - stream && this.options.getIncrementalOutput())); + new DashScopeAgentRequest.DashScopeAgentRequestInput(null, prompt.getInstructions() + .stream() + .map(msg -> new DashScopeAgentRequestMessage(msg.getMessageType().getValue(), msg.getText())) + .toList(), runtimeOptions.getSessionId(), runtimeOptions.getMemoryId(), runtimeOptions.getImages(), + runtimeOptions.getBizParams()), + new DashScopeAgentRequest.DashScopeAgentRequestParameters(runtimeOptions.getHasThoughts(), + stream && runtimeOptions.getIncrementalOutput(), + ragOptions == null ? null + : new DashScopeAgentRequestRagOptions(ragOptions.getPipelineIds(), + ragOptions.getFileIds(), ragOptions.getMetadataFilter(), ragOptions.getTags(), + ragOptions.getStructuredFilter(), ragOptions.getSessionFileIds()))); } private ChatResponse toChatResponse(DashScopeAgentResponse response) { @@ -139,4 +150,10 @@ private ChatResponse toChatResponse(DashScopeAgentResponse response) { return new ChatResponse(List.of(generation)); } + private DashScopeAgentOptions mergeOptions(ChatOptions chatOptions) { + DashScopeAgentOptions agentOptions = ModelOptionsUtils.copyToTarget(chatOptions, ChatOptions.class, + DashScopeAgentOptions.class); + return ModelOptionsUtils.merge(agentOptions, this.options, DashScopeAgentOptions.class); + } + } diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentOptions.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentOptions.java index 5b627197..7921a282 100644 --- a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentOptions.java +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentOptions.java @@ -17,6 +17,7 @@ import java.util.List; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.JsonNode; import org.springframework.ai.chat.prompt.ChatOptions; @@ -29,18 +30,30 @@ public class DashScopeAgentOptions implements ChatOptions { + @JsonProperty("app_id") private String appId; + @JsonProperty("session_id") private String sessionId; + @JsonProperty("memory_id") private String memoryId; + @JsonProperty("incremental_output") private Boolean incrementalOutput; + @JsonProperty("has_thoughts") private Boolean hasThoughts; + @JsonProperty("images") + private List images; + + @JsonProperty("biz_params") private JsonNode bizParams; + @JsonProperty("rag_options") + private DashScopeAgentRagOptions ragOptions; + @Override public String getModel() { return null; @@ -121,6 +134,14 @@ public void setHasThoughts(Boolean hasThoughts) { this.hasThoughts = hasThoughts; } + public List getImages() { + return images; + } + + public void setImages(List images) { + this.images = images; + } + public JsonNode getBizParams() { return bizParams; } @@ -129,6 +150,14 @@ public void setBizParams(JsonNode bizParams) { this.bizParams = bizParams; } + public DashScopeAgentRagOptions getRagOptions() { + return ragOptions; + } + + public void setRagOptions(DashScopeAgentRagOptions ragOptions) { + this.ragOptions = ragOptions; + } + @Override public ChatOptions copy() { return DashScopeAgentOptions.fromOptions(this); @@ -186,11 +215,21 @@ public Builder withHasThoughts(Boolean hasThoughts) { return this; } + public Builder withImages(List images) { + this.options.images = images; + return this; + } + public Builder withBizParams(JsonNode bizParams) { this.options.bizParams = bizParams; return this; } + public Builder withRagOptions(DashScopeAgentRagOptions ragOptions) { + this.options.ragOptions = ragOptions; + return this; + } + public DashScopeAgentOptions build() { return this.options; } @@ -205,6 +244,7 @@ public String toString() { sb.append(", memoryId='").append(memoryId).append('\''); sb.append(", incrementalOutput=").append(incrementalOutput); sb.append(", hasThoughts=").append(hasThoughts); + sb.append(", images=").append(images); sb.append(", bizParams=").append(bizParams); sb.append('}'); return sb.toString(); diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentRagOptions.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentRagOptions.java new file mode 100644 index 00000000..01ed038a --- /dev/null +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentRagOptions.java @@ -0,0 +1,153 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.dashscope.agent; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; + +import java.util.List; + +/** + * @author kevinlin09 + */ +public class DashScopeAgentRagOptions { + + /** knowledge base ids */ + @JsonProperty("pipeline_ids") + private List pipelineIds; + + /** file ids of knowledge base */ + @JsonProperty("file_ids") + private List fileIds; + + /** tags of knowledge base */ + @JsonProperty("tags") + private List tags; + + /** metadata filter of knowledge base query */ + @JsonProperty("metadata_filter") + private JsonNode metadataFilter; + + /** structured filter of knowledge base query */ + @JsonProperty("structured_filter") + private JsonNode structuredFilter; + + /** file ID is a temporary file associated with the current session */ + @JsonProperty("session_file_ids") + private List sessionFileIds; + + public List getPipelineIds() { + return pipelineIds; + } + + public void setPipelineIds(List pipelineIds) { + this.pipelineIds = pipelineIds; + } + + public List getFileIds() { + return fileIds; + } + + public void setFileIds(List fileIds) { + this.fileIds = fileIds; + } + + public List getTags() { + return tags; + } + + public void setTags(List tags) { + this.tags = tags; + } + + public JsonNode getMetadataFilter() { + return metadataFilter; + } + + public void setMetadataFilter(JsonNode metadataFilter) { + this.metadataFilter = metadataFilter; + } + + public JsonNode getStructuredFilter() { + return structuredFilter; + } + + public void setStructuredFilter(JsonNode structuredFilter) { + this.structuredFilter = structuredFilter; + } + + public List getSessionFileIds() { + return sessionFileIds; + } + + public void setSessionFileIds(List sessionFileIds) { + this.sessionFileIds = sessionFileIds; + } + + public static DashScopeAgentRagOptions.Builder builder() { + + return new DashScopeAgentRagOptions.Builder(); + } + + public static class Builder { + + protected DashScopeAgentRagOptions options; + + public Builder() { + this.options = new DashScopeAgentRagOptions(); + } + + public Builder(DashScopeAgentRagOptions options) { + this.options = options; + } + + public Builder withPipelineIds(List pipelineIds) { + this.options.pipelineIds = pipelineIds; + return this; + } + + public Builder withFileIds(List fileIds) { + this.options.fileIds = fileIds; + return this; + } + + public Builder withTags(List tags) { + this.options.tags = tags; + return this; + } + + public Builder withMetadataFilter(JsonNode metadataFilter) { + this.options.metadataFilter = metadataFilter; + return this; + } + + public Builder withStructuredFilter(JsonNode structuredFilter) { + this.options.structuredFilter = structuredFilter; + return this; + } + + public Builder withSessionFileIds(List sessionFileIds) { + this.options.sessionFileIds = sessionFileIds; + return this; + } + + public DashScopeAgentRagOptions build() { + return this.options; + } + + } + +} diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/api/DashScopeAgentApi.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/api/DashScopeAgentApi.java index a7f24ce4..d9305e21 100644 --- a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/api/DashScopeAgentApi.java +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/api/DashScopeAgentApi.java @@ -80,12 +80,12 @@ public DashScopeAgentApi(String baseUrl, String apiKey, String workSpaceId, Rest } public ResponseEntity call(DashScopeAgentRequest request) { - String uri = "/api/v1/apps/" + request.app_id() + "/completion"; + String uri = "/api/v1/apps/" + request.appId() + "/completion"; return restClient.post().uri(uri).body(request).retrieve().toEntity(DashScopeAgentResponse.class); } public Flux stream(DashScopeAgentRequest request) { - String uri = "/api/v1/apps/" + request.app_id() + "/completion"; + String uri = "/api/v1/apps/" + request.appId() + "/completion"; return webClient.post() .uri(uri) .body(Mono.just(request), DashScopeAgentResponse.class) @@ -96,47 +96,82 @@ public Flux stream(DashScopeAgentRequest request) { }); } + // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) - public record DashScopeAgentRequest(String app_id, @JsonProperty("input") DashScopeAgentRequestInput input, + public record DashScopeAgentRequest( + @JsonProperty("app_id") String appId, + @JsonProperty("input") DashScopeAgentRequestInput input, @JsonProperty("parameters") DashScopeAgentRequestParameters parameters) { @JsonInclude(JsonInclude.Include.NON_NULL) - public record DashScopeAgentRequestInput(@JsonProperty("prompt") String prompt, - @JsonProperty("session_id") String sessionId, @JsonProperty("memory_id") String memoryId, + public record DashScopeAgentRequestInput( + @JsonProperty("prompt") String prompt, + @JsonProperty("messages") List messages, + @JsonProperty("session_id") String sessionId, + @JsonProperty("memory_id") String memoryId, + @JsonProperty("image_list") List images, @JsonProperty("biz_params") JsonNode bizParams) { + @JsonInclude(JsonInclude.Include.NON_NULL) + public record DashScopeAgentRequestMessage( + @JsonProperty("role") String role, + @JsonProperty("content") String content) { + } } @JsonInclude(JsonInclude.Include.NON_NULL) - public record DashScopeAgentRequestParameters(@JsonProperty("has_thoughts") Boolean hasThoughts, - @JsonProperty("incremental_output") Boolean incrementalOutput - + public record DashScopeAgentRequestParameters( + @JsonProperty("has_thoughts") Boolean hasThoughts, + @JsonProperty("incremental_output") Boolean incrementalOutput, + @JsonProperty("rag_options") DashScopeAgentRequestRagOptions ragOptions ) { + @JsonInclude(JsonInclude.Include.NON_NULL) + public record DashScopeAgentRequestRagOptions( + @JsonProperty("pipeline_ids") List pipelineIds, + @JsonProperty("file_ids") List fileIds, + @JsonProperty("metadata_filter") JsonNode metadataFilter, + @JsonProperty("tags") List tags, + @JsonProperty("structured_filter") JsonNode structuredFilter, + @JsonProperty("session_file_ids") List sessionFileIds) { + } } } @JsonInclude(JsonInclude.Include.NON_NULL) - public record DashScopeAgentResponse(@JsonProperty("status_code") Integer statusCode, - @JsonProperty("request_id") String requestId, @JsonProperty("code") String code, - @JsonProperty("message") String message, @JsonProperty("output") DashScopeAgentResponseOutput output, + public record DashScopeAgentResponse( + @JsonProperty("status_code") Integer statusCode, + @JsonProperty("request_id") String requestId, + @JsonProperty("code") String code, + @JsonProperty("message") String message, + @JsonProperty("output") DashScopeAgentResponseOutput output, @JsonProperty("usage") DashScopeAgentResponseUsage usage) { @JsonInclude(JsonInclude.Include.NON_NULL) - public record DashScopeAgentResponseOutput(@JsonProperty("text") String text, - @JsonProperty("finish_reason") String finishReason, @JsonProperty("session_id") String sessionId, + public record DashScopeAgentResponseOutput( + @JsonProperty("text") String text, + @JsonProperty("finish_reason") String finishReason, + @JsonProperty("session_id") String sessionId, @JsonProperty("thoughts") List thoughts, @JsonProperty("doc_references") List docReferences) { @JsonInclude(JsonInclude.Include.NON_NULL) - public record DashScopeAgentResponseOutputThoughts(@JsonProperty("thought") String thought, - @JsonProperty("action_type") String actionType, @JsonProperty("action_name") String actionName, + public record DashScopeAgentResponseOutputThoughts( + @JsonProperty("thought") String thought, + @JsonProperty("action_type") String actionType, + @JsonProperty("action_name") String actionName, @JsonProperty("action") String action, @JsonProperty("action_input_stream") String actionInputStream, - @JsonProperty("action_input") String actionInput, @JsonProperty("response") String response, - @JsonProperty("observation") String observation) { + @JsonProperty("action_input") String actionInput, + @JsonProperty("response") String response, + @JsonProperty("observation") String observation, + @JsonProperty("reasoning_content") String reasoningContent) { } @JsonInclude(JsonInclude.Include.NON_NULL) - public record DashScopeAgentResponseOutputDocReference(@JsonProperty("index_id") String indexId, - @JsonProperty("title") String title, @JsonProperty("doc_id") String docId, - @JsonProperty("doc_name") String docName, @JsonProperty("text") String text, - @JsonProperty("images") List images) { + public record DashScopeAgentResponseOutputDocReference( + @JsonProperty("index_id") String indexId, + @JsonProperty("title") String title, + @JsonProperty("doc_id") String docId, + @JsonProperty("doc_name") String docName, + @JsonProperty("text") String text, + @JsonProperty("images") List images, + @JsonProperty("page_number") List pageNumber) { } } @@ -144,7 +179,8 @@ public record DashScopeAgentResponseOutputDocReference(@JsonProperty("index_id") public record DashScopeAgentResponseUsage( @JsonProperty("models") List models) { @JsonInclude(JsonInclude.Include.NON_NULL) - public record DashScopeAgentResponseUsageModels(@JsonProperty("model_id") String modelId, + public record DashScopeAgentResponseUsageModels( + @JsonProperty("model_id") String modelId, @JsonProperty("input_tokens") Integer inputTokens, @JsonProperty("output_tokens") Integer outputTokens) { } diff --git a/spring-ai-alibaba-core/src/test/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentTests.java b/spring-ai-alibaba-core/src/test/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentTests.java index dc41adf4..3380455b 100644 --- a/spring-ai-alibaba-core/src/test/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentTests.java +++ b/spring-ai-alibaba-core/src/test/java/com/alibaba/cloud/ai/dashscope/agent/DashScopeAgentTests.java @@ -15,6 +15,20 @@ */ package com.alibaba.cloud.ai.dashscope.agent; +import com.alibaba.cloud.ai.dashscope.common.DashScopeException; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import com.alibaba.cloud.ai.dashscope.api.DashScopeAgentApi; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.concurrent.CountDownLatch; + /** * Title Dashscope Agent test cases.
* Description Dashscope Agent test cases.
@@ -26,4 +40,87 @@ class DashScopeAgentTests { + private static final Logger logger = LoggerFactory.getLogger(DashScopeAgentTests.class); + + private static final String TEST_API_KEY = System.getenv("DASHSCOPE_API_KEY"); + + private static final String TEST_APP_ID = System.getenv("APP_ID"); + + private static final String TEST_FILE_ID = System.getenv("FILE_ID"); + + private final DashScopeAgentApi dashscopeAgentApi = new DashScopeAgentApi(TEST_API_KEY); + + @Test + void callWithRagOptionsFileIds() { + DashScopeAgent dashScopeAgent = new DashScopeAgent(dashscopeAgentApi); + + Flux response = dashScopeAgent.stream(new Prompt("梁随板失败怎么办?", + DashScopeAgentOptions.builder() + .withAppId(TEST_APP_ID) + .withIncrementalOutput(true) + .withRagOptions(DashScopeAgentRagOptions.builder().withFileIds(List.of(TEST_FILE_ID)).build()) + .build())); + + printResponse(response); + } + + @Test + void callWithImageList() { + DashScopeAgent dashScopeAgent = new DashScopeAgent(dashscopeAgentApi); + + Flux response = dashScopeAgent.stream(new Prompt("图中描绘的是什么景象?", DashScopeAgentOptions.builder() + .withAppId(TEST_APP_ID) + .withIncrementalOutput(true) + .withImages(List + .of("https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg")) + .build())); + + printResponse(response); + } + + @Test + void callWithSystemMessage() { + DashScopeAgent dashScopeAgent = new DashScopeAgent(dashscopeAgentApi); + + Flux response = dashScopeAgent + .stream(new Prompt(List.of(new SystemMessage("你是一个新闻记者,请记住你的角色。"), new UserMessage("你是谁?")), + DashScopeAgentOptions.builder().withAppId(TEST_APP_ID).withIncrementalOutput(true).build())); + + printResponse(response); + } + + @Test + void callWithDeepSeeek() { + DashScopeAgent dashScopeAgent = new DashScopeAgent(dashscopeAgentApi, + DashScopeAgentOptions.builder() + .withAppId(TEST_APP_ID) + .withIncrementalOutput(true) + .withHasThoughts(true) + .build()); + + Flux response = dashScopeAgent.stream(new Prompt("x的平方等于4,x等于多少?")); + + printResponse(response); + } + + static void printResponse(Flux response) { + CountDownLatch cdl = new CountDownLatch(1); + response.subscribe(data -> { + System.out.printf("%s%n", data.getResult().getOutput()); + }, err -> { + logger.error("err: {}", err.getMessage(), err); + }, () -> { + System.out.println("\n"); + logger.info("done"); + cdl.countDown(); + }); + + try { + cdl.await(); + } + catch (InterruptedException e) { + throw new DashScopeException(e.getMessage()); + } + } + } \ No newline at end of file