diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a16c4cc85a5..1d68a9f342c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -79,7 +79,7 @@ /speech @GoogleCloudPlatform/java-samples-reviewers @yoshi-approver @GoogleCloudPlatform/cloud-samples-reviewers /talent @GoogleCloudPlatform/java-samples-reviewers @yoshi-approver @GoogleCloudPlatform/cloud-samples-reviewers /texttospeech @GoogleCloudPlatform/java-samples-reviewers @yoshi-approver @GoogleCloudPlatform/cloud-samples-reviewers -/translate @GoogleCloudPlatform/java-samples-reviewers @yoshi-approver @GoogleCloudPlatform/cloud-samples-reviewers +/translate @GoogleCloudPlatform/java-samples-reviewers @yoshi-approver @GoogleCloudPlatform/cloud-samples-reviewers @GoogleCloudPlatform/cloud-ml-translate-dev /video @GoogleCloudPlatform/java-samples-reviewers @yoshi-approver @GoogleCloudPlatform/cloud-samples-reviewers /vision @GoogleCloudPlatform/java-samples-reviewers @yoshi-approver @GoogleCloudPlatform/cloud-samples-reviewers diff --git a/security-command-center/snippets/src/main/java/management/api/CreateEventThreatDetectionCustomModule.java b/security-command-center/snippets/src/main/java/management/api/CreateEventThreatDetectionCustomModule.java new file mode 100644 index 00000000000..3e0fb3125b4 --- /dev/null +++ b/security-command-center/snippets/src/main/java/management/api/CreateEventThreatDetectionCustomModule.java @@ -0,0 +1,97 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package management.api; + +// [START securitycenter_create_event_threat_detection_custom_module] +import com.google.cloud.securitycentermanagement.v1.CreateEventThreatDetectionCustomModuleRequest; +import com.google.cloud.securitycentermanagement.v1.EventThreatDetectionCustomModule; +import com.google.cloud.securitycentermanagement.v1.EventThreatDetectionCustomModule.EnablementState; +import com.google.cloud.securitycentermanagement.v1.SecurityCenterManagementClient; +import com.google.protobuf.ListValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class CreateEventThreatDetectionCustomModule { + + public static void main(String[] args) throws IOException { + // https://cloud.google.com/security-command-center/docs/reference/security-center-management/rest/v1/organizations.locations.eventThreatDetectionCustomModules/create + // TODO: Developer should replace project_id with a real project ID before running this code + String projectId = "project_id"; + + String customModuleDisplayName = "custom_module_display_name"; + + createEventThreatDetectionCustomModule(projectId, customModuleDisplayName); + } + + public static EventThreatDetectionCustomModule createEventThreatDetectionCustomModule( + String projectId, String customModuleDisplayName) throws IOException { + + // Initialize client that will be used to send requests. This client only needs + // to be created + // once, and can be reused for multiple requests. + try (SecurityCenterManagementClient client = SecurityCenterManagementClient.create()) { + + // define the metadata and other config parameters severity, description, + // recommendation and ips below + Map metadata = new HashMap<>(); + metadata.put("severity", Value.newBuilder().setStringValue("MEDIUM").build()); + metadata.put( + "description", Value.newBuilder().setStringValue("add your description here").build()); + metadata.put( + "recommendation", + Value.newBuilder().setStringValue("add your recommendation here").build()); + List ips = Arrays.asList(Value.newBuilder().setStringValue("0.0.0.0").build()); + + Value metadataVal = + Value.newBuilder() + .setStructValue(Struct.newBuilder().putAllFields(metadata).build()) + .build(); + Value ipsValue = + Value.newBuilder().setListValue(ListValue.newBuilder().addAllValues(ips).build()).build(); + + Struct configStruct = + Struct.newBuilder().putFields("metadata", metadataVal).putFields("ips", ipsValue).build(); + + // define the Event Threat Detection custom module configuration, update the EnablementState + // below + EventThreatDetectionCustomModule eventThreatDetectionCustomModule = + EventThreatDetectionCustomModule.newBuilder() + .setConfig(configStruct) + .setDisplayName(customModuleDisplayName) + .setEnablementState(EnablementState.ENABLED) + .setType("CONFIGURABLE_BAD_IP") + .build(); + + CreateEventThreatDetectionCustomModuleRequest request = + CreateEventThreatDetectionCustomModuleRequest.newBuilder() + .setParent(String.format("projects/%s/locations/global", projectId)) + .setEventThreatDetectionCustomModule(eventThreatDetectionCustomModule) + .build(); + + EventThreatDetectionCustomModule response = + client.createEventThreatDetectionCustomModule(request); + + return response; + } + } +} +// [END securitycenter_create_event_threat_detection_custom_module] diff --git a/security-command-center/snippets/src/main/java/management/api/DeleteEventThreatDetectionCustomModule.java b/security-command-center/snippets/src/main/java/management/api/DeleteEventThreatDetectionCustomModule.java new file mode 100644 index 00000000000..650b0d32a60 --- /dev/null +++ b/security-command-center/snippets/src/main/java/management/api/DeleteEventThreatDetectionCustomModule.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package management.api; + +// [START securitycenter_delete_event_threat_detection_custom_module] +import com.google.cloud.securitycentermanagement.v1.DeleteEventThreatDetectionCustomModuleRequest; +import com.google.cloud.securitycentermanagement.v1.SecurityCenterManagementClient; +import java.io.IOException; + +public class DeleteEventThreatDetectionCustomModule { + + public static void main(String[] args) throws IOException { + // https://cloud.google.com/security-command-center/docs/reference/security-center-management/rest/v1/organizations.locations.eventThreatDetectionCustomModules/delete + // TODO: Developer should replace project_id with a real project ID before running this code + String projectId = "project_id"; + + String customModuleId = "custom_module_id"; + + deleteEventThreatDetectionCustomModule(projectId, customModuleId); + } + + public static boolean deleteEventThreatDetectionCustomModule( + String projectId, String customModuleId) throws IOException { + + // Initialize client that will be used to send requests. This client only needs + // to be created + // once, and can be reused for multiple requests. + try (SecurityCenterManagementClient client = SecurityCenterManagementClient.create()) { + + String name = + String.format( + "projects/%s/locations/global/eventThreatDetectionCustomModules/%s", + projectId, customModuleId); + + DeleteEventThreatDetectionCustomModuleRequest request = + DeleteEventThreatDetectionCustomModuleRequest.newBuilder().setName(name).build(); + + client.deleteEventThreatDetectionCustomModule(request); + + return true; + } + } +} +// [END securitycenter_delete_event_threat_detection_custom_module] diff --git a/security-command-center/snippets/src/main/java/management/api/GetEventThreatDetectionCustomModule.java b/security-command-center/snippets/src/main/java/management/api/GetEventThreatDetectionCustomModule.java new file mode 100644 index 00000000000..1c9af776fba --- /dev/null +++ b/security-command-center/snippets/src/main/java/management/api/GetEventThreatDetectionCustomModule.java @@ -0,0 +1,60 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package management.api; + +// [START securitycenter_get_event_threat_detection_custom_module] +import com.google.cloud.securitycentermanagement.v1.EventThreatDetectionCustomModule; +import com.google.cloud.securitycentermanagement.v1.GetEventThreatDetectionCustomModuleRequest; +import com.google.cloud.securitycentermanagement.v1.SecurityCenterManagementClient; +import java.io.IOException; + +public class GetEventThreatDetectionCustomModule { + + public static void main(String[] args) throws IOException { + // https://cloud.google.com/security-command-center/docs/reference/security-center-management/rest/v1/organizations.locations.eventThreatDetectionCustomModules/get + // TODO: Developer should replace project_id with a real project ID before running this code + String projectId = "project_id"; + + String customModuleId = "custom_module_id"; + + getEventThreatDetectionCustomModule(projectId, customModuleId); + } + + public static EventThreatDetectionCustomModule getEventThreatDetectionCustomModule( + String projectId, String customModuleId) throws IOException { + + // Initialize client that will be used to send requests. This client only needs + // to be created + // once, and can be reused for multiple requests. + try (SecurityCenterManagementClient client = SecurityCenterManagementClient.create()) { + + String name = + String.format( + "projects/%s/locations/global/eventThreatDetectionCustomModules/%s", + projectId, customModuleId); + + GetEventThreatDetectionCustomModuleRequest request = + GetEventThreatDetectionCustomModuleRequest.newBuilder().setName(name).build(); + + EventThreatDetectionCustomModule response = + client.getEventThreatDetectionCustomModule(request); + + return response; + } + } +} +// [END securitycenter_get_event_threat_detection_custom_module] diff --git a/security-command-center/snippets/src/main/java/management/api/ListEventThreatDetectionCustomModules.java b/security-command-center/snippets/src/main/java/management/api/ListEventThreatDetectionCustomModules.java new file mode 100644 index 00000000000..ed2dfb01118 --- /dev/null +++ b/security-command-center/snippets/src/main/java/management/api/ListEventThreatDetectionCustomModules.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package management.api; + +// [START securitycenter_list_event_threat_detection_custom_module] +import com.google.cloud.securitycentermanagement.v1.ListEventThreatDetectionCustomModulesRequest; +import com.google.cloud.securitycentermanagement.v1.SecurityCenterManagementClient; +import com.google.cloud.securitycentermanagement.v1.SecurityCenterManagementClient.ListEventThreatDetectionCustomModulesPagedResponse; +import java.io.IOException; + +public class ListEventThreatDetectionCustomModules { + + public static void main(String[] args) throws IOException { + // https://cloud.google.com/security-command-center/docs/reference/security-center-management/rest/v1/organizations.locations.eventThreatDetectionCustomModules/list + // TODO: Developer should replace project_id with a real project ID before running this code + String projectId = "project_id"; + + listEventThreatDetectionCustomModules(projectId); + } + + public static ListEventThreatDetectionCustomModulesPagedResponse + listEventThreatDetectionCustomModules(String projectId) throws IOException { + + // Initialize client that will be used to send requests. This client only needs + // to be created + // once, and can be reused for multiple requests. + try (SecurityCenterManagementClient client = SecurityCenterManagementClient.create()) { + + ListEventThreatDetectionCustomModulesRequest request = + ListEventThreatDetectionCustomModulesRequest.newBuilder() + .setParent(String.format("projects/%s/locations/global", projectId)) + .build(); + + ListEventThreatDetectionCustomModulesPagedResponse response = + client.listEventThreatDetectionCustomModules(request); + + return response; + } + } +} +// [END securitycenter_list_event_threat_detection_custom_module] diff --git a/security-command-center/snippets/src/test/java/management/api/EventThreatDetectionCustomModuleTest.java b/security-command-center/snippets/src/test/java/management/api/EventThreatDetectionCustomModuleTest.java new file mode 100644 index 00000000000..4f6330a572f --- /dev/null +++ b/security-command-center/snippets/src/test/java/management/api/EventThreatDetectionCustomModuleTest.java @@ -0,0 +1,168 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package management.api; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.securitycentermanagement.v1.EventThreatDetectionCustomModule; +import com.google.cloud.securitycentermanagement.v1.ListEventThreatDetectionCustomModulesRequest; +import com.google.cloud.securitycentermanagement.v1.SecurityCenterManagementClient; +import com.google.cloud.securitycentermanagement.v1.SecurityCenterManagementClient.ListEventThreatDetectionCustomModulesPagedResponse; +import com.google.cloud.testing.junit4.MultipleAttemptsRule; +import com.google.common.base.Strings; +import java.io.IOException; +import java.util.UUID; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.StreamSupport; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class EventThreatDetectionCustomModuleTest { + // TODO(Developer): Replace the below variable + private static final String PROJECT_ID = System.getenv("SCC_PROJECT_ID"); + private static final String CUSTOM_MODULE_DISPLAY_NAME = + "java_sample_etd_custom_module_test_" + UUID.randomUUID(); + private static final int MAX_ATTEMPT_COUNT = 3; + private static final int INITIAL_BACKOFF_MILLIS = 120000; // 2 minutes + + @Rule + public final MultipleAttemptsRule multipleAttemptsRule = + new MultipleAttemptsRule(MAX_ATTEMPT_COUNT, INITIAL_BACKOFF_MILLIS); + + // Check if the required environment variables are set. + public static void requireEnvVar(String envVarName) { + assertWithMessage(String.format("Missing environment variable '%s' ", envVarName)) + .that(System.getenv(envVarName)) + .isNotEmpty(); + } + + @BeforeClass + public static void setUp() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("SCC_PROJECT_ID"); + } + + @AfterClass + public static void cleanUp() throws IOException { + // Perform cleanup after running tests + cleanupExistingCustomModules(); + } + + // cleanupExistingCustomModules clean up all the existing custom module + private static void cleanupExistingCustomModules() throws IOException { + try (SecurityCenterManagementClient client = SecurityCenterManagementClient.create()) { + ListEventThreatDetectionCustomModulesRequest request = + ListEventThreatDetectionCustomModulesRequest.newBuilder() + .setParent(String.format("projects/%s/locations/global", PROJECT_ID)) + .build(); + ListEventThreatDetectionCustomModulesPagedResponse response = + client.listEventThreatDetectionCustomModules(request); + // Iterate over the response and delete custom module one by one which start with + // java_sample_custom_module + for (EventThreatDetectionCustomModule module : response.iterateAll()) { + try { + if (module.getDisplayName().startsWith("java_sample_etd_custom_module")) { + String customModuleId = extractCustomModuleId(module.getName()); + deleteCustomModule(PROJECT_ID, customModuleId); + } + } catch (Exception e) { + System.err.println("Failed to delete module: " + module.getDisplayName()); + e.printStackTrace(); + } + } + } catch (Exception e) { + System.err.println("Failed to process cleanupExistingCustomModules."); + e.printStackTrace(); + } + } + + // extractCustomModuleID extracts the custom module Id from the full name and below regex will + // parses suffix after the last slash character. + private static String extractCustomModuleId(String customModuleFullName) { + if (!Strings.isNullOrEmpty(customModuleFullName)) { + Pattern pattern = Pattern.compile(".*/([^/]+)$"); + Matcher matcher = pattern.matcher(customModuleFullName); + if (matcher.find()) { + return matcher.group(1); + } + } + return ""; + } + + // deleteCustomModule method is for deleting the custom module + private static void deleteCustomModule(String projectId, String customModuleId) + throws IOException { + if (!Strings.isNullOrEmpty(projectId) && !Strings.isNullOrEmpty(customModuleId)) { + DeleteEventThreatDetectionCustomModule.deleteEventThreatDetectionCustomModule( + projectId, customModuleId); + } + } + + @Test + public void testCreateEventThreatDetectionCustomModule() throws IOException { + EventThreatDetectionCustomModule response = + CreateEventThreatDetectionCustomModule.createEventThreatDetectionCustomModule( + PROJECT_ID, CUSTOM_MODULE_DISPLAY_NAME); + assertNotNull(response); + assertThat(response.getDisplayName()).isEqualTo(CUSTOM_MODULE_DISPLAY_NAME); + } + + @Test + public void testDeleteEventThreatDetectionCustomModule() throws IOException { + EventThreatDetectionCustomModule response = + CreateEventThreatDetectionCustomModule.createEventThreatDetectionCustomModule( + PROJECT_ID, CUSTOM_MODULE_DISPLAY_NAME); + String customModuleId = extractCustomModuleId(response.getName()); + assertTrue( + DeleteEventThreatDetectionCustomModule.deleteEventThreatDetectionCustomModule( + PROJECT_ID, customModuleId)); + } + + @Test + public void testListEventThreatDetectionCustomModules() throws IOException { + CreateEventThreatDetectionCustomModule.createEventThreatDetectionCustomModule( + PROJECT_ID, CUSTOM_MODULE_DISPLAY_NAME); + ListEventThreatDetectionCustomModulesPagedResponse response = + ListEventThreatDetectionCustomModules.listEventThreatDetectionCustomModules(PROJECT_ID); + assertTrue( + StreamSupport.stream(response.iterateAll().spliterator(), false) + .anyMatch(module -> CUSTOM_MODULE_DISPLAY_NAME.equals(module.getDisplayName()))); + } + + @Test + public void testGetEventThreatDetectionCustomModule() throws IOException { + EventThreatDetectionCustomModule response = + CreateEventThreatDetectionCustomModule.createEventThreatDetectionCustomModule( + PROJECT_ID, CUSTOM_MODULE_DISPLAY_NAME); + String customModuleId = extractCustomModuleId(response.getName()); + EventThreatDetectionCustomModule getCustomModuleResponse = + GetEventThreatDetectionCustomModule.getEventThreatDetectionCustomModule( + PROJECT_ID, customModuleId); + + assertThat(getCustomModuleResponse.getDisplayName()).isEqualTo(CUSTOM_MODULE_DISPLAY_NAME); + assertThat(extractCustomModuleId(getCustomModuleResponse.getName())).isEqualTo(customModuleId); + } +} diff --git a/tpu/src/main/java/tpu/CreateQueuedResource.java b/tpu/src/main/java/tpu/CreateQueuedResource.java index 24fc7802d52..421acff2d04 100644 --- a/tpu/src/main/java/tpu/CreateQueuedResource.java +++ b/tpu/src/main/java/tpu/CreateQueuedResource.java @@ -17,26 +17,25 @@ package tpu; //[START tpu_queued_resources_create] -import com.google.api.gax.retrying.RetrySettings; import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest; import com.google.cloud.tpu.v2alpha1.Node; import com.google.cloud.tpu.v2alpha1.QueuedResource; import com.google.cloud.tpu.v2alpha1.TpuClient; -import com.google.cloud.tpu.v2alpha1.TpuSettings; import java.io.IOException; import java.util.concurrent.ExecutionException; -import org.threeten.bp.Duration; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; public class CreateQueuedResource { public static void main(String[] args) - throws IOException, ExecutionException, InterruptedException { + throws IOException, ExecutionException, InterruptedException, TimeoutException { // TODO(developer): Replace these variables before running the sample. // Project ID or project number of the Google Cloud project you want to create a node. String projectId = "YOUR_PROJECT_ID"; // The zone in which to create the TPU. // For more information about supported TPU types for specific zones, // see https://cloud.google.com/tpu/docs/regions-zones - String zone = "europe-west4-a"; + String zone = "us-central1-f"; // The name for your TPU. String nodeName = "YOUR_NODE_ID"; // The accelerator type that specifies the version and size of the Cloud TPU you want to create. @@ -56,35 +55,19 @@ public static void main(String[] args) // Creates a Queued Resource public static QueuedResource createQueuedResource(String projectId, String zone, String queuedResourceId, String nodeName, String tpuType, String tpuSoftwareVersion) - throws IOException, ExecutionException, InterruptedException { - // With these settings the client library handles the Operation's polling mechanism - // and prevent CancellationException error - TpuSettings.Builder clientSettings = - TpuSettings.newBuilder(); - clientSettings - .createQueuedResourceSettings() - .setRetrySettings( - RetrySettings.newBuilder() - .setInitialRetryDelay(Duration.ofMillis(5000L)) - .setRetryDelayMultiplier(2.0) - .setInitialRpcTimeout(Duration.ZERO) - .setRpcTimeoutMultiplier(1.0) - .setMaxRetryDelay(Duration.ofMillis(45000L)) - .setTotalTimeout(Duration.ofHours(24L)) - .build()); + throws IOException, ExecutionException, InterruptedException, TimeoutException { + String resource = String.format("projects/%s/locations/%s/queuedResources/%s", + projectId, zone, queuedResourceId); // Initialize client that will be used to send requests. This client only needs to be created // once, and can be reused for multiple requests. - try (TpuClient tpuClient = TpuClient.create(clientSettings.build())) { + try (TpuClient tpuClient = TpuClient.create()) { String parent = String.format("projects/%s/locations/%s", projectId, zone); Node node = Node.newBuilder() .setName(nodeName) .setAcceleratorType(tpuType) .setRuntimeVersion(tpuSoftwareVersion) - .setQueuedResource( - String.format( - "projects/%s/locations/%s/queuedResources/%s", - projectId, zone, queuedResourceId)) + .setQueuedResource(resource) .build(); QueuedResource queuedResource = @@ -99,9 +82,6 @@ public static QueuedResource createQueuedResource(String projectId, String zone, .setNodeId(nodeName) .build()) .build()) - // You can request a queued resource using a reservation by specifying it in code - //.setReservationName( - // "projects/YOUR_PROJECT_ID/locations/YOUR_ZONE/reservations/YOUR_RESERVATION_NAME") .build(); CreateQueuedResourceRequest request = @@ -111,11 +91,7 @@ public static QueuedResource createQueuedResource(String projectId, String zone, .setQueuedResource(queuedResource) .build(); - QueuedResource response = tpuClient.createQueuedResourceAsync(request).get(); - // You can wait until TPU Node is READY, - // and check its status using getTpuVm() from "tpu_vm_get" sample. - System.out.printf("Queued Resource created: %s\n", response.getName()); - return response; + return tpuClient.createQueuedResourceAsync(request).get(1, TimeUnit.MINUTES); } } } diff --git a/tpu/src/main/java/tpu/CreateSpotQueuedResource.java b/tpu/src/main/java/tpu/CreateSpotQueuedResource.java new file mode 100644 index 00000000000..b281d87abd9 --- /dev/null +++ b/tpu/src/main/java/tpu/CreateSpotQueuedResource.java @@ -0,0 +1,103 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tpu; + +// [START tpu_queued_resources_create_spot] +import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest; +import com.google.cloud.tpu.v2alpha1.Node; +import com.google.cloud.tpu.v2alpha1.QueuedResource; +import com.google.cloud.tpu.v2alpha1.SchedulingConfig; +import com.google.cloud.tpu.v2alpha1.TpuClient; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +public class CreateSpotQueuedResource { + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + // Project ID or project number of the Google Cloud project you want to create a node. + String projectId = "YOUR_PROJECT_ID"; + // The zone in which to create the TPU. + // For more information about supported TPU types for specific zones, + // see https://cloud.google.com/tpu/docs/regions-zones + String zone = "us-central1-f"; + // The name for your TPU. + String nodeName = "YOUR_TPU_NAME"; + // The accelerator type that specifies the version and size of the Cloud TPU you want to create. + // For more information about supported accelerator types for each TPU version, + // see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions. + String tpuType = "v2-8"; + // Software version that specifies the version of the TPU runtime to install. + // For more information see https://cloud.google.com/tpu/docs/runtimes + String tpuSoftwareVersion = "tpu-vm-tf-2.14.1"; + // The name for your Queued Resource. + String queuedResourceId = "QUEUED_RESOURCE_ID"; + + createQueuedResource( + projectId, zone, queuedResourceId, nodeName, tpuType, tpuSoftwareVersion); + } + + // Creates a Queued Resource with --preemptible flag. + public static QueuedResource createQueuedResource( + String projectId, String zone, String queuedResourceId, + String nodeName, String tpuType, String tpuSoftwareVersion) + throws IOException, ExecutionException, InterruptedException { + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. + try (TpuClient tpuClient = TpuClient.create()) { + String parent = String.format("projects/%s/locations/%s", projectId, zone); + String resourceName = String.format("projects/%s/locations/%s/queuedResources/%s", + projectId, zone, queuedResourceId); + SchedulingConfig schedulingConfig = SchedulingConfig.newBuilder() + .setPreemptible(true) + .build(); + + Node node = + Node.newBuilder() + .setName(nodeName) + .setAcceleratorType(tpuType) + .setRuntimeVersion(tpuSoftwareVersion) + .setSchedulingConfig(schedulingConfig) + .setQueuedResource(resourceName) + .build(); + + QueuedResource queuedResource = + QueuedResource.newBuilder() + .setName(queuedResourceId) + .setTpu( + QueuedResource.Tpu.newBuilder() + .addNodeSpec( + QueuedResource.Tpu.NodeSpec.newBuilder() + .setParent(parent) + .setNode(node) + .setNodeId(nodeName) + .build()) + .build()) + .build(); + + CreateQueuedResourceRequest request = + CreateQueuedResourceRequest.newBuilder() + .setParent(parent) + .setQueuedResourceId(queuedResourceId) + .setQueuedResource(queuedResource) + .build(); + + return tpuClient.createQueuedResourceAsync(request).get(); + } + } +} +// [END tpu_queued_resources_create_spot] diff --git a/tpu/src/main/java/tpu/DeleteQueuedResource.java b/tpu/src/main/java/tpu/DeleteQueuedResource.java index 0f592ff165f..9f0e123a43e 100644 --- a/tpu/src/main/java/tpu/DeleteQueuedResource.java +++ b/tpu/src/main/java/tpu/DeleteQueuedResource.java @@ -17,25 +17,19 @@ package tpu; //[START tpu_queued_resources_delete] -import com.google.api.gax.retrying.RetrySettings; -import com.google.api.gax.rpc.UnknownException; import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest; -import com.google.cloud.tpu.v2alpha1.GetQueuedResourceRequest; -import com.google.cloud.tpu.v2alpha1.QueuedResource; import com.google.cloud.tpu.v2alpha1.TpuClient; -import com.google.cloud.tpu.v2alpha1.TpuSettings; import java.io.IOException; import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import org.threeten.bp.Duration; public class DeleteQueuedResource { - public static void main(String[] args) { + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { // TODO(developer): Replace these variables before running the sample. // Project ID or project number of the Google Cloud project. String projectId = "YOUR_PROJECT_ID"; // The zone in which the TPU was created. - String zone = "europe-west4-a"; + String zone = "us-central1-f"; // The name for your Queued Resource. String queuedResourceId = "QUEUED_RESOURCE_ID"; @@ -43,46 +37,22 @@ public static void main(String[] args) { } // Deletes a Queued Resource asynchronously. - public static void deleteQueuedResource(String projectId, String zone, String queuedResourceId) { + public static void deleteQueuedResource(String projectId, String zone, String queuedResourceId) + throws ExecutionException, InterruptedException, IOException { String name = String.format("projects/%s/locations/%s/queuedResources/%s", projectId, zone, queuedResourceId); - // With these settings the client library handles the Operation's polling mechanism - // and prevent CancellationException error - TpuSettings.Builder clientSettings = - TpuSettings.newBuilder(); - clientSettings - .deleteQueuedResourceSettings() - .setRetrySettings( - RetrySettings.newBuilder() - .setInitialRetryDelay(Duration.ofMillis(5000L)) - .setRetryDelayMultiplier(2.0) - .setInitialRpcTimeout(Duration.ZERO) - .setRpcTimeoutMultiplier(1.0) - .setMaxRetryDelay(Duration.ofMillis(45000L)) - .setTotalTimeout(Duration.ofHours(24L)) - .build()); // Initialize client that will be used to send requests. This client only needs to be created // once, and can be reused for multiple requests. - try (TpuClient tpuClient = TpuClient.create(clientSettings.build())) { - // Retrive node name - GetQueuedResourceRequest getRequest = - GetQueuedResourceRequest.newBuilder().setName(name).build(); - QueuedResource queuedResource = tpuClient.getQueuedResource(getRequest); - String nodeName = queuedResource.getTpu().getNodeSpec(0).getNode().getName(); + try (TpuClient tpuClient = TpuClient.create()) { // Before deleting the queued resource it is required to delete the TPU VM. - DeleteTpuVm.deleteTpuVm(projectId, zone, nodeName); - // Wait until TpuVm is deleted - TimeUnit.MINUTES.sleep(3); + // For more information about deleting TPU + // see https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm DeleteQueuedResourceRequest request = - DeleteQueuedResourceRequest.newBuilder().setName(name).build(); + DeleteQueuedResourceRequest.newBuilder().setName(name).build(); tpuClient.deleteQueuedResourceAsync(request).get(); - - } catch (UnknownException | InterruptedException | ExecutionException | IOException e) { - System.out.println(e.getMessage()); } - System.out.printf("Deleted Queued Resource: %s\n", name); } } //[END tpu_queued_resources_delete] diff --git a/tpu/src/main/java/tpu/GetQueuedResource.java b/tpu/src/main/java/tpu/GetQueuedResource.java index 80c6762327c..588987a25f0 100644 --- a/tpu/src/main/java/tpu/GetQueuedResource.java +++ b/tpu/src/main/java/tpu/GetQueuedResource.java @@ -28,7 +28,7 @@ public static void main(String[] args) throws IOException { // Project ID or project number of the Google Cloud project. String projectId = "YOUR_PROJECT_ID"; // The zone in which the TPU was created. - String zone = "europe-west4-a"; + String zone = "us-central1-f"; // The name for your Queued Resource. String queuedResourceId = "QUEUED_RESOURCE_ID"; diff --git a/tpu/src/test/java/tpu/QueuedResourceIT.java b/tpu/src/test/java/tpu/QueuedResourceIT.java index 0981e6078fb..7d164105f25 100644 --- a/tpu/src/test/java/tpu/QueuedResourceIT.java +++ b/tpu/src/test/java/tpu/QueuedResourceIT.java @@ -17,6 +17,7 @@ package tpu; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; @@ -33,6 +34,7 @@ import com.google.cloud.tpu.v2alpha1.TpuSettings; import java.io.IOException; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.runner.RunWith; @@ -40,7 +42,7 @@ import org.mockito.MockedStatic; @RunWith(JUnit4.class) -@Timeout(value = 10) +@Timeout(value = 2, unit = TimeUnit.MINUTES) public class QueuedResourceIT { private static final String PROJECT_ID = "project-id"; private static final String ZONE = "europe-west4-a"; @@ -50,6 +52,30 @@ public class QueuedResourceIT { private static final String QUEUED_RESOURCE_NAME = "queued-resource"; private static final String NETWORK_NAME = "default"; + @Test + public void testCreateQueuedResource() throws Exception { + try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { + QueuedResource mockQueuedResource = mock(QueuedResource.class); + TpuClient mockTpuClient = mock(TpuClient.class); + OperationFuture mockFuture = mock(OperationFuture.class); + + mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient); + when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class))) + .thenReturn(mockFuture); + when(mockFuture.get(anyLong(), any(TimeUnit.class))).thenReturn(mockQueuedResource); + + QueuedResource returnedQueuedResource = + CreateQueuedResource.createQueuedResource( + PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME, + TPU_TYPE, TPU_SOFTWARE_VERSION); + + verify(mockTpuClient, times(1)) + .createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)); + verify(mockFuture, times(1)).get(anyLong(), any(TimeUnit.class)); + assertEquals(returnedQueuedResource, mockQueuedResource); + } + } + @Test public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { @@ -113,6 +139,25 @@ public void testDeleteForceQueuedResource() } } + @Test + public void testDeleteQueuedResource() + throws IOException, ExecutionException, InterruptedException { + try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { + TpuClient mockTpuClient = mock(TpuClient.class); + OperationFuture mockFuture = mock(OperationFuture.class); + + mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient); + when(mockTpuClient.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class))) + .thenReturn(mockFuture); + when(mockFuture.get()).thenReturn(null); + + DeleteQueuedResource.deleteQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME); + + verify(mockTpuClient, times(1)) + .deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class)); + } + } + @Test public void testCreateQueuedResourceWithStartupScript() throws Exception { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { @@ -137,6 +182,32 @@ public void testCreateQueuedResourceWithStartupScript() throws Exception { } } + @Test + public void testCreateSpotQueuedResource() throws Exception { + try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { + QueuedResource mockQueuedResource = QueuedResource.newBuilder() + .setName("QueuedResourceName") + .build(); + TpuClient mockedClientInstance = mock(TpuClient.class); + OperationFuture mockFuture = mock(OperationFuture.class); + + mockedTpuClient.when(TpuClient::create).thenReturn(mockedClientInstance); + when(mockedClientInstance.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class))) + .thenReturn(mockFuture); + when(mockFuture.get()).thenReturn(mockQueuedResource); + + QueuedResource returnedQueuedResource = + CreateSpotQueuedResource.createQueuedResource( + PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME, + TPU_TYPE, TPU_SOFTWARE_VERSION); + + verify(mockedClientInstance, times(1)) + .createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)); + verify(mockFuture, times(1)).get(); + assertEquals(returnedQueuedResource.getName(), mockQueuedResource.getName()); + } + } + @Test public void testCreateTimeBoundQueuedResource() throws Exception { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) {