From 15a748a8cb61d04a0dded5d611d5a3f3b1f3eadc Mon Sep 17 00:00:00 2001 From: Jacek Spalinski <69755075+jacspa96@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:21:28 +0100 Subject: [PATCH 1/3] feat(dataplex): add code samples for Search Entries (#9609) * feat(dataplex): add sample for search Entries * feat(dataplex): add integration tests for search Entries * feat(dataplex): make searchEntries return SearchEntriesResponse to support paging * feat(dataplex): adjust integration test based on code review * feat(dataplex): adjust comment regarding search scope * feat(dataplex): remove paging from Search --------- Co-authored-by: Jacek Spalinski --- .../src/main/java/dataplex/SearchEntries.java | 65 +++++++++++++++++ .../test/java/dataplex/SearchEntriesIT.java | 71 +++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 dataplex/snippets/src/main/java/dataplex/SearchEntries.java create mode 100644 dataplex/snippets/src/test/java/dataplex/SearchEntriesIT.java diff --git a/dataplex/snippets/src/main/java/dataplex/SearchEntries.java b/dataplex/snippets/src/main/java/dataplex/SearchEntries.java new file mode 100644 index 00000000000..25706176380 --- /dev/null +++ b/dataplex/snippets/src/main/java/dataplex/SearchEntries.java @@ -0,0 +1,65 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataplex; + +// [START dataplex_search_entries] +import com.google.cloud.dataplex.v1.CatalogServiceClient; +import com.google.cloud.dataplex.v1.Entry; +import com.google.cloud.dataplex.v1.SearchEntriesRequest; +import com.google.cloud.dataplex.v1.SearchEntriesResult; +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +public class SearchEntries { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String projectId = "MY_PROJECT_ID"; + // How to write query for search: https://cloud.google.com/dataplex/docs/search-syntax + String query = "MY_QUERY"; + + List entries = searchEntries(projectId, query); + entries.forEach(entry -> System.out.println("Entry name found in search: " + entry.getName())); + } + + // Method to search Entries located in projectId and matching query + public static List searchEntries(String projectId, String query) throws IOException { + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. + try (CatalogServiceClient client = CatalogServiceClient.create()) { + SearchEntriesRequest searchEntriesRequest = + SearchEntriesRequest.newBuilder() + .setPageSize(100) + // Required field, will by default limit search scope to organization under which the + // project is located + .setName(String.format("projects/%s/locations/global", projectId)) + // Optional field, will further limit search scope only to specified project + .setScope(String.format("projects/%s", projectId)) + .setQuery(query) + .build(); + + CatalogServiceClient.SearchEntriesPagedResponse searchEntriesResponse = + client.searchEntries(searchEntriesRequest); + return searchEntriesResponse.getPage().getResponse().getResultsList().stream() + // Extract Entries nested inside search results + .map(SearchEntriesResult::getDataplexEntry) + .collect(Collectors.toList()); + } + } +} +// [END dataplex_search_entries] diff --git a/dataplex/snippets/src/test/java/dataplex/SearchEntriesIT.java b/dataplex/snippets/src/test/java/dataplex/SearchEntriesIT.java new file mode 100644 index 00000000000..2a1d7636dd5 --- /dev/null +++ b/dataplex/snippets/src/test/java/dataplex/SearchEntriesIT.java @@ -0,0 +1,71 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dataplex; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.dataplex.v1.Entry; +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class SearchEntriesIT { + private static final String ID = UUID.randomUUID().toString().substring(0, 8); + private static final String LOCATION = "us-central1"; + private static final String entryGroupId = "test-entry-group-" + ID; + private static final String entryId = "test-entry-" + ID; + private static final String expectedEntry = + String.format("locations/%s/entryGroups/%s/entries/%s", LOCATION, entryGroupId, entryId); + + private static final String PROJECT_ID = requireProjectIdEnvVar(); + + private static String requireProjectIdEnvVar() { + String value = System.getenv("GOOGLE_CLOUD_PROJECT"); + assertNotNull( + "Environment variable GOOGLE_CLOUD_PROJECT is required to perform these tests.", value); + return value; + } + + @BeforeClass + public static void setUp() throws Exception { + requireProjectIdEnvVar(); + CreateEntryGroup.createEntryGroup(PROJECT_ID, LOCATION, entryGroupId); + CreateEntry.createEntry(PROJECT_ID, LOCATION, entryGroupId, entryId); + Thread.sleep(30000); + } + + @Test + public void testSearchEntries() throws IOException { + String query = "name:test-entry- AND description:description AND aspect:generic"; + List entries = SearchEntries.searchEntries(PROJECT_ID, query); + assertThat( + entries.stream() + .map(Entry::getName) + .map(entryName -> entryName.substring(entryName.indexOf("location")))) + .contains(expectedEntry); + } + + @AfterClass + public static void tearDown() throws Exception { + // Entry inside this Entry Group will be deleted automatically + DeleteEntryGroup.deleteEntryGroup(PROJECT_ID, LOCATION, entryGroupId); + } +} From 7f84b302e5eccc354941318bf134fbfa6955bff4 Mon Sep 17 00:00:00 2001 From: James Ma Date: Tue, 19 Nov 2024 15:39:06 -0800 Subject: [PATCH 2/3] Update README.md to reflect CRf rebrand (#9685) --- functions/README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/functions/README.md b/functions/README.md index f093b1fa13a..5d5d6086a7b 100644 --- a/functions/README.md +++ b/functions/README.md @@ -2,25 +2,25 @@ # Google Cloud Functions Java Samples -[Cloud Functions][functions_docs] is a lightweight, event-based, asynchronous -compute solution that allows you to create small, single-purpose functions that -respond to Cloud events without the need to manage a server or a runtime -environment. +[Cloud Run functions](https://cloud.google.com/functions/docs/concepts/overview) is a lightweight, event-based, asynchronous compute solution that allows you to create small, single-purpose functions that respond to Cloud events without the need to manage a server or a runtime environment. -[functions_docs]: https://cloud.google.com/functions/docs/ +There are two versions of Cloud Run functions: + +* **Cloud Run functions**, formerly known as Cloud Functions (2nd gen), which deploys your function as services on Cloud Run, allowing you to trigger them using Eventarc and Pub/Sub. Cloud Run functions are created using `gcloud functions` or `gcloud run`. Samples for Cloud Run functions can be found in the [`functions/v2`](v2/) folder. +* **Cloud Run functions (1st gen)**, formerly known as Cloud Functions (1st gen), the original version of functions with limited event triggers and configurability. Cloud Run functions (1st gen) are created using `gcloud functions --no-gen2`. Samples for Cloud Run functions (1st generation) can be found in the current `functions/` folder. ## Samples * [Hello World](helloworld/) -* [Concepts](concepts/) +* [Concepts](v2/concepts/) * [Datastore](v2/datastore/) * [Firebase](firebase/) -* [Cloud Pub/Sub](pubsub/) +* [Cloud Pub/Sub](v2/pubsub/) * [HTTP](http/) * [Logging & Monitoring](logging/) * [Slack](slack/) -* [OCR tutorial](ocr/) -* [ImageMagick](imagemagick/) +* [OCR tutorial](v2/ocr/) +* [ImageMagick](v2/imagemagick/) * [CI/CD setup](ci_cd/) ## Running Functions Locally From bcecf8f697edc4f564636d92c2d6446897277828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A2=D0=B5=D1=82=D1=8F=D0=BD=D0=B0=20=D0=AF=D0=B3=D0=BE?= =?UTF-8?q?=D0=B4=D1=81=D1=8C=D0=BA=D0=B0?= <49729677+TetyanaYahodska@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:11:53 +0100 Subject: [PATCH 3/3] feat(tpu): add tpu vm create topology sample. (#9611) * Changed package, added information to CODEOWNERS * Added information to CODEOWNERS * Added timeout * Fixed parameters for test * Fixed DeleteTpuVm and naming * Added comment, created Util class * Fixed naming * Fixed whitespace * Split PR into smaller, deleted redundant code * Implemented tpu_vm_create_topology sample, created test * Changed zone * Fixed empty lines and tests, deleted cleanup method * Fixed tests * Fixed test * Fixed imports * Increased timeout to 10 sec * Fixed tests * Fixed tests * Deleted settings * Made ByteArrayOutputStream bout as local variable * Changed timeout to 10 sec --- .../java/tpu/CreateTpuWithTopologyFlag.java | 85 +++++++++++++++++++ tpu/src/main/java/tpu/GetQueuedResource.java | 1 - tpu/src/test/java/tpu/CreateTpuIT.java | 70 --------------- tpu/src/test/java/tpu/QueuedResourceIT.java | 23 +++-- tpu/src/test/java/tpu/TpuVmIT.java | 69 ++++++++++++--- 5 files changed, 155 insertions(+), 93 deletions(-) create mode 100644 tpu/src/main/java/tpu/CreateTpuWithTopologyFlag.java delete mode 100644 tpu/src/test/java/tpu/CreateTpuIT.java diff --git a/tpu/src/main/java/tpu/CreateTpuWithTopologyFlag.java b/tpu/src/main/java/tpu/CreateTpuWithTopologyFlag.java new file mode 100644 index 00000000000..86e7e28a007 --- /dev/null +++ b/tpu/src/main/java/tpu/CreateTpuWithTopologyFlag.java @@ -0,0 +1,85 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tpu; + +//[START tpu_vm_create_topology] +import com.google.cloud.tpu.v2.AcceleratorConfig; +import com.google.cloud.tpu.v2.AcceleratorConfig.Type; +import com.google.cloud.tpu.v2.CreateNodeRequest; +import com.google.cloud.tpu.v2.Node; +import com.google.cloud.tpu.v2.TpuClient; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +public class CreateTpuWithTopologyFlag { + + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + // Project ID or project number of the Google Cloud project you want to create a node. + String projectId = "YOUR_PROJECT_ID"; + // The zone in which to create the TPU. + // For more information about supported TPU types for specific zones, + // see https://cloud.google.com/tpu/docs/regions-zones + String zone = "europe-west4-a"; + // The name for your TPU. + String nodeName = "YOUR_TPU_NAME"; + // The version of the Cloud TPU you want to create. + // Available options: TYPE_UNSPECIFIED = 0, V2 = 2, V3 = 4, V4 = 7 + Type tpuVersion = AcceleratorConfig.Type.V2; + // Software version that specifies the version of the TPU runtime to install. + // For more information, see https://cloud.google.com/tpu/docs/runtimes + String tpuSoftwareVersion = "tpu-vm-tf-2.17.0-pod-pjrt"; + // The physical topology of your TPU slice. + // For more information about topology for each TPU version, + // see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions. + String topology = "2x2"; + + createTpuWithTopologyFlag(projectId, zone, nodeName, tpuVersion, tpuSoftwareVersion, topology); + } + + // Creates a TPU VM with the specified name, zone, version and topology. + public static Node createTpuWithTopologyFlag(String projectId, String zone, String nodeName, + Type tpuVersion, String tpuSoftwareVersion, String topology) + throws IOException, ExecutionException, InterruptedException { + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. + try (TpuClient tpuClient = TpuClient.create()) { + String parent = String.format("projects/%s/locations/%s", projectId, zone); + Node tpuVm = + Node.newBuilder() + .setName(nodeName) + .setAcceleratorConfig(Node.newBuilder() + .getAcceleratorConfigBuilder() + .setType(tpuVersion) + .setTopology(topology) + .build()) + .setRuntimeVersion(tpuSoftwareVersion) + .build(); + + CreateNodeRequest request = + CreateNodeRequest.newBuilder() + .setParent(parent) + .setNodeId(nodeName) + .setNode(tpuVm) + .build(); + + return tpuClient.createNodeAsync(request).get(); + } + } +} +//[END tpu_vm_create_topology] \ No newline at end of file diff --git a/tpu/src/main/java/tpu/GetQueuedResource.java b/tpu/src/main/java/tpu/GetQueuedResource.java index 3a510e045fe..a17c2b41f79 100644 --- a/tpu/src/main/java/tpu/GetQueuedResource.java +++ b/tpu/src/main/java/tpu/GetQueuedResource.java @@ -17,7 +17,6 @@ package tpu; //[START tpu_queued_resources_get] - import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest; import com.google.cloud.tpu.v2alpha1.QueuedResource; import com.google.cloud.tpu.v2alpha1.TpuClient; diff --git a/tpu/src/test/java/tpu/CreateTpuIT.java b/tpu/src/test/java/tpu/CreateTpuIT.java deleted file mode 100644 index cdb11831364..00000000000 --- a/tpu/src/test/java/tpu/CreateTpuIT.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package tpu; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.google.api.gax.longrunning.OperationFuture; -import com.google.cloud.tpu.v2.CreateNodeRequest; -import com.google.cloud.tpu.v2.Node; -import com.google.cloud.tpu.v2.TpuClient; -import com.google.cloud.tpu.v2.TpuSettings; -import org.junit.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.MockedStatic; - -@RunWith(JUnit4.class) -@Timeout(value = 3) -public class CreateTpuIT { - private static final String PROJECT_ID = "project-id"; - private static final String ZONE = "asia-east1-c"; - private static final String NODE_NAME = "test-tpu"; - private static final String TPU_TYPE = "v2-8"; - private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1"; - - @Test - public void testCreateTpuVm() throws Exception { - try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { - Node mockNode = mock(Node.class); - TpuClient mockTpuClient = mock(TpuClient.class); - OperationFuture mockFuture = mock(OperationFuture.class); - - mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class))) - .thenReturn(mockTpuClient); - when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class))) - .thenReturn(mockFuture); - when(mockFuture.get()).thenReturn(mockNode); - - Node returnedNode = CreateTpuVm.createTpuVm( - PROJECT_ID, ZONE, NODE_NAME, - TPU_TYPE, TPU_SOFTWARE_VERSION); - - verify(mockTpuClient, times(1)) - .createNodeAsync(any(CreateNodeRequest.class)); - verify(mockFuture, times(1)).get(); - assertEquals(returnedNode, mockNode); - } - } -} diff --git a/tpu/src/test/java/tpu/QueuedResourceIT.java b/tpu/src/test/java/tpu/QueuedResourceIT.java index 906d0e270df..ec7d9512b92 100644 --- a/tpu/src/test/java/tpu/QueuedResourceIT.java +++ b/tpu/src/test/java/tpu/QueuedResourceIT.java @@ -35,15 +35,15 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.MockedStatic; @RunWith(JUnit4.class) -@Timeout(value = 3) +@Timeout(value = 10) public class QueuedResourceIT { private static final String PROJECT_ID = "project-id"; private static final String ZONE = "europe-west4-a"; @@ -52,10 +52,10 @@ public class QueuedResourceIT { private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1"; private static final String QUEUED_RESOURCE_NAME = "queued-resource"; private static final String NETWORK_NAME = "default"; - private ByteArrayOutputStream bout; + private static ByteArrayOutputStream bout; - @Before - public void setUp() { + @BeforeAll + public static void setUp() { bout = new ByteArrayOutputStream(); System.setOut(new PrintStream(bout)); } @@ -75,8 +75,8 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception { QueuedResource returnedQueuedResource = CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork( - PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME, - TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME); + PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME, + TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME); verify(mockTpuClient, times(1)) .createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)); @@ -89,7 +89,6 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception { public void testGetQueuedResource() throws IOException { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { TpuClient mockClient = mock(TpuClient.class); - GetQueuedResource mockGetQueuedResource = mock(GetQueuedResource.class); QueuedResource mockQueuedResource = mock(QueuedResource.class); mockedTpuClient.when(TpuClient::create).thenReturn(mockClient); @@ -99,14 +98,14 @@ public void testGetQueuedResource() throws IOException { QueuedResource returnedQueuedResource = GetQueuedResource.getQueuedResource(PROJECT_ID, ZONE, NODE_NAME); - verify(mockGetQueuedResource, times(1)) - .getQueuedResource(PROJECT_ID, ZONE, NODE_NAME); + verify(mockClient, times(1)) + .getQueuedResource(any(GetQueuedResourceRequest.class)); assertEquals(returnedQueuedResource, mockQueuedResource); } } @Test - public void testDeleteTpuVm() { + public void testDeleteForceQueuedResource() { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { TpuClient mockTpuClient = mock(TpuClient.class); OperationFuture mockFuture = mock(OperationFuture.class); diff --git a/tpu/src/test/java/tpu/TpuVmIT.java b/tpu/src/test/java/tpu/TpuVmIT.java index 08dfeca8eb9..a640953c445 100644 --- a/tpu/src/test/java/tpu/TpuVmIT.java +++ b/tpu/src/test/java/tpu/TpuVmIT.java @@ -17,6 +17,7 @@ package tpu; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; @@ -25,6 +26,8 @@ import static org.mockito.Mockito.when; import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.tpu.v2.AcceleratorConfig; +import com.google.cloud.tpu.v2.CreateNodeRequest; import com.google.cloud.tpu.v2.DeleteNodeRequest; import com.google.cloud.tpu.v2.GetNodeRequest; import com.google.cloud.tpu.v2.Node; @@ -34,7 +37,6 @@ import java.io.IOException; import java.io.PrintStream; import java.util.concurrent.ExecutionException; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.runner.RunWith; @@ -42,17 +44,38 @@ import org.mockito.MockedStatic; @RunWith(JUnit4.class) -@Timeout(value = 3) +@Timeout(value = 10) public class TpuVmIT { private static final String PROJECT_ID = "project-id"; private static final String ZONE = "asia-east1-c"; private static final String NODE_NAME = "test-tpu"; - private static ByteArrayOutputStream bout; + private static final String TPU_TYPE = "v2-8"; + private static final AcceleratorConfig.Type ACCELERATOR_TYPE = AcceleratorConfig.Type.V2; + private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1"; + private static final String TOPOLOGY = "2x2"; - @BeforeAll - public static void setUp() { - bout = new ByteArrayOutputStream(); - System.setOut(new PrintStream(bout)); + @Test + public void testCreateTpuVm() throws Exception { + try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { + Node mockNode = mock(Node.class); + TpuClient mockTpuClient = mock(TpuClient.class); + OperationFuture mockFuture = mock(OperationFuture.class); + + mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class))) + .thenReturn(mockTpuClient); + when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class))) + .thenReturn(mockFuture); + when(mockFuture.get()).thenReturn(mockNode); + + Node returnedNode = CreateTpuVm.createTpuVm( + PROJECT_ID, ZONE, NODE_NAME, + TPU_TYPE, TPU_SOFTWARE_VERSION); + + verify(mockTpuClient, times(1)) + .createNodeAsync(any(CreateNodeRequest.class)); + verify(mockFuture, times(1)).get(); + assertEquals(returnedNode, mockNode); + } } @Test @@ -60,21 +83,22 @@ public void testGetTpuVm() throws IOException { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { Node mockNode = mock(Node.class); TpuClient mockClient = mock(TpuClient.class); - GetTpuVm mockGetTpuVm = mock(GetTpuVm.class); mockedTpuClient.when(TpuClient::create).thenReturn(mockClient); when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode); Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME); - verify(mockGetTpuVm, times(1)) - .getTpuVm(PROJECT_ID, ZONE, NODE_NAME); + verify(mockClient, times(1)) + .getNode(any(GetNodeRequest.class)); assertThat(returnedNode).isEqualTo(mockNode); } } @Test public void testDeleteTpuVm() throws IOException, ExecutionException, InterruptedException { + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + System.setOut(new PrintStream(bout)); try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { TpuClient mockTpuClient = mock(TpuClient.class); OperationFuture mockFuture = mock(OperationFuture.class); @@ -89,6 +113,31 @@ public void testDeleteTpuVm() throws IOException, ExecutionException, Interrupte assertThat(output).contains("TPU VM deleted"); verify(mockTpuClient, times(1)).deleteNodeAsync(any(DeleteNodeRequest.class)); + + bout.close(); + } + } + + @Test + public void testCreateTpuVmWithTopologyFlag() + throws IOException, ExecutionException, InterruptedException { + try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { + Node mockNode = mock(Node.class); + TpuClient mockTpuClient = mock(TpuClient.class); + OperationFuture mockFuture = mock(OperationFuture.class); + + mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient); + when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class))) + .thenReturn(mockFuture); + when(mockFuture.get()).thenReturn(mockNode); + Node returnedNode = CreateTpuWithTopologyFlag.createTpuWithTopologyFlag( + PROJECT_ID, ZONE, NODE_NAME, ACCELERATOR_TYPE, + TPU_SOFTWARE_VERSION, TOPOLOGY); + + verify(mockTpuClient, times(1)) + .createNodeAsync(any(CreateNodeRequest.class)); + verify(mockFuture, times(1)).get(); + assertEquals(returnedNode, mockNode); } } } \ No newline at end of file