Skip to content

Commit

Permalink
Moved S3Url util class to util.py
Browse files Browse the repository at this point in the history
  • Loading branch information
moomindani committed Nov 10, 2023
1 parent 3ad6d2a commit e61215e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 42 deletions.
47 changes: 5 additions & 42 deletions tests/functional/adapter/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest

import boto3
import os
from urllib.parse import urlparse
from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations
from dbt.tests.adapter.basic.test_singular_tests import BaseSingularTests
from dbt.tests.adapter.basic.test_singular_tests_ephemeral import BaseSingularTestsEphemeral
Expand All @@ -29,7 +27,7 @@
check_relations_equal,
)

from tests.util import get_s3_location, get_region
from tests.util import get_s3_location, get_region, cleanup_s3_location


s3bucket = get_s3_location()
Expand Down Expand Up @@ -61,11 +59,6 @@
base_materialized_var_sql = config_materialized_var + config_incremental_strategy + model_base


def cleanup_s3_location():
client = boto3.client("s3", region_name=region)
S3Url(s3bucket + schema_name).delete_all_keys_v2(client)


class TestSimpleMaterializationsGlue(BaseSimpleMaterializations):
# all tests within this test has the same schema
@pytest.fixture(scope="class")
Expand All @@ -92,7 +85,7 @@ def models(self):

@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

pass
Expand Down Expand Up @@ -131,7 +124,7 @@ def models(self):

@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

# test_ephemeral with refresh table
Expand Down Expand Up @@ -184,7 +177,7 @@ def unique_schema(request, prefix) -> str:
class TestIncrementalGlue(BaseIncremental):
@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

@pytest.fixture(scope="class")
Expand Down Expand Up @@ -250,7 +243,7 @@ def unique_schema(request, prefix) -> str:

@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

def test_generic_tests(self, project):
Expand Down Expand Up @@ -291,33 +284,3 @@ def test_generic_tests(self, project):

#class TestSnapshotTimestampGlue(BaseSnapshotTimestamp):
# pass

class S3Url(object):
def __init__(self, url):
self._parsed = urlparse(url, allow_fragments=False)

@property
def bucket(self):
return self._parsed.netloc

@property
def key(self):
if self._parsed.query:
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
else:
return self._parsed.path.lstrip("/")

@property
def url(self):
return self._parsed.geturl()

def delete_all_keys_v2(self, client):
bucket = self.bucket
prefix = self.key

for response in client.get_paginator('list_objects_v2').paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' not in response:
continue
for content in response['Contents']:
print("Deleting: s3://" + bucket + "/" + content['Key'])
client.delete_object(Bucket=bucket, Key=content['Key'])
37 changes: 37 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import boto3
from urllib.parse import urlparse
from dbt.config.project import PartialProject


Expand Down Expand Up @@ -110,3 +112,38 @@ def get_s3_location():
def get_role_arn():
return os.environ.get("DBT_GLUE_ROLE_ARN", f"arn:aws:iam::{get_account_id()}:role/GlueInteractiveSessionRole")


def cleanup_s3_location(path, region):
client = boto3.client("s3", region_name=region)
S3Url(path).delete_all_keys_v2(client)


class S3Url(object):
def __init__(self, url):
self._parsed = urlparse(url, allow_fragments=False)

@property
def bucket(self):
return self._parsed.netloc

@property
def key(self):
if self._parsed.query:
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
else:
return self._parsed.path.lstrip("/")

@property
def url(self):
return self._parsed.geturl()

def delete_all_keys_v2(self, client):
bucket = self.bucket
prefix = self.key

for response in client.get_paginator('list_objects_v2').paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' not in response:
continue
for content in response['Contents']:
print("Deleting: s3://" + bucket + "/" + content['Key'])
client.delete_object(Bucket=bucket, Key=content['Key'])

0 comments on commit e61215e

Please sign in to comment.