Skip to content

Commit

Permalink
Merge pull request #427 from brianxiadong/feat-brianxiadong-testcase
Browse files Browse the repository at this point in the history
Add Some Test Cases for DashScope Components
  • Loading branch information
yuluo-yx authored Feb 26, 2025
2 parents 8fda877 + 21e3662 commit 8deed8e
Show file tree
Hide file tree
Showing 21 changed files with 1,766 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public void delete(Filter.Expression filterExpression) {
@Override
public List<Document> similaritySearch(String query) {

return similaritySearch(SearchRequest.builder().query(query).toString());
return similaritySearch(SearchRequest.builder().query(query).build());

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void testBasicChatCompletion() {
assertThat(response).isNotNull();
assertThat(response.getResult()).isNotNull();
assertThat(response.getResult().getOutput()).isInstanceOf(AssistantMessage.class);
assertThat(response.getResult().getOutput().getContent()).isEqualTo(TEST_RESPONSE);
assertThat(response.getResult().getOutput().getText()).isEqualTo(TEST_RESPONSE);
assertThat(response.getMetadata().getId()).isEqualTo(TEST_REQUEST_ID);
}

Expand Down Expand Up @@ -151,11 +151,11 @@ void testStreamChatCompletion() {

// Verify results
StepVerifier.create(responseFlux).assertNext(response -> {
assertThat(response.getResult().getOutput().getContent()).isEqualTo("I'm ");
assertThat(response.getResult().getOutput().getText()).isEqualTo("I'm ");
}).assertNext(response -> {
assertThat(response.getResult().getOutput().getContent()).isEqualTo("doing ");
assertThat(response.getResult().getOutput().getText()).isEqualTo("doing ");
}).assertNext(response -> {
assertThat(response.getResult().getOutput().getContent()).isEqualTo("well!");
assertThat(response.getResult().getOutput().getText()).isEqualTo("well!");
assertThat(response.getMetadata().getUsage()).isNotNull();
}).verifyComplete();
}
Expand Down Expand Up @@ -185,7 +185,7 @@ void testSystemMessage() {
ChatResponse chatResponse = chatModel.call(prompt);

assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getResults().get(0).getOutput().getContent()).isEqualTo(response);
assertThat(chatResponse.getResults().get(0).getOutput().getText()).isEqualTo(response);
}

@Test
Expand Down Expand Up @@ -223,7 +223,7 @@ void testToolCalls() {
ChatResponse response = toolChatModel.call(prompt);

assertThat(response).isNotNull();
assertThat(response.getResults().get(0).getOutput().getContent()).contains("get_weather");
assertThat(response.getResults().get(0).getOutput().getText()).contains("get_weather");
}

@Test
Expand Down Expand Up @@ -270,9 +270,9 @@ void testStreamToolCalls() {

assertThat(responses).isNotNull();
assertThat(responses).hasSize(3);
assertThat(responses.get(0).getResults().get(0).getOutput().getContent()).isEqualTo(chunk1);
assertThat(responses.get(1).getResults().get(0).getOutput().getContent()).isEqualTo(chunk2);
assertThat(responses.get(2).getResults().get(0).getOutput().getContent()).isEqualTo(chunk3);
assertThat(responses.get(0).getResults().get(0).getOutput().getText()).isEqualTo(chunk1);
assertThat(responses.get(1).getResults().get(0).getOutput().getText()).isEqualTo(chunk2);
assertThat(responses.get(2).getResults().get(0).getOutput().getText()).isEqualTo(chunk3);
}

@Test
Expand Down Expand Up @@ -395,7 +395,7 @@ void testMultipleMessagesInPrompt() {
ChatResponse response = chatModel.call(prompt);

assertThat(response).isNotNull();
assertThat(response.getResult().getOutput().getContent()).isEqualTo("It's sunny today!");
assertThat(response.getResult().getOutput().getText()).isEqualTo("It's sunny today!");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.lang.reflect.Field;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static org.assertj.core.api.Assertions.assertThat;

/**
* Test cases for DashScopeWebSocketClient. Tests cover WebSocket connection, message
Expand All @@ -43,4 +47,166 @@
*/
class DashScopeWebSocketClientTests {

private static final String TEST_API_KEY = "test-api-key";

private static final String TEST_WORKSPACE_ID = "test-workspace";

private static final String TEST_MESSAGE = "Hello, WebSocket!";

private DashScopeWebSocketClient client;

private WebSocket mockWebSocket;

private Response mockResponse;

@BeforeEach
void setUp() {
// Initialize mocks
mockWebSocket = mock(WebSocket.class);
mockResponse = mock(Response.class);

// Set up basic mock behavior
when(mockWebSocket.send(any(String.class))).thenReturn(true);
when(mockWebSocket.send(any(ByteString.class))).thenReturn(true);

// Configure client options
DashScopeWebSocketClientOptions options = DashScopeWebSocketClientOptions.builder()
.withApiKey(TEST_API_KEY)
.withWorkSpaceId(TEST_WORKSPACE_ID)
.build();

// Initialize client
client = new DashScopeWebSocketClient(options);

// Set webSocketClient using reflection
try {
Field webSocketClientField = DashScopeWebSocketClient.class.getDeclaredField("webSocketClient");
webSocketClientField.setAccessible(true);
webSocketClientField.set(client, mockWebSocket);

// Set isOpen to true
Field isOpenField = DashScopeWebSocketClient.class.getDeclaredField("isOpen");
isOpenField.setAccessible(true);
isOpenField.set(client, new AtomicBoolean(true));
}
catch (Exception e) {
throw new RuntimeException("Failed to set fields via reflection", e);
}
}

@Test
void testWebSocketEvents() {
// Test sending text message
client.sendText(TEST_MESSAGE);
verify(mockWebSocket).send(TEST_MESSAGE);

// Test receiving task started event
client.onMessage(mockWebSocket, createTaskStartedMessage());

// Test receiving result generated event
client.onMessage(mockWebSocket, createResultGeneratedMessage());

// Test receiving task finished event
client.onMessage(mockWebSocket, createTaskFinishedMessage());
}

@Test
void testStreamBinaryOut() {
// Test binary streaming
String testText = "Test binary streaming";
Flux<ByteBuffer> result = client.streamBinaryOut(testText);

StepVerifier.create(result).expectSubscription().then(() -> {
// Simulate binary message
ByteString testBinary = ByteString.of(ByteBuffer.wrap("test data".getBytes()));
client.onMessage(mockWebSocket, testBinary);
})
.expectNextMatches(buffer -> buffer.hasRemaining())
.then(() -> client.onMessage(mockWebSocket, createTaskFinishedMessage()))
.verifyComplete();
}

@Test
void testStreamTextOut() {
// Test text streaming
ByteBuffer testBuffer = ByteBuffer.wrap("Test text streaming".getBytes());
Flux<String> result = client.streamTextOut(Flux.just(testBuffer));

StepVerifier.create(result)
.expectSubscription()
.then(() -> client.onMessage(mockWebSocket, createResultGeneratedMessage()))
.expectNextMatches(text -> text.contains("result"))
.then(() -> client.onMessage(mockWebSocket, createTaskFinishedMessage()))
.verifyComplete();
}

@Test
void testErrorHandling() {
// Test error handling
Exception testException = new Exception("Test error");
client.onFailure(mockWebSocket, testException, mockResponse);

// Verify error is propagated to emitters
StepVerifier.create(client.streamBinaryOut(TEST_MESSAGE)).expectError().verify();
}

@Test
void testTaskFailedEvent() {
// Test task failed event
client.onMessage(mockWebSocket, createTaskFailedMessage());

// Verify error is propagated to emitters
StepVerifier.create(client.streamTextOut(Flux.just(ByteBuffer.wrap("test".getBytes())))).expectError().verify();
}

private String createTaskStartedMessage() {
return """
{
"header": {
"task_id": "test-task-id",
"event": "task-started"
},
"payload": {}
}""";
}

private String createResultGeneratedMessage() {
return """
{
"header": {
"task_id": "test-task-id",
"event": "result-generated"
},
"payload": {
"output": {
"text": "test result"
}
}
}""";
}

private String createTaskFinishedMessage() {
return """
{
"header": {
"task_id": "test-task-id",
"event": "task-finished"
},
"payload": {}
}""";
}

private String createTaskFailedMessage() {
return """
{
"header": {
"task_id": "test-task-id",
"event": "task-failed",
"error_code": "500",
"error_message": "Test error"
},
"payload": {}
}""";
}

}

This file was deleted.

This file was deleted.

Loading

0 comments on commit 8deed8e

Please sign in to comment.