Skip to content

Commit

Permalink
Fixed test
Browse files Browse the repository at this point in the history
  • Loading branch information
TetyanaYahodska committed Nov 18, 2024
1 parent 456fd9d commit 68aa1e0
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 188 deletions.
6 changes: 3 additions & 3 deletions tpu/src/main/java/tpu/ListQueuedResources.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ public static void main(String[] args) throws IOException {
// Project ID or project number of the Google Cloud project.
String projectId = "YOUR_PROJECT_ID";
// The zone in which the TPU was created.
String zone = "europe-west4-a";
String zone = "us-central1-a";

listQueuedResources(projectId, zone);
}

// List Queued Resources.
public static TpuClient.ListQueuedResourcesPagedResponse listQueuedResources(
public static TpuClient.ListQueuedResourcesPage listQueuedResources(
String projectId, String zone) throws IOException {
String parent = String.format("projects/%s/locations/%s", projectId, zone);
// Initialize client that will be used to send requests. This client only needs to be created
Expand All @@ -42,7 +42,7 @@ public static TpuClient.ListQueuedResourcesPagedResponse listQueuedResources(
ListQueuedResourcesRequest request =
ListQueuedResourcesRequest.newBuilder().setParent(parent).build();

return tpuClient.listQueuedResources(request);
return tpuClient.listQueuedResources(request).getPage();
}
}
}
Expand Down
70 changes: 0 additions & 70 deletions tpu/src/test/java/tpu/CreateTpuIT.java

This file was deleted.

53 changes: 41 additions & 12 deletions tpu/src/test/java/tpu/QueuedResourceIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,24 @@
import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest;
import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest;
import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest;
import com.google.cloud.tpu.v2alpha1.ListQueuedResourcesRequest;
import com.google.cloud.tpu.v2alpha1.QueuedResource;
import com.google.cloud.tpu.v2alpha1.TpuClient;
import com.google.cloud.tpu.v2alpha1.TpuSettings;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import org.junit.Before;
import org.junit.Test;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.MockedStatic;

@RunWith(JUnit4.class)
@Timeout(value = 3)
@Timeout(value = 30)
public class QueuedResourceIT {
private static final String PROJECT_ID = "project-id";
private static final String ZONE = "europe-west4-a";
Expand All @@ -52,10 +55,10 @@ public class QueuedResourceIT {
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
private static final String NETWORK_NAME = "default";
private ByteArrayOutputStream bout;
private static ByteArrayOutputStream bout;

@Before
public void setUp() {
@BeforeAll
public static void setUp() {
bout = new ByteArrayOutputStream();
System.setOut(new PrintStream(bout));
}
Expand All @@ -75,8 +78,8 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {

QueuedResource returnedQueuedResource =
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);

verify(mockTpuClient, times(1))
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
Expand All @@ -89,7 +92,6 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
public void testGetQueuedResource() throws IOException {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
TpuClient mockClient = mock(TpuClient.class);
GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class);
QueuedResource mockQueuedResource = mock(QueuedResource.class);

mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
Expand All @@ -99,14 +101,41 @@ public void testGetQueuedResource() throws IOException {
QueuedResource returnedQueuedResource =
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);

verify(mockGetQueuedResource, times(1))
.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
verify(mockClient, times(1))
.getQueuedResource(any(GetQueuedResourceRequest.class));
assertEquals(returnedQueuedResource, mockQueuedResource);
}
}

@Test
public void testDeleteTpuVm() {
public void testListTpuVm() throws IOException {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
QueuedResource queuedResource1 = mock(QueuedResource.class);
QueuedResource queuedResource2 = mock(QueuedResource.class);
List<QueuedResource> mockListQueuedResources =
Arrays.asList(queuedResource1, queuedResource2);

TpuClient mockClient = mock(TpuClient.class);
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
TpuClient.ListQueuedResourcesPagedResponse mockListQueuedResourcesResponse =
mock(TpuClient.ListQueuedResourcesPagedResponse.class);
when(mockClient.listQueuedResources(any(ListQueuedResourcesRequest.class)))
.thenReturn(mockListQueuedResourcesResponse);
TpuClient.ListQueuedResourcesPage mockQueuedResourcesPage =
mock(TpuClient.ListQueuedResourcesPage.class);
when(mockListQueuedResourcesResponse.getPage()).thenReturn(mockQueuedResourcesPage);
when(mockQueuedResourcesPage.getValues()).thenReturn(mockListQueuedResources);

TpuClient.ListQueuedResourcesPage returnedList =
ListQueuedResources.listQueuedResources(PROJECT_ID, ZONE);

assertThat(returnedList.getValues()).isEqualTo(mockListQueuedResources);
verify(mockClient, times(1)).listQueuedResources(any(ListQueuedResourcesRequest.class));
}
}

@Test
public void testDeleteForceQueuedResource() {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
TpuClient mockTpuClient = mock(TpuClient.class);
OperationFuture mockFuture = mock(OperationFuture.class);
Expand Down
99 changes: 0 additions & 99 deletions tpu/src/test/java/tpu/QueuedResourcesIT.java

This file was deleted.

35 changes: 31 additions & 4 deletions tpu/src/test/java/tpu/TpuVmIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package tpu;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
Expand All @@ -25,6 +26,7 @@
import static org.mockito.Mockito.when;

import com.google.api.gax.longrunning.OperationFuture;
import com.google.cloud.tpu.v2.CreateNodeRequest;
import com.google.cloud.tpu.v2.DeleteNodeRequest;
import com.google.cloud.tpu.v2.GetNodeRequest;
import com.google.cloud.tpu.v2.Node;
Expand All @@ -42,11 +44,13 @@
import org.mockito.MockedStatic;

@RunWith(JUnit4.class)
@Timeout(value = 3)
@Timeout(value = 30)
public class TpuVmIT {
private static final String PROJECT_ID = "project-id";
private static final String ZONE = "asia-east1-c";
private static final String NODE_NAME = "test-tpu";
private static final String TPU_TYPE = "v2-8";
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
private static ByteArrayOutputStream bout;

@BeforeAll
Expand All @@ -55,20 +59,43 @@ public static void setUp() {
System.setOut(new PrintStream(bout));
}

@Test
public void testCreateTpuVm() throws Exception {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
Node mockNode = mock(Node.class);
TpuClient mockTpuClient = mock(TpuClient.class);
OperationFuture mockFuture = mock(OperationFuture.class);

mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
.thenReturn(mockTpuClient);
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
.thenReturn(mockFuture);
when(mockFuture.get()).thenReturn(mockNode);

Node returnedNode = CreateTpuVm.createTpuVm(
PROJECT_ID, ZONE, NODE_NAME,
TPU_TYPE, TPU_SOFTWARE_VERSION);

verify(mockTpuClient, times(1))
.createNodeAsync(any(CreateNodeRequest.class));
verify(mockFuture, times(1)).get();
assertEquals(returnedNode, mockNode);
}
}

@Test
public void testGetTpuVm() throws IOException {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
Node mockNode = mock(Node.class);
TpuClient mockClient = mock(TpuClient.class);
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);

mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode);

Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);

verify(mockGetTpuVm, times(1))
.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
verify(mockClient, times(1))
.getNode(any(GetNodeRequest.class));
assertThat(returnedNode).isEqualTo(mockNode);
}
}
Expand Down

0 comments on commit 68aa1e0

Please sign in to comment.