Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[test] Refine unit test and integ tests #268

Merged
merged 5 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ black==23.11.0
# Adapter specific dependencies
waiter
boto3
moto~=4.2.7
pyparsing

dbt-core~=1.7.1
dbt-spark~=1.7.1
Expand Down
49 changes: 7 additions & 42 deletions tests/functional/adapter/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest

import boto3
import os
from urllib.parse import urlparse
from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations
from dbt.tests.adapter.basic.test_singular_tests import BaseSingularTests
from dbt.tests.adapter.basic.test_singular_tests_ephemeral import BaseSingularTestsEphemeral
Expand All @@ -29,7 +27,7 @@
check_relations_equal,
)

from tests.util import get_s3_location, get_region
from tests.util import get_s3_location, get_region, cleanup_s3_location


s3bucket = get_s3_location()
Expand Down Expand Up @@ -61,11 +59,6 @@
base_materialized_var_sql = config_materialized_var + config_incremental_strategy + model_base


def cleanup_s3_location():
client = boto3.client("s3", region_name=region)
S3Url(s3bucket + schema_name).delete_all_keys_v2(client)


class TestSimpleMaterializationsGlue(BaseSimpleMaterializations):
# all tests within this test has the same schema
@pytest.fixture(scope="class")
Expand All @@ -92,7 +85,7 @@ def models(self):

@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

pass
Expand Down Expand Up @@ -131,7 +124,7 @@ def models(self):

@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

# test_ephemeral with refresh table
Expand Down Expand Up @@ -184,7 +177,7 @@ def unique_schema(request, prefix) -> str:
class TestIncrementalGlue(BaseIncremental):
@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

@pytest.fixture(scope="class")
Expand Down Expand Up @@ -250,16 +243,18 @@ def unique_schema(request, prefix) -> str:

@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

def test_generic_tests(self, project):
# seed command
results = run_dbt(["seed"])

relation = relation_from_name(project.adapter, "base")
relation_table_model = relation_from_name(project.adapter, "table_model")
# run refresh table to disable the previous parquet file paths
project.run_sql(f"refresh table {relation}")
project.run_sql(f"refresh table {relation_table_model}")

# test command selecting base model
results = run_dbt(["test", "-m", "base"])
Expand Down Expand Up @@ -291,33 +286,3 @@ def test_generic_tests(self, project):

#class TestSnapshotTimestampGlue(BaseSnapshotTimestamp):
# pass

class S3Url(object):
def __init__(self, url):
self._parsed = urlparse(url, allow_fragments=False)

@property
def bucket(self):
return self._parsed.netloc

@property
def key(self):
if self._parsed.query:
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
else:
return self._parsed.path.lstrip("/")

@property
def url(self):
return self._parsed.geturl()

def delete_all_keys_v2(self, client):
bucket = self.bucket
prefix = self.key

for response in client.get_paginator('list_objects_v2').paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' not in response:
continue
for content in response['Contents']:
print("Deleting: s3://" + bucket + "/" + content['Key'])
client.delete_object(Bucket=bucket, Key=content['Key'])
4 changes: 4 additions & 0 deletions tests/unit/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CATALOG_ID = "1234567890101"
DATABASE_NAME = "test_dbt_glue"
BUCKET_NAME = "test-dbt-glue"
AWS_REGION = "us-east-1"
42 changes: 39 additions & 3 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Any, Dict, Optional
import unittest
from unittest import mock
from moto import mock_glue

from dbt.config import RuntimeConfig

import dbt.flags as flags
from dbt.adapters.glue import GlueAdapter
from dbt.adapters.glue.relation import SparkRelation
from tests.util import config_from_parts_or_dicts
from .util import MockAWSService


class TestGlueAdapter(unittest.TestCase):
Expand All @@ -33,8 +36,8 @@ def setUp(self):
"region": "us-east-1",
"workers": 2,
"worker_type": "G.1X",
"schema": "dbt_functional_test_01",
"database": "dbt_functional_test_01",
"schema": "dbt_unit_test_01",
"database": "dbt_unit_test_01",
}
},
"target": "test",
Expand All @@ -56,5 +59,38 @@ def test_glue_connection(self):

self.assertEqual(connection.state, "open")
self.assertEqual(connection.type, "glue")
self.assertEqual(connection.credentials.schema, "dbt_functional_test_01")
self.assertEqual(connection.credentials.schema, "dbt_unit_test_01")
self.assertIsNotNone(connection.handle)


@mock_glue
def test_get_table_type(self):
config = self._get_config()
adapter = GlueAdapter(config)

database_name = "dbt_unit_test_01"
table_name = "test_table"
mock_aws_service = MockAWSService()
mock_aws_service.create_database(name=database_name)
mock_aws_service.create_iceberg_table(table_name=table_name, database_name=database_name)
target_relation = SparkRelation.create(
schema=database_name,
identifier=table_name,
)
with mock.patch("dbt.adapters.glue.connections.open"):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
self.assertEqual(adapter.get_table_type(target_relation), "iceberg_table")

@mock_glue
def test_hudi_merge_table(self):
config = self._get_config()
adapter = GlueAdapter(config)
target_relation = SparkRelation.create(
schema="dbt_unit_test_01",
name="test_hudi_merge_table",
)
with mock.patch("dbt.adapters.glue.connections.open"):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
adapter.hudi_merge_table(target_relation, "SELECT 1", "id", "category", "empty", None, None)
96 changes: 96 additions & 0 deletions tests/unit/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Optional
import boto3

from .constants import AWS_REGION, BUCKET_NAME, CATALOG_ID, DATABASE_NAME


class MockAWSService:
def create_database(self, name: str = DATABASE_NAME, catalog_id: str = CATALOG_ID):
glue = boto3.client("glue", region_name=AWS_REGION)
glue.create_database(DatabaseInput={"Name": name}, CatalogId=catalog_id)

def create_table(
self,
table_name: str,
database_name: str = DATABASE_NAME,
catalog_id: str = CATALOG_ID,
location: Optional[str] = "auto",
):
glue = boto3.client("glue", region_name=AWS_REGION)
if location == "auto":
location = f"s3://{BUCKET_NAME}/tables/{table_name}"
glue.create_table(
CatalogId=catalog_id,
DatabaseName=database_name,
TableInput={
"Name": table_name,
"StorageDescriptor": {
"Columns": [
{
"Name": "id",
"Type": "string",
},
{
"Name": "country",
"Type": "string",
},
],
"Location": location,
},
"PartitionKeys": [
{
"Name": "dt",
"Type": "date",
},
],
"TableType": "table",
"Parameters": {
"compressionType": "snappy",
"classification": "parquet",
"projection.enabled": "false",
"typeOfData": "file",
},
},
)

def create_iceberg_table(
self,
table_name: str,
database_name: str = DATABASE_NAME,
catalog_id: str = CATALOG_ID):
glue = boto3.client("glue", region_name=AWS_REGION)
glue.create_table(
CatalogId=catalog_id,
DatabaseName=database_name,
TableInput={
"Name": table_name,
"StorageDescriptor": {
"Columns": [
{
"Name": "id",
"Type": "string",
},
{
"Name": "country",
"Type": "string",
},
{
"Name": "dt",
"Type": "date",
},
],
"Location": f"s3://{BUCKET_NAME}/tables/data/{table_name}",
},
"PartitionKeys": [
{
"Name": "dt",
"Type": "date",
},
],
"TableType": "EXTERNAL_TABLE",
"Parameters": {
"metadata_location": f"s3://{BUCKET_NAME}/tables/metadata/{table_name}/123.json",
"table_type": "iceberg",
},
},
)
37 changes: 37 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import boto3
from urllib.parse import urlparse
from dbt.config.project import PartialProject


Expand Down Expand Up @@ -110,3 +112,38 @@ def get_s3_location():
def get_role_arn():
return os.environ.get("DBT_GLUE_ROLE_ARN", f"arn:aws:iam::{get_account_id()}:role/GlueInteractiveSessionRole")


def cleanup_s3_location(path, region):
client = boto3.client("s3", region_name=region)
S3Url(path).delete_all_keys_v2(client)


class S3Url(object):
def __init__(self, url):
self._parsed = urlparse(url, allow_fragments=False)

@property
def bucket(self):
return self._parsed.netloc

@property
def key(self):
if self._parsed.query:
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
else:
return self._parsed.path.lstrip("/")

@property
def url(self):
return self._parsed.geturl()

def delete_all_keys_v2(self, client):
bucket = self.bucket
prefix = self.key

for response in client.get_paginator('list_objects_v2').paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' not in response:
continue
for content in response['Contents']:
print("Deleting: s3://" + bucket + "/" + content['Key'])
client.delete_object(Bucket=bucket, Key=content['Key'])