Skip to content

Commit

Permalink
Added automated Athena workgroup creation if missing (#618)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaborschulz-aws authored Sep 28, 2023
1 parent ad78189 commit e218ebc
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 6 deletions.
4 changes: 4 additions & 0 deletions cid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def region(self) -> str:
def region_name(self) -> str:
return self.session.region_name

@property
def partition(self) -> str:
return self.session.get_partition_for_region(region_name=self.region_name)

@property
def session(self) -> Session:
return self._session
Expand Down
4 changes: 3 additions & 1 deletion cid/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cid.helpers.quicksight import QuickSight, Dashboard, Dataset, Datasource, Template
from cid.helpers.csv2view import csv2view
from cid.helpers.organizations import Organizations
from cid.helpers.s3 import S3

__all__ = [
"Athena",
Expand All @@ -17,5 +18,6 @@
"Template",
"diff",
"csv2view",
"Organizations"
"Organizations",
"S3"
]
77 changes: 72 additions & 5 deletions cid/helpers/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pkg_resources import resource_string

from cid.base import CidBase
from cid.helpers.s3 import S3
from cid.utils import get_parameter, get_parameters, cid_print
from cid.helpers.diff import diff
from cid.exceptions import CidCritical, CidError
Expand Down Expand Up @@ -125,20 +126,28 @@ def WorkGroup(self) -> str:
logger.info('Selecting Athena workgroup...')
workgroups = self.list_work_groups()
logger.info(f'Found {len(workgroups)} workgroups: {", ".join([wg.get("Name") for wg in workgroups])}')
if len(workgroups) == 1:
if len(workgroups) == 0:
self.WorkGroup = self._ensure_workgroup(name=self.defaults.get('WorkGroup'))
elif len(workgroups) == 1:
# Silently choose the only workgroup that is available
self.WorkGroup = workgroups.pop().get('Name')
elif len(workgroups) > 1:
self.WorkGroup = self._ensure_workgroup(name=workgroups.pop().get('Name'))
else:
# Select default workgroup if present
if self.defaults.get('WorkGroup') not in {wgr['Name'] for wgr in workgroups}:
workgroups.append({'Name': f"{self.defaults.get('WorkGroup')} (create new)"})
default_workgroup = next(iter([wgr.get('Name') for wgr in workgroups if wgr['Name'] == self.defaults.get('WorkGroup')]), None)
if default_workgroup: logger.info(f'Found "{default_workgroup}" as a default workgroup')
# Ask user
self.WorkGroup = get_parameter(
selected_workgroup = get_parameter(
param_name='athena-workgroup',
message="Select AWS Athena workgroup to use",
message="Select Amazon Athena workgroup to use",
choices=[wgr['Name'] for wgr in workgroups],
default=default_workgroup
)
if ' (create new)' in selected_workgroup:
selected_workgroup = selected_workgroup.replace(' (create new)', '')
self.WorkGroup = self._ensure_workgroup(name=selected_workgroup)

logger.info(f'Selected workgroup: "{self._WorkGroup}"')
return self._WorkGroup

Expand All @@ -155,6 +164,64 @@ def WorkGroup(self, name: str):
self._WorkGroup = name
logger.info(f'Selected Athena WorkGroup: "{self._WorkGroup}"')

def _ensure_workgroup(self, name: str) -> str:
try:
s3 = S3(session=self.session)
bucket_name = f'{self.partition}-athena-query-results-cid-{self.account_id}-{self.region}'

workgroup = self.client.get_work_group(WorkGroup=name)
# "${AWS::Partition}-athena-query-results-cid-${AWS::AccountId}-${AWS::Region}"
if not workgroup.get('WorkGroup', {}).get('Configuration', {}).get('ResultConfiguration', {}).get('OutputLocation', None):
s3 = S3(session=self.session)
buckets = s3.list_buckets(region_name=self.region)
if bucket_name not in buckets:
buckets.append(f'{bucket_name} (create new)')
bucket_name = get_parameter(
param_name='athena-result-bucket',
message=f"Select S3 bucket to use with Amazon Athena Workgroup [{name}]",
choices=[bucket for bucket in buckets]
)
if ' (create new)' in bucket_name:
bucket_name = bucket_name.replace(' (create new)', '')
s3.ensure_bucket(name=bucket_name)
response = self.client.update_work_group(
WorkGroup=name,
Description='string',
ConfigurationUpdates={
'ResultConfigurationUpdates': {
'OutputLocation': f's3://{bucket_name}',
'EncryptionConfiguration': {
'EncryptionOption': 'SSE_S3',
},
'AclConfiguration': {
'S3AclOption': 'BUCKET_OWNER_FULL_CONTROL'
}
}
}
)
return name
except self.client.exceptions.InvalidRequestException as ex:
# Workgroup does not exist
if 'WorkGroup is not found' in ex.response.get('Error', {}).get('Message'):
s3.ensure_bucket(name=bucket_name)
response = self.client.create_work_group(
Name=name,
Configuration={
'ResultConfiguration': {
'OutputLocation': f's3://{bucket_name}',
'EncryptionConfiguration': {
'EncryptionOption': 'SSE_S3',
},
'AclConfiguration': {
'S3AclOption': 'BUCKET_OWNER_FULL_CONTROL'
}
},
}
)
return name
except Exception as ex:
raise CidCritical('Failed to create Athena work group') from ex

def list_data_catalogs(self) -> list:
return self.client.list_data_catalogs().get('DataCatalogsSummary')

Expand Down
75 changes: 75 additions & 0 deletions cid/helpers/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
import logging
import botocore
from typing import Optional, List

from cid.base import CidBase
from cid.exceptions import CidError

logger = logging.getLogger(__name__)


class S3(CidBase):

def __init__(self, session):
super().__init__(session)
self.client = self.session.client('s3', region_name=self.region)

def ensure_bucket(self, name: str) -> str:
try:
response = self.client.head_bucket(Bucket=name)
return name
except botocore.exceptions.ClientError as ex:
if int(ex.response['Error']['Code']) != 404:
raise CidError(f"Cannot check bucket {ex}!")

response = self.client.create_bucket(
ACL='private',
Bucket=name
)
response = self.client.put_bucket_encryption(
Bucket=name,
ServerSideEncryptionConfiguration={
'Rules': [
{
'ApplyServerSideEncryptionByDefault': {
'SSEAlgorithm': 'AES256',
},
},
]
}
)

response = self.client.put_bucket_lifecycle_configuration(
Bucket=name,
LifecycleConfiguration={
'Rules': [
{
'Expiration': {
'Days': 14,
},
'Filter': {
'Prefix': '/',
},
'ID': 'ExpireAfter14Days',
'Status': 'Enabled',
},
],
},
)
return name

def list_buckets(self, region_name: Optional[str] = None) -> List[str]:
buckets = self.client.list_buckets()
bucket_regions = {
x['Name']: self.client.get_bucket_location(Bucket=x['Name']).get('LocationConstraint', None) for x in buckets['Buckets']
}
for bucket in bucket_regions:
if bucket_regions[bucket] is None:
bucket_regions[bucket] = 'us-east-1'

if region_name:
bucket_names = [x['Name'] for x in buckets['Buckets'] if bucket_regions.get(x['Name']) == region_name]
else:
bucket_names = [x['Name'] for x in buckets['Buckets']]
return bucket_names

0 comments on commit e218ebc

Please sign in to comment.