Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Some Test Cases for DashScope Components #424

Merged
merged 6 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ public DashScopeChatModel(DashScopeApi dashscopeApi, DashScopeChatOptions option

@Override
public ChatResponse call(Prompt prompt) {
Assert.notNull(prompt, "Prompt must not be null");
Assert.isTrue(!CollectionUtils.isEmpty(prompt.getInstructions()), "Prompt messages must not be empty");

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
Expand Down Expand Up @@ -219,6 +221,8 @@ public ChatOptions getDefaultOptions() {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
Assert.notNull(prompt, "Prompt must not be null");
Assert.isTrue(!CollectionUtils.isEmpty(prompt.getInstructions()), "Prompt messages must not be empty");

return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,13 @@ public String getName() {
}

// Request
@Override
protected KeyValues requestStopSequences(KeyValues keyValues, ChatModelObservationContext context) {
if (context.getRequest().getOptions() instanceof DashScopeChatOptions) {
List<Object> stop = ((DashScopeChatOptions) context.getRequest().getOptions()).getStop();
if (CollectionUtils.isEmpty(stop)) {
return keyValues;
}
KeyValue.of(ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES, stop,
Objects::nonNull);

String stopSequences;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
* @author nuocheng.lxm
Expand Down Expand Up @@ -83,6 +84,8 @@ public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, DashScopeImageOp

@Override
public ImageResponse call(ImagePrompt request) {
Assert.notNull(request, "Prompt must not be null");
Assert.isTrue(!CollectionUtils.isEmpty(request.getInstructions()), "Prompt messages must not be empty");

String taskId = submitImageGenTask(request);
if (taskId == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class DashScopeAudioTranscriptionResponseMetadata extends AudioTranscript

};

protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }";
protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %2$s }";

@Nullable
private RateLimit rateLimit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,113 @@
*/
package com.alibaba.cloud.ai.dashscope.api;

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static com.alibaba.cloud.ai.dashscope.common.DashScopeApiConstants.*;
import static org.assertj.core.api.Assertions.assertThat;

/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* Test cases for ApiUtils
*
* @author brianxiadong
* @since 2025-02-24
*/
class ApiUtilsTests {

private static final String TEST_API_KEY = "test-api-key";

private static final String TEST_WORKSPACE_ID = "test-workspace";

@Test
void testGetJsonContentHeadersWithApiKeyOnly() {
// Test getting JSON content headers with API key only
HttpHeaders headers = new HttpHeaders();
ApiUtils.getJsonContentHeaders(TEST_API_KEY).accept(headers);

assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + TEST_API_KEY);
assertThat(headers.getFirst(HEADER_OPENAPI_SOURCE)).isEqualTo(SOURCE_FLAG);
assertThat(headers.getContentType()).isEqualTo(MediaType.APPLICATION_JSON);
assertThat(headers.getFirst("user-agent")).contains(SDK_FLAG);
}

@Test
void testGetJsonContentHeadersWithWorkspaceId() {
// Test getting JSON content headers with workspace ID
HttpHeaders headers = new HttpHeaders();
ApiUtils.getJsonContentHeaders(TEST_API_KEY, TEST_WORKSPACE_ID).accept(headers);

assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + TEST_API_KEY);
assertThat(headers.getFirst(HEADER_OPENAPI_SOURCE)).isEqualTo(SOURCE_FLAG);
assertThat(headers.getFirst(HEADER_WORK_SPACE_ID)).isEqualTo(TEST_WORKSPACE_ID);
assertThat(headers.getContentType()).isEqualTo(MediaType.APPLICATION_JSON);
assertThat(headers.getFirst("user-agent")).contains(SDK_FLAG);
}

@Test
void testGetJsonContentHeadersWithStream() {
// Test getting JSON content headers with stream enabled
HttpHeaders headers = new HttpHeaders();
ApiUtils.getJsonContentHeaders(TEST_API_KEY, TEST_WORKSPACE_ID, true).accept(headers);

assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + TEST_API_KEY);
assertThat(headers.getFirst(HEADER_OPENAPI_SOURCE)).isEqualTo(SOURCE_FLAG);
assertThat(headers.getFirst(HEADER_WORK_SPACE_ID)).isEqualTo(TEST_WORKSPACE_ID);
assertThat(headers.getContentType()).isEqualTo(MediaType.APPLICATION_JSON);
assertThat(headers.getFirst("X-DashScope-SSE")).isEqualTo("enable");
assertThat(headers.getFirst("user-agent")).contains(SDK_FLAG);
}

@Test
void testGetMapContentHeaders() {
// Test getting map content headers
Map<String, String> customHeaders = new HashMap<>();
customHeaders.put("Custom-Header", "custom-value");

Map<String, String> headers = ApiUtils.getMapContentHeaders(TEST_API_KEY, true, TEST_WORKSPACE_ID,
customHeaders);

assertThat(headers.get("Authorization")).isEqualTo("bearer " + TEST_API_KEY);
assertThat(headers.get("X-DashScope-WorkSpace")).isEqualTo(TEST_WORKSPACE_ID);
assertThat(headers.get("X-DashScope-DataInspection")).isEqualTo("enable");
assertThat(headers.get("Custom-Header")).isEqualTo("custom-value");
assertThat(headers.get("user-agent")).contains(SDK_FLAG);
}

@Test
void testGetAudioTranscriptionHeaders() {
// Test getting audio transcription headers
HttpHeaders headers = new HttpHeaders();
ApiUtils.getAudioTranscriptionHeaders(TEST_API_KEY, TEST_WORKSPACE_ID, true, true, true).accept(headers);

assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + TEST_API_KEY);
assertThat(headers.getFirst("X-DashScope-WorkSpace")).isEqualTo(TEST_WORKSPACE_ID);
assertThat(headers.getFirst("X-DashScope-DataInspection")).isEqualTo("enable");
assertThat(headers.getFirst("X-DashScope-Async")).isEqualTo("enable");
assertThat(headers.getFirst("X-DashScope-SSE")).isEqualTo("enable");
assertThat(headers.getFirst("Cache-Control")).isEqualTo("no-cache");
assertThat(headers.getFirst("X-Accel-Buffering")).isEqualTo("no");
assertThat(headers.getFirst(HttpHeaders.ACCEPT)).isEqualTo("text/event-stream");
assertThat(headers.getContentType()).isEqualTo(MediaType.APPLICATION_JSON);
}

@Test
void testGetFileUploadHeaders() {
// Test getting file upload headers
Map<String, String> input = new HashMap<>();
input.put("Content-Type", "multipart/form-data");
input.put("Custom-Header", "custom-value");

HttpHeaders headers = new HttpHeaders();
ApiUtils.getFileUploadHeaders(input).accept(headers);

assertThat(Objects.requireNonNull(headers.getContentType()).toString()).isEqualTo("multipart/form-data");
assertThat(headers.getFirst("Custom-Header")).isEqualTo("custom-value");
}

}
Loading
Loading