Skip to content

Commit

Permalink
feat(tpu): add tpu vm create topology sample. (#9611)
Browse files Browse the repository at this point in the history
* Changed package, added information to CODEOWNERS

* Added information to CODEOWNERS

* Added timeout

* Fixed parameters for test

* Fixed DeleteTpuVm and naming

* Added comment, created Util class

* Fixed naming

* Fixed whitespace

* Split PR into smaller, deleted redundant code

* Implemented tpu_vm_create_topology sample, created test

* Changed zone

* Fixed empty lines and tests, deleted cleanup method

* Fixed tests

* Fixed test

* Fixed imports

* Increased timeout to 10 sec

* Fixed tests

* Fixed tests

* Deleted settings

* Made ByteArrayOutputStream bout as local variable

* Changed timeout to 10 sec
  • Loading branch information
TetyanaYahodska authored Nov 20, 2024
1 parent 7f84b30 commit bcecf8f
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 93 deletions.
85 changes: 85 additions & 0 deletions tpu/src/main/java/tpu/CreateTpuWithTopologyFlag.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package tpu;

//[START tpu_vm_create_topology]
import com.google.cloud.tpu.v2.AcceleratorConfig;
import com.google.cloud.tpu.v2.AcceleratorConfig.Type;
import com.google.cloud.tpu.v2.CreateNodeRequest;
import com.google.cloud.tpu.v2.Node;
import com.google.cloud.tpu.v2.TpuClient;
import java.io.IOException;
import java.util.concurrent.ExecutionException;

public class CreateTpuWithTopologyFlag {

public static void main(String[] args)
throws IOException, ExecutionException, InterruptedException {
// TODO(developer): Replace these variables before running the sample.
// Project ID or project number of the Google Cloud project you want to create a node.
String projectId = "YOUR_PROJECT_ID";
// The zone in which to create the TPU.
// For more information about supported TPU types for specific zones,
// see https://cloud.google.com/tpu/docs/regions-zones
String zone = "europe-west4-a";
// The name for your TPU.
String nodeName = "YOUR_TPU_NAME";
// The version of the Cloud TPU you want to create.
// Available options: TYPE_UNSPECIFIED = 0, V2 = 2, V3 = 4, V4 = 7
Type tpuVersion = AcceleratorConfig.Type.V2;
// Software version that specifies the version of the TPU runtime to install.
// For more information, see https://cloud.google.com/tpu/docs/runtimes
String tpuSoftwareVersion = "tpu-vm-tf-2.17.0-pod-pjrt";
// The physical topology of your TPU slice.
// For more information about topology for each TPU version,
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
String topology = "2x2";

createTpuWithTopologyFlag(projectId, zone, nodeName, tpuVersion, tpuSoftwareVersion, topology);
}

// Creates a TPU VM with the specified name, zone, version and topology.
public static Node createTpuWithTopologyFlag(String projectId, String zone, String nodeName,
Type tpuVersion, String tpuSoftwareVersion, String topology)
throws IOException, ExecutionException, InterruptedException {
// Initialize client that will be used to send requests. This client only needs to be created
// once, and can be reused for multiple requests.
try (TpuClient tpuClient = TpuClient.create()) {
String parent = String.format("projects/%s/locations/%s", projectId, zone);
Node tpuVm =
Node.newBuilder()
.setName(nodeName)
.setAcceleratorConfig(Node.newBuilder()
.getAcceleratorConfigBuilder()
.setType(tpuVersion)
.setTopology(topology)
.build())
.setRuntimeVersion(tpuSoftwareVersion)
.build();

CreateNodeRequest request =
CreateNodeRequest.newBuilder()
.setParent(parent)
.setNodeId(nodeName)
.setNode(tpuVm)
.build();

return tpuClient.createNodeAsync(request).get();
}
}
}
//[END tpu_vm_create_topology]
1 change: 0 additions & 1 deletion tpu/src/main/java/tpu/GetQueuedResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package tpu;

//[START tpu_queued_resources_get]

import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest;
import com.google.cloud.tpu.v2alpha1.QueuedResource;
import com.google.cloud.tpu.v2alpha1.TpuClient;
Expand Down
70 changes: 0 additions & 70 deletions tpu/src/test/java/tpu/CreateTpuIT.java

This file was deleted.

23 changes: 11 additions & 12 deletions tpu/src/test/java/tpu/QueuedResourceIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.MockedStatic;

@RunWith(JUnit4.class)
@Timeout(value = 3)
@Timeout(value = 10)
public class QueuedResourceIT {
private static final String PROJECT_ID = "project-id";
private static final String ZONE = "europe-west4-a";
Expand All @@ -52,10 +52,10 @@ public class QueuedResourceIT {
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
private static final String NETWORK_NAME = "default";
private ByteArrayOutputStream bout;
private static ByteArrayOutputStream bout;

@Before
public void setUp() {
@BeforeAll
public static void setUp() {
bout = new ByteArrayOutputStream();
System.setOut(new PrintStream(bout));
}
Expand All @@ -75,8 +75,8 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {

QueuedResource returnedQueuedResource =
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);

verify(mockTpuClient, times(1))
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
Expand All @@ -89,7 +89,6 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
public void testGetQueuedResource() throws IOException {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
TpuClient mockClient = mock(TpuClient.class);
GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class);
QueuedResource mockQueuedResource = mock(QueuedResource.class);

mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
Expand All @@ -99,14 +98,14 @@ public void testGetQueuedResource() throws IOException {
QueuedResource returnedQueuedResource =
GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);

verify(mockGetQueuedResource, times(1))
.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME);
verify(mockClient, times(1))
.getQueuedResource(any(GetQueuedResourceRequest.class));
assertEquals(returnedQueuedResource, mockQueuedResource);
}
}

@Test
public void testDeleteTpuVm() {
public void testDeleteForceQueuedResource() {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
TpuClient mockTpuClient = mock(TpuClient.class);
OperationFuture mockFuture = mock(OperationFuture.class);
Expand Down
69 changes: 59 additions & 10 deletions tpu/src/test/java/tpu/TpuVmIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package tpu;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
Expand All @@ -25,6 +26,8 @@
import static org.mockito.Mockito.when;

import com.google.api.gax.longrunning.OperationFuture;
import com.google.cloud.tpu.v2.AcceleratorConfig;
import com.google.cloud.tpu.v2.CreateNodeRequest;
import com.google.cloud.tpu.v2.DeleteNodeRequest;
import com.google.cloud.tpu.v2.GetNodeRequest;
import com.google.cloud.tpu.v2.Node;
Expand All @@ -34,47 +37,68 @@
import java.io.IOException;
import java.io.PrintStream;
import java.util.concurrent.ExecutionException;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.MockedStatic;

@RunWith(JUnit4.class)
@Timeout(value = 3)
@Timeout(value = 10)
public class TpuVmIT {
private static final String PROJECT_ID = "project-id";
private static final String ZONE = "asia-east1-c";
private static final String NODE_NAME = "test-tpu";
private static ByteArrayOutputStream bout;
private static final String TPU_TYPE = "v2-8";
private static final AcceleratorConfig.Type ACCELERATOR_TYPE = AcceleratorConfig.Type.V2;
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
private static final String TOPOLOGY = "2x2";

@BeforeAll
public static void setUp() {
bout = new ByteArrayOutputStream();
System.setOut(new PrintStream(bout));
@Test
public void testCreateTpuVm() throws Exception {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
Node mockNode = mock(Node.class);
TpuClient mockTpuClient = mock(TpuClient.class);
OperationFuture mockFuture = mock(OperationFuture.class);

mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
.thenReturn(mockTpuClient);
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
.thenReturn(mockFuture);
when(mockFuture.get()).thenReturn(mockNode);

Node returnedNode = CreateTpuVm.createTpuVm(
PROJECT_ID, ZONE, NODE_NAME,
TPU_TYPE, TPU_SOFTWARE_VERSION);

verify(mockTpuClient, times(1))
.createNodeAsync(any(CreateNodeRequest.class));
verify(mockFuture, times(1)).get();
assertEquals(returnedNode, mockNode);
}
}

@Test
public void testGetTpuVm() throws IOException {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
Node mockNode = mock(Node.class);
TpuClient mockClient = mock(TpuClient.class);
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);

mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode);

Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);

verify(mockGetTpuVm, times(1))
.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
verify(mockClient, times(1))
.getNode(any(GetNodeRequest.class));
assertThat(returnedNode).isEqualTo(mockNode);
}
}

@Test
public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException {
ByteArrayOutputStream bout = new ByteArrayOutputStream();
System.setOut(new PrintStream(bout));
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
TpuClient mockTpuClient = mock(TpuClient.class);
OperationFuture mockFuture = mock(OperationFuture.class);
Expand All @@ -89,6 +113,31 @@ public void testDeleteTpuVm() throws IOException, ExecutionException, Interrupte

assertThat(output).contains("TPU VM deleted");
verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class));

bout.close();
}
}

@Test
public void testCreateTpuVmWithTopologyFlag()
throws IOException, ExecutionException, InterruptedException {
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
Node mockNode = mock(Node.class);
TpuClient mockTpuClient = mock(TpuClient.class);
OperationFuture mockFuture = mock(OperationFuture.class);

mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient);
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
.thenReturn(mockFuture);
when(mockFuture.get()).thenReturn(mockNode);
Node returnedNode = CreateTpuWithTopologyFlag.createTpuWithTopologyFlag(
PROJECT_ID, ZONE, NODE_NAME, ACCELERATOR_TYPE,
TPU_SOFTWARE_VERSION, TOPOLOGY);

verify(mockTpuClient, times(1))
.createNodeAsync(any(CreateNodeRequest.class));
verify(mockFuture, times(1)).get();
assertEquals(returnedNode, mockNode);
}
}
}

0 comments on commit bcecf8f

Please sign in to comment.