From 3c06654ec8be4f9e9bf5c304814f163b7727d28e Mon Sep 17 00:00:00 2001 From: nate nowack Date: Mon, 2 Dec 2024 18:57:11 -0600 Subject: [PATCH] migrate `prefect_aws.secrets_manager` off `sync_compatible` (#16169) --- .../prefect_aws/secrets_manager.py | 164 +++++++++++-- .../prefect-aws/tests/test_secrets_manager.py | 220 ++++++++++++------ 2 files changed, 291 insertions(+), 93 deletions(-) diff --git a/src/integrations/prefect-aws/prefect_aws/secrets_manager.py b/src/integrations/prefect-aws/prefect_aws/secrets_manager.py index 82a695d02141..c13404b3246f 100644 --- a/src/integrations/prefect-aws/prefect_aws/secrets_manager.py +++ b/src/integrations/prefect-aws/prefect_aws/secrets_manager.py @@ -6,9 +6,10 @@ from pydantic import Field from prefect import task +from prefect._internal.compatibility.async_dispatch import async_dispatch from prefect.blocks.abstract import SecretBlock from prefect.logging import get_run_logger -from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible +from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect_aws import AwsCredentials @@ -365,22 +366,21 @@ class AwsSecret(SecretBlock): secret_name: The name of the secret. """ - _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # type: ignore _block_type_name = "AWS Secret" - _documentation_url = "https://docs.prefect.io/integrations/prefect-aws" # noqa + _documentation_url = "https://docs.prefect.io/integrations/prefect-aws" # type: ignore aws_credentials: AwsCredentials secret_name: str = Field(default=..., description="The name of the secret.") - @sync_compatible - async def read_secret( + async def aread_secret( self, version_id: Optional[str] = None, version_stage: Optional[str] = None, - **read_kwargs: Dict[str, Any], + **read_kwargs: Any, ) -> bytes: """ - Reads the secret from the secret storage service. + Asynchronously reads the secret from the secret storage service. Args: version_id: The version of the secret to read. If not provided, the latest @@ -397,7 +397,7 @@ async def read_secret( Reads a secret. ```python secrets_manager = SecretsManager.load("MY_BLOCK") - secrets_manager.read_secret() + await secrets_manager.aread_secret() ``` """ client = self.aws_credentials.get_secrets_manager_client() @@ -416,12 +416,53 @@ async def read_secret( self.logger.info(f"The secret {arn!r} data was successfully read.") return secret - @sync_compatible - async def write_secret( + @async_dispatch(aread_secret) + def read_secret( + self, + version_id: Optional[str] = None, + version_stage: Optional[str] = None, + **read_kwargs: Any, + ) -> bytes: + """ + Reads the secret from the secret storage service. + + Args: + version_id: The version of the secret to read. If not provided, the latest + version will be read. + version_stage: The version stage of the secret to read. If not provided, + the latest version will be read. + read_kwargs: Additional keyword arguments to pass to the + `get_secret_value` method of the boto3 client. + + Returns: + The secret data. + + Examples: + Reads a secret. + ```python + secrets_manager = SecretsManager.load("MY_BLOCK") + secrets_manager.read_secret() + ``` + """ + client = self.aws_credentials.get_secrets_manager_client() + if version_id is not None: + read_kwargs["VersionId"] = version_id + if version_stage is not None: + read_kwargs["VersionStage"] = version_stage + response = client.get_secret_value(SecretId=self.secret_name, **read_kwargs) + if "SecretBinary" in response: + secret = response["SecretBinary"] + elif "SecretString" in response: + secret = response["SecretString"] + arn = response["ARN"] + self.logger.info(f"The secret {arn!r} data was successfully read.") + return secret + + async def awrite_secret( self, secret_data: bytes, **put_or_create_secret_kwargs: Dict[str, Any] ) -> str: """ - Writes the secret to the secret storage service as a SecretBinary; + Asynchronously writes the secret to the secret storage service as a SecretBinary; if it doesn't exist, it will be created. Args: @@ -436,7 +477,7 @@ async def write_secret( Write some secret data. ```python secrets_manager = SecretsManager.load("MY_BLOCK") - secrets_manager.write_secret(b"my_secret_data") + await secrets_manager.awrite_secret(b"my_secret_data") ``` """ client = self.aws_credentials.get_secrets_manager_client() @@ -461,15 +502,57 @@ async def write_secret( self.logger.info(f"The secret data was written successfully to {arn!r}.") return arn - @sync_compatible - async def delete_secret( + @async_dispatch(awrite_secret) + def write_secret( + self, secret_data: bytes, **put_or_create_secret_kwargs: Dict[str, Any] + ) -> str: + """ + Writes the secret to the secret storage service as a SecretBinary; + if it doesn't exist, it will be created. + + Args: + secret_data: The secret data to write. + **put_or_create_secret_kwargs: Additional keyword arguments to pass to + put_secret_value or create_secret method of the boto3 client. + + Returns: + The path that the secret was written to. + + Examples: + Write some secret data. + ```python + secrets_manager = SecretsManager.load("MY_BLOCK") + secrets_manager.write_secret(b"my_secret_data") + ``` + """ + client = self.aws_credentials.get_secrets_manager_client() + try: + response = client.put_secret_value( + SecretId=self.secret_name, + SecretBinary=secret_data, + **put_or_create_secret_kwargs, + ) + except client.exceptions.ResourceNotFoundException: + self.logger.info( + f"The secret {self.secret_name!r} does not exist yet, creating it now." + ) + response = client.create_secret( + Name=self.secret_name, + SecretBinary=secret_data, + **put_or_create_secret_kwargs, + ) + arn = response["ARN"] + self.logger.info(f"The secret data was written successfully to {arn!r}.") + return arn + + async def adelete_secret( self, recovery_window_in_days: int = 30, force_delete_without_recovery: bool = False, **delete_kwargs: Dict[str, Any], ) -> str: """ - Deletes the secret from the secret storage service. + Asynchronously deletes the secret from the secret storage service. Args: recovery_window_in_days: The number of days to wait before permanently @@ -486,7 +569,7 @@ async def delete_secret( Deletes the secret with a recovery window of 15 days. ```python secrets_manager = SecretsManager.load("MY_BLOCK") - secrets_manager.delete_secret(recovery_window_in_days=15) + await secrets_manager.adelete_secret(recovery_window_in_days=15) ``` """ if force_delete_without_recovery and recovery_window_in_days: @@ -510,3 +593,52 @@ async def delete_secret( arn = response["ARN"] self.logger.info(f"The secret {arn} was deleted successfully.") return arn + + @async_dispatch(adelete_secret) + def delete_secret( + self, + recovery_window_in_days: int = 30, + force_delete_without_recovery: bool = False, + **delete_kwargs: Dict[str, Any], + ) -> str: + """ + Deletes the secret from the secret storage service. + + Args: + recovery_window_in_days: The number of days to wait before permanently + deleting the secret. Must be between 7 and 30 days. + force_delete_without_recovery: If True, the secret will be deleted + immediately without a recovery window. + **delete_kwargs: Additional keyword arguments to pass to the + delete_secret method of the boto3 client. + + Returns: + The path that the secret was deleted from. + + Examples: + Deletes the secret with a recovery window of 15 days. + ```python + secrets_manager = SecretsManager.load("MY_BLOCK") + secrets_manager.delete_secret(recovery_window_in_days=15) + ``` + """ + if force_delete_without_recovery and recovery_window_in_days: + raise ValueError( + "Cannot specify recovery window and force delete without recovery." + ) + elif not (7 <= recovery_window_in_days <= 30): + raise ValueError( + "Recovery window must be between 7 and 30 days, got " + f"{recovery_window_in_days}." + ) + + client = self.aws_credentials.get_secrets_manager_client() + response = client.delete_secret( + SecretId=self.secret_name, + RecoveryWindowInDays=recovery_window_in_days, + ForceDeleteWithoutRecovery=force_delete_without_recovery, + **delete_kwargs, + ) + arn = response["ARN"] + self.logger.info(f"The secret {arn} was deleted successfully.") + return arn diff --git a/src/integrations/prefect-aws/tests/test_secrets_manager.py b/src/integrations/prefect-aws/tests/test_secrets_manager.py index b1a479125eb2..9c1f2dc41a6d 100644 --- a/src/integrations/prefect-aws/tests/test_secrets_manager.py +++ b/src/integrations/prefect-aws/tests/test_secrets_manager.py @@ -8,8 +8,6 @@ AwsSecret, create_secret, delete_secret, - read_secret, - update_secret, ) from prefect import flow @@ -57,47 +55,159 @@ def secret_under_test(secretsmanager_client, request): ) -async def test_read_secret(secret_under_test, aws_credentials): - expected_value = secret_under_test.pop("expected_value") +class TestAwsSecretSync: + """Test synchronous AwsSecret methods""" - @flow - async def test_flow(): - return await read_secret( - aws_credentials=aws_credentials, - **secret_under_test, + async def test_read_secret(self, secret_under_test, aws_credentials): + expected_value = secret_under_test.pop("expected_value") + secret_name = secret_under_test.pop( + "secret_name" + ) # Remove secret_name from kwargs + + @flow + async def test_flow(): + secret = AwsSecret( + aws_credentials=aws_credentials, + secret_name=secret_name, # Use for AwsSecret initialization + ) + # Pass remaining kwargs (version_id, version_stage) if present + return await secret.read_secret(**secret_under_test) + + assert (await test_flow()) == expected_value + + async def test_write_secret(self, aws_credentials, secretsmanager_client): + secret = AwsSecret(aws_credentials=aws_credentials, secret_name="my-test") + secret_value = b"test-secret-value" + + @flow + async def test_flow(): + return await secret.write_secret(secret_value) + + arn = await test_flow() + assert arn.startswith("arn:aws:secretsmanager") + + # Verify the secret was written correctly + response = secretsmanager_client.get_secret_value(SecretId="my-test") + assert response["SecretBinary"] == secret_value + + async def test_delete_secret(self, aws_credentials, secretsmanager_client): + # First create a secret to delete + secret = AwsSecret(aws_credentials=aws_credentials, secret_name="test-delete") + secret_value = b"delete-me" + + @flow + async def setup_flow(): + return await secret.write_secret(secret_value) + + arn = await setup_flow() + + # Now test deletion + @flow + async def test_flow(): + return await secret.delete_secret( + recovery_window_in_days=7, force_delete_without_recovery=False + ) + + deleted_arn = await test_flow() + assert deleted_arn == arn + + # Verify the secret is scheduled for deletion + with pytest.raises(secretsmanager_client.exceptions.InvalidRequestException): + secretsmanager_client.get_secret_value(SecretId="test-delete") + + async def test_delete_secret_validation(self, aws_credentials): + secret = AwsSecret( + aws_credentials=aws_credentials, secret_name="test-validation" ) - assert (await test_flow()) == expected_value + with pytest.raises(ValueError, match="Cannot specify recovery window"): + await secret.delete_secret( + force_delete_without_recovery=True, recovery_window_in_days=10 + ) + with pytest.raises( + ValueError, match="Recovery window must be between 7 and 30 days" + ): + await secret.delete_secret(recovery_window_in_days=42) -async def test_update_secret(secret_under_test, aws_credentials, secretsmanager_client): - current_secret_value = secret_under_test["expected_value"] - new_secret_value = ( - current_secret_value + "2" - if isinstance(current_secret_value, str) - else current_secret_value + b"2" - ) - @flow - async def test_flow(): - return await update_secret( - aws_credentials=aws_credentials, - secret_name=secret_under_test["secret_name"], - secret_value=new_secret_value, +class TestAwsSecretAsync: + """Test asynchronous AwsSecret methods""" + + async def test_read_secret(self, secret_under_test, aws_credentials): + expected_value = secret_under_test.pop("expected_value") + secret_name = secret_under_test.pop( + "secret_name" + ) # Remove secret_name from kwargs + + @flow + async def test_flow(): + secret = AwsSecret( + aws_credentials=aws_credentials, + secret_name=secret_name, # Use for AwsSecret initialization + ) + # Pass remaining kwargs (version_id, version_stage) if present + return await secret.aread_secret(**secret_under_test) + + assert (await test_flow()) == expected_value + + async def test_write_secret(self, aws_credentials, secretsmanager_client): + secret = AwsSecret(aws_credentials=aws_credentials, secret_name="my-test") + secret_value = b"test-secret-value" + + @flow + async def test_flow(): + return await secret.awrite_secret(secret_value) + + arn = await test_flow() + assert arn.startswith("arn:aws:secretsmanager") + + # Verify the secret was written correctly + response = secretsmanager_client.get_secret_value(SecretId="my-test") + assert response["SecretBinary"] == secret_value + + async def test_delete_secret(self, aws_credentials, secretsmanager_client): + # First create a secret to delete + secret = AwsSecret(aws_credentials=aws_credentials, secret_name="test-delete") + secret_value = b"delete-me" + + @flow + async def setup_flow(): + return await secret.awrite_secret(secret_value) + + arn = await setup_flow() + + # Now test deletion + @flow + async def test_flow(): + return await secret.adelete_secret( + recovery_window_in_days=7, force_delete_without_recovery=False + ) + + deleted_arn = await test_flow() + assert deleted_arn == arn + + # Verify the secret is scheduled for deletion + with pytest.raises(secretsmanager_client.exceptions.InvalidRequestException): + secretsmanager_client.get_secret_value(SecretId="test-delete") + + async def test_delete_secret_validation(self, aws_credentials): + secret = AwsSecret( + aws_credentials=aws_credentials, secret_name="test-validation" ) - flow_state = await test_flow() - assert flow_state.get("Name") == secret_under_test["secret_name"] + with pytest.raises(ValueError, match="Cannot specify recovery window"): + await secret.adelete_secret( + force_delete_without_recovery=True, recovery_window_in_days=10 + ) - updated_secret = secretsmanager_client.get_secret_value( - SecretId=secret_under_test["secret_name"] - ) - assert ( - updated_secret.get("SecretString") == new_secret_value - or updated_secret.get("SecretBinary") == new_secret_value - ) + with pytest.raises( + ValueError, match="Recovery window must be between 7 and 30 days" + ): + await secret.adelete_secret(recovery_window_in_days=42) +# Keep existing task-based tests @pytest.mark.parametrize( ["secret_name", "secret_value"], [["string_secret", "42"], ["binary_secret", b"42"]] ) @@ -134,7 +244,7 @@ async def test_flow(): [29, True], ], ) -async def test_delete_secret( +async def test_delete_secret_task( aws_credentials, secret_under_test, recovery_window_in_days, @@ -163,47 +273,3 @@ async def test_flow(): ) else: assert deletion_date.date() == pendulum.now("UTC").date() - - -class TestAwsSecret: - @pytest.fixture - def aws_secret(self, aws_credentials, secretsmanager_client): - yield AwsSecret(aws_credentials=aws_credentials, secret_name="my-test") - - def test_roundtrip_read_write_delete(self, aws_secret): - arn = "arn:aws:secretsmanager:us-east-1:123456789012:secret" - assert aws_secret.write_secret("my-secret").startswith(arn) - assert aws_secret.read_secret() == b"my-secret" - assert aws_secret.write_secret("my-updated-secret").startswith(arn) - assert aws_secret.read_secret() == b"my-updated-secret" - assert aws_secret.delete_secret().startswith(arn) - - def test_read_secret_version_id(self, aws_secret: AwsSecret): - client = aws_secret.aws_credentials.get_secrets_manager_client() - client.create_secret(Name="my-test", SecretBinary="my-secret") - response = client.update_secret( - SecretId="my-test", SecretBinary="my-updated-secret" - ) - assert ( - aws_secret.read_secret(version_id=response["VersionId"]) - == b"my-updated-secret" - ) - - def test_delete_secret_conflict(self, aws_secret: AwsSecret): - with pytest.raises(ValueError, match="Cannot specify recovery window"): - aws_secret.delete_secret( - force_delete_without_recovery=True, recovery_window_in_days=10 - ) - - def test_delete_secret_recovery_window(self, aws_secret: AwsSecret): - with pytest.raises( - ValueError, match="Recovery window must be between 7 and 30 days" - ): - aws_secret.delete_secret(recovery_window_in_days=42) - - async def test_read_secret(self, secret_under_test, aws_credentials): - secret = AwsSecret( - aws_credentials=aws_credentials, - secret_name=secret_under_test["secret_name"], - ) - assert await secret.read_secret() == secret_under_test["expected_value"]