diff --git a/tpu/src/main/java/tpu/CreateTpuWithTopologyFlag.java b/tpu/src/main/java/tpu/CreateTpuWithTopologyFlag.java new file mode 100644 index 00000000000..86e7e28a007 --- /dev/null +++ b/tpu/src/main/java/tpu/CreateTpuWithTopologyFlag.java @@ -0,0 +1,85 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tpu; + +//[START tpu_vm_create_topology] +import com.google.cloud.tpu.v2.AcceleratorConfig; +import com.google.cloud.tpu.v2.AcceleratorConfig.Type; +import com.google.cloud.tpu.v2.CreateNodeRequest; +import com.google.cloud.tpu.v2.Node; +import com.google.cloud.tpu.v2.TpuClient; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +public class CreateTpuWithTopologyFlag { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + // Project ID or project number of the Google Cloud project you want to create a node. + String projectId = "YOUR_PROJECT_ID"; + // The zone in which to create the TPU. + // For more information about supported TPU types for specific zones, + // see https://cloud.google.com/tpu/docs/regions-zones + String zone = "europe-west4-a"; + // The name for your TPU. + String nodeName = "YOUR_TPU_NAME"; + // The version of the Cloud TPU you want to create. + // Available options: TYPE_UNSPECIFIED = 0, V2 = 2, V3 = 4, V4 = 7 + Type tpuVersion = AcceleratorConfig.Type.V2; + // Software version that specifies the version of the TPU runtime to install. + // For more information, see https://cloud.google.com/tpu/docs/runtimes + String tpuSoftwareVersion = "tpu-vm-tf-2.17.0-pod-pjrt"; + // The physical topology of your TPU slice. + // For more information about topology for each TPU version, + // see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions. + String topology = "2x2"; + + createTpuWithTopologyFlag(projectId, zone, nodeName, tpuVersion, tpuSoftwareVersion, topology); + } + + // Creates a TPU VM with the specified name, zone, version and topology. + public static Node createTpuWithTopologyFlag(String projectId, String zone, String nodeName, + Type tpuVersion, String tpuSoftwareVersion, String topology) + throws IOException, ExecutionException, InterruptedException { + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. + try (TpuClient tpuClient = TpuClient.create()) { + String parent = String.format("projects/%s/locations/%s", projectId, zone); + Node tpuVm = + Node.newBuilder() + .setName(nodeName) + .setAcceleratorConfig(Node.newBuilder() + .getAcceleratorConfigBuilder() + .setType(tpuVersion) + .setTopology(topology) + .build()) + .setRuntimeVersion(tpuSoftwareVersion) + .build(); + + CreateNodeRequest request = + CreateNodeRequest.newBuilder() + .setParent(parent) + .setNodeId(nodeName) + .setNode(tpuVm) + .build(); + + return tpuClient.createNodeAsync(request).get(); + } + } +} +//[END tpu_vm_create_topology] \ No newline at end of file diff --git a/tpu/src/main/java/tpu/GetQueuedResource.java b/tpu/src/main/java/tpu/GetQueuedResource.java index 3a510e045fe..a17c2b41f79 100644 --- a/tpu/src/main/java/tpu/GetQueuedResource.java +++ b/tpu/src/main/java/tpu/GetQueuedResource.java @@ -17,7 +17,6 @@ package tpu; //[START tpu_queued_resources_get] - import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest; import com.google.cloud.tpu.v2alpha1.QueuedResource; import com.google.cloud.tpu.v2alpha1.TpuClient; diff --git a/tpu/src/test/java/tpu/CreateTpuIT.java b/tpu/src/test/java/tpu/CreateTpuIT.java deleted file mode 100644 index cdb11831364..00000000000 --- a/tpu/src/test/java/tpu/CreateTpuIT.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package tpu; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.google.api.gax.longrunning.OperationFuture; -import com.google.cloud.tpu.v2.CreateNodeRequest; -import com.google.cloud.tpu.v2.Node; -import com.google.cloud.tpu.v2.TpuClient; -import com.google.cloud.tpu.v2.TpuSettings; -import org.junit.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.MockedStatic; - -@RunWith(JUnit4.class) -@Timeout(value = 3) -public class CreateTpuIT { - private static final String PROJECT_ID = "project-id"; - private static final String ZONE = "asia-east1-c"; - private static final String NODE_NAME = "test-tpu"; - private static final String TPU_TYPE = "v2-8"; - private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1"; - - @Test - public void testCreateTpuVm() throws Exception { - try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { - Node mockNode = mock(Node.class); - TpuClient mockTpuClient = mock(TpuClient.class); - OperationFuture mockFuture = mock(OperationFuture.class); - - mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class))) - .thenReturn(mockTpuClient); - when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class))) - .thenReturn(mockFuture); - when(mockFuture.get()).thenReturn(mockNode); - - Node returnedNode = CreateTpuVm.createTpuVm( - PROJECT_ID, ZONE, NODE_NAME, - TPU_TYPE, TPU_SOFTWARE_VERSION); - - verify(mockTpuClient, times(1)) - .createNodeAsync(any(CreateNodeRequest.class)); - verify(mockFuture, times(1)).get(); - assertEquals(returnedNode, mockNode); - } - } -} diff --git a/tpu/src/test/java/tpu/QueuedResourceIT.java b/tpu/src/test/java/tpu/QueuedResourceIT.java index 906d0e270df..ec7d9512b92 100644 --- a/tpu/src/test/java/tpu/QueuedResourceIT.java +++ b/tpu/src/test/java/tpu/QueuedResourceIT.java @@ -35,15 +35,15 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.MockedStatic; @RunWith(JUnit4.class) -@Timeout(value = 3) +@Timeout(value = 10) public class QueuedResourceIT { private static final String PROJECT_ID = "project-id"; private static final String ZONE = "europe-west4-a"; @@ -52,10 +52,10 @@ public class QueuedResourceIT { private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1"; private static final String QUEUED_RESOURCE_NAME = "queued-resource"; private static final String NETWORK_NAME = "default"; - private ByteArrayOutputStream bout; + private static ByteArrayOutputStream bout; - @Before - public void setUp() { + @BeforeAll + public static void setUp() { bout = new ByteArrayOutputStream(); System.setOut(new PrintStream(bout)); } @@ -75,8 +75,8 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception { QueuedResource returnedQueuedResource = CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork( - PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME, - TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME); + PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME, + TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME); verify(mockTpuClient, times(1)) .createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)); @@ -89,7 +89,6 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception { public void testGetQueuedResource() throws IOException { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { TpuClient mockClient = mock(TpuClient.class); - GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class); QueuedResource mockQueuedResource = mock(QueuedResource.class); mockedTpuClient.when(TpuClient::create).thenReturn(mockClient); @@ -99,14 +98,14 @@ public void testGetQueuedResource() throws IOException { QueuedResource returnedQueuedResource = GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME); - verify(mockGetQueuedResource, times(1)) - .getQueuedResource(PROJECT_ID, ZONE, NODE_NAME); + verify(mockClient, times(1)) + .getQueuedResource(any(GetQueuedResourceRequest.class)); assertEquals(returnedQueuedResource, mockQueuedResource); } } @Test - public void testDeleteTpuVm() { + public void testDeleteForceQueuedResource() { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { TpuClient mockTpuClient = mock(TpuClient.class); OperationFuture mockFuture = mock(OperationFuture.class); diff --git a/tpu/src/test/java/tpu/TpuVmIT.java b/tpu/src/test/java/tpu/TpuVmIT.java index 08dfeca8eb9..a640953c445 100644 --- a/tpu/src/test/java/tpu/TpuVmIT.java +++ b/tpu/src/test/java/tpu/TpuVmIT.java @@ -17,6 +17,7 @@ package tpu; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; @@ -25,6 +26,8 @@ import static org.mockito.Mockito.when; import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.tpu.v2.AcceleratorConfig; +import com.google.cloud.tpu.v2.CreateNodeRequest; import com.google.cloud.tpu.v2.DeleteNodeRequest; import com.google.cloud.tpu.v2.GetNodeRequest; import com.google.cloud.tpu.v2.Node; @@ -34,7 +37,6 @@ import java.io.IOException; import java.io.PrintStream; import java.util.concurrent.ExecutionException; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.runner.RunWith; @@ -42,17 +44,38 @@ import org.mockito.MockedStatic; @RunWith(JUnit4.class) -@Timeout(value = 3) +@Timeout(value = 10) public class TpuVmIT { private static final String PROJECT_ID = "project-id"; private static final String ZONE = "asia-east1-c"; private static final String NODE_NAME = "test-tpu"; - private static ByteArrayOutputStream bout; + private static final String TPU_TYPE = "v2-8"; + private static final AcceleratorConfig.Type ACCELERATOR_TYPE = AcceleratorConfig.Type.V2; + private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1"; + private static final String TOPOLOGY = "2x2"; - @BeforeAll - public static void setUp() { - bout = new ByteArrayOutputStream(); - System.setOut(new PrintStream(bout)); + @Test + public void testCreateTpuVm() throws Exception { + try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { + Node mockNode = mock(Node.class); + TpuClient mockTpuClient = mock(TpuClient.class); + OperationFuture mockFuture = mock(OperationFuture.class); + + mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class))) + .thenReturn(mockTpuClient); + when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class))) + .thenReturn(mockFuture); + when(mockFuture.get()).thenReturn(mockNode); + + Node returnedNode = CreateTpuVm.createTpuVm( + PROJECT_ID, ZONE, NODE_NAME, + TPU_TYPE, TPU_SOFTWARE_VERSION); + + verify(mockTpuClient, times(1)) + .createNodeAsync(any(CreateNodeRequest.class)); + verify(mockFuture, times(1)).get(); + assertEquals(returnedNode, mockNode); + } } @Test @@ -60,21 +83,22 @@ public void testGetTpuVm() throws IOException { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { Node mockNode = mock(Node.class); TpuClient mockClient = mock(TpuClient.class); - GetTpuVm mockGetTpuVm = mock(GetTpuVm.class); mockedTpuClient.when(TpuClient::create).thenReturn(mockClient); when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode); Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME); - verify(mockGetTpuVm, times(1)) - .getTpuVm(PROJECT_ID, ZONE, NODE_NAME); + verify(mockClient, times(1)) + .getNode(any(GetNodeRequest.class)); assertThat(returnedNode).isEqualTo(mockNode); } } @Test public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException { + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + System.setOut(new PrintStream(bout)); try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { TpuClient mockTpuClient = mock(TpuClient.class); OperationFuture mockFuture = mock(OperationFuture.class); @@ -89,6 +113,31 @@ public void testDeleteTpuVm() throws IOException, ExecutionException, Interrupte assertThat(output).contains("TPU VM deleted"); verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class)); + + bout.close(); + } + } + + @Test + public void testCreateTpuVmWithTopologyFlag() + throws IOException, ExecutionException, InterruptedException { + try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { + Node mockNode = mock(Node.class); + TpuClient mockTpuClient = mock(TpuClient.class); + OperationFuture mockFuture = mock(OperationFuture.class); + + mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient); + when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class))) + .thenReturn(mockFuture); + when(mockFuture.get()).thenReturn(mockNode); + Node returnedNode = CreateTpuWithTopologyFlag.createTpuWithTopologyFlag( + PROJECT_ID, ZONE, NODE_NAME, ACCELERATOR_TYPE, + TPU_SOFTWARE_VERSION, TOPOLOGY); + + verify(mockTpuClient, times(1)) + .createNodeAsync(any(CreateNodeRequest.class)); + verify(mockFuture, times(1)).get(); + assertEquals(returnedNode, mockNode); } } } \ No newline at end of file