Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
feat: add kill_infrastructure implementation for vertex ai worker (#…
Browse files Browse the repository at this point in the history
…213)

* feat: add kill_infrastructure implementation for vertex

* add unit tests

* fix tests

* maybe assert like this

* Updates changelog for upcoming release

* add default Nones to optional props

---------

Co-authored-by: Alexander Streed <alex.s@prefect.io>
  • Loading branch information
parkedwards and desertaxle authored Sep 22, 2023
1 parent 6327600 commit a21aca4
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 11 deletions.
18 changes: 15 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

### Changed
- Use flow run name for name of created custom jobs - [#208](https://github.com/PrefectHQ/prefect-gcp/pull/208)

### Deprecated

Expand All @@ -20,13 +19,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Security

## 0.4.7

Released September 22nd, 2023.

### Added

- Vertex AI `CustomJob` worker - [#211](https://github.com/PrefectHQ/prefect-gcp/pull/211)
- Add `kill_infrastructure` method to Vertex AI worker - [#213](https://github.com/PrefectHQ/prefect-gcp/pull/213)

### Changed

- Use flow run name for name of created custom jobs - [#208](https://github.com/PrefectHQ/prefect-gcp/pull/208)

## 0.4.6

Not yet released
Released September 5th, 2023.

### Changed

- Vertex AI CustomJob sets labels specified by Prefect Agent when Deployment triggered on infrastructure.
- Persist Labels to Vertex AI Custom Job - [#198](https://github.com/PrefectHQ/prefect-gcp/pull/208)

## 0.4.5

Expand Down
65 changes: 58 additions & 7 deletions prefect_gcp/workers/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from uuid import uuid4

import anyio
from prefect.exceptions import InfrastructureNotFound
from prefect.logging.loggers import PrefectLogAdapter
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from prefect.utilities.pydantic import JsonPatch
Expand All @@ -52,6 +53,7 @@
Scheduling,
WorkerPoolSpec,
)
from google.cloud.aiplatform_v1.types.job_service import CancelCustomJobRequest
from google.cloud.aiplatform_v1.types.job_state import JobState
from google.cloud.aiplatform_v1.types.machine_resources import DiskSpec, MachineSpec
from google.protobuf.duration_pb2 import Duration
Expand Down Expand Up @@ -111,15 +113,17 @@ class VertexAIWorkerVariables(BaseVariables):
"The type of accelerator to attach to the machine. "
"See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec"
),
default="NVIDIA_TESLA_K80",
example="NVIDIA_TESLA_K80",
default=None,
)
accelerator_count: int = Field(
accelerator_count: Optional[int] = Field(
title="Accelerator Count",
description=(
"The number of accelerators to attach to the machine. "
"See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec"
),
default=0,
example=1,
default=None,
)
boot_disk_type: str = Field(
title="Boot Disk Type",
Expand Down Expand Up @@ -188,8 +192,6 @@ def _get_base_job_spec() -> Dict[str, Any]:
},
"machine_spec": {
"machine_type": "n1-standard-4",
"accelerator_type": "NVIDIA_TESLA_K80",
"accelerator_count": "1",
},
"disk_spec": {
"boot_disk_type": "pd-ssd",
Expand Down Expand Up @@ -410,7 +412,7 @@ async def run(
)

if task_status:
task_status.started(job_name)
task_status.started(job_run.name)

final_job_run = await self._watch_job_run(
job_name=job_name,
Expand All @@ -433,7 +435,10 @@ async def run(
)

error_msg = final_job_run.error.message
if error_msg:

# Vertex will include an error message upon valid
# flow cancellations, so we'll avoid raising an error in that case
if error_msg and "CANCELED" not in error_msg:
raise RuntimeError(error_msg)

status_code = 0 if final_job_run.state == JobState.JOB_STATE_SUCCEEDED else 1
Expand Down Expand Up @@ -590,3 +595,49 @@ def _get_compatible_labels(
regex_pattern=_DISALLOWED_GCP_LABEL_CHARACTERS,
)
return compatible_labels

async def kill_infrastructure(
self,
infrastructure_pid: str,
configuration: VertexAIWorkerJobConfiguration,
grace_seconds: int = 30,
):
"""
Stops a job running in Vertex AI upon flow cancellation,
based on the provided infrastructure PID + run configuration.
"""
if grace_seconds != 30:
self._logger.warning(
f"Kill grace period of {grace_seconds}s requested, but GCP does not "
"support dynamic grace period configuration. See here for more info: "
"https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs/cancel" # noqa
)

client_options = ClientOptions(
api_endpoint=f"{configuration.region}-aiplatform.googleapis.com"
)
with configuration.credentials.get_job_service_client(
client_options=client_options
) as job_service_client:
await run_sync_in_worker_thread(
self._stop_job,
client=job_service_client,
vertex_job_name=infrastructure_pid,
)

def _stop_job(self, client: "JobServiceClient", vertex_job_name: str):
"""
Calls the `cancel_custom_job` method on the Vertex AI Job Service Client.
"""
cancel_custom_job_request = CancelCustomJobRequest(name=vertex_job_name)
try:
client.cancel_custom_job(
request=cancel_custom_job_request,
)
except Exception as exc:
if "does not exist" in str(exc):
raise InfrastructureNotFound(
f"Cannot stop Vertex AI job; the job name {vertex_job_name!r} "
"could not be found."
) from exc
raise
59 changes: 58 additions & 1 deletion tests/test_vertex_worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock

import anyio
import pydantic
import pytest
from google.cloud.aiplatform_v1.types.job_service import CancelCustomJobRequest
from google.cloud.aiplatform_v1.types.job_state import JobState
from prefect.client.schemas import FlowRun
from prefect.exceptions import InfrastructureNotFound

from prefect_gcp.workers.vertex import (
VertexAIWorker,
Expand Down Expand Up @@ -98,7 +102,6 @@ async def test_validate_incomplete_worker_pool_spec(self, gcp_credentials):
"Job is missing required attributes at the following paths: "
"/worker_pool_specs/0/container_spec/image_uri, "
"/worker_pool_specs/0/disk_spec, "
"/worker_pool_specs/0/machine_spec/accelerator_count, "
"/worker_pool_specs/0/machine_spec/machine_type"
),
"type": "value_error",
Expand Down Expand Up @@ -205,3 +208,57 @@ async def test_cancelled_worker_run(self, flow_run, job_config):
assert result == VertexAIWorkerResult(
status_code=1, identifier=job_display_name
)

async def test_kill_infrastructure(self, flow_run, job_config):
mock = job_config.credentials.job_service_client.create_custom_job
# the CancelCustomJobRequest class seems to reject a MagicMock value
# so here, we'll use a SimpleNamespace as the mocked return values
mock.return_value = SimpleNamespace(
name="foobar", state=JobState.JOB_STATE_PENDING
)

async with VertexAIWorker("test-pool") as worker:
with anyio.fail_after(10):
async with anyio.create_task_group() as tg:
result = await tg.start(worker.run, flow_run, job_config)
await worker.kill_infrastructure(result, job_config)

mock = job_config.credentials.job_service_client.cancel_custom_job
assert mock.call_count == 1
mock.assert_called_with(request=CancelCustomJobRequest(name="foobar"))

async def test_kill_infrastructure_no_grace_seconds(
self, flow_run, job_config, caplog
):
mock = job_config.credentials.job_service_client.create_custom_job
mock.return_value = SimpleNamespace(
name="bazzbar", state=JobState.JOB_STATE_PENDING
)
async with VertexAIWorker("test-pool") as worker:

input_grace_period = 32

with anyio.fail_after(10):
async with anyio.create_task_group() as tg:
identifier = await tg.start(worker.run, flow_run, job_config)
await worker.kill_infrastructure(
identifier, job_config, input_grace_period
)
for record in caplog.records:
if (
f"Kill grace period of {input_grace_period}s "
"requested, but GCP does not"
) in record.msg:
break
else:
raise AssertionError("Expected message not found.")

async def test_kill_infrastructure_not_found(self, job_config):
async with VertexAIWorker("test-pool") as worker:
job_config.credentials.job_service_client.cancel_custom_job.side_effect = (
Exception("does not exist")
)
with pytest.raises(
InfrastructureNotFound, match="Cannot stop Vertex AI job"
):
await worker.kill_infrastructure("foobarbazz", job_config)

0 comments on commit a21aca4

Please sign in to comment.