diff --git a/cid/base.py b/cid/base.py index 02044eb6..0086ae86 100644 --- a/cid/base.py +++ b/cid/base.py @@ -43,6 +43,10 @@ def region(self) -> str: def region_name(self) -> str: return self.session.region_name + @property + def partition(self) -> str: + return self.session.get_partition_for_region(region_name=self.region_name) + @property def session(self) -> Session: return self._session diff --git a/cid/helpers/__init__.py b/cid/helpers/__init__.py index 9654df4a..89946c66 100644 --- a/cid/helpers/__init__.py +++ b/cid/helpers/__init__.py @@ -5,6 +5,7 @@ from cid.helpers.quicksight import QuickSight, Dashboard, Dataset, Datasource, Template from cid.helpers.csv2view import csv2view from cid.helpers.organizations import Organizations +from cid.helpers.s3 import S3 __all__ = [ "Athena", @@ -17,5 +18,6 @@ "Template", "diff", "csv2view", - "Organizations" + "Organizations", + "S3" ] diff --git a/cid/helpers/athena.py b/cid/helpers/athena.py index 0b30ca44..da457e30 100644 --- a/cid/helpers/athena.py +++ b/cid/helpers/athena.py @@ -9,6 +9,7 @@ from pkg_resources import resource_string from cid.base import CidBase +from cid.helpers.s3 import S3 from cid.utils import get_parameter, get_parameters, cid_print from cid.helpers.diff import diff from cid.exceptions import CidCritical, CidError @@ -125,20 +126,28 @@ def WorkGroup(self) -> str: logger.info('Selecting Athena workgroup...') workgroups = self.list_work_groups() logger.info(f'Found {len(workgroups)} workgroups: {", ".join([wg.get("Name") for wg in workgroups])}') - if len(workgroups) == 1: + if len(workgroups) == 0: + self.WorkGroup = self._ensure_workgroup(name=self.defaults.get('WorkGroup')) + elif len(workgroups) == 1: # Silently choose the only workgroup that is available - self.WorkGroup = workgroups.pop().get('Name') - elif len(workgroups) > 1: + self.WorkGroup = self._ensure_workgroup(name=workgroups.pop().get('Name')) + else: # Select default workgroup if present + if self.defaults.get('WorkGroup') not in {wgr['Name'] for wgr in workgroups}: + workgroups.append({'Name': f"{self.defaults.get('WorkGroup')} (create new)"}) default_workgroup = next(iter([wgr.get('Name') for wgr in workgroups if wgr['Name'] == self.defaults.get('WorkGroup')]), None) if default_workgroup: logger.info(f'Found "{default_workgroup}" as a default workgroup') # Ask user - self.WorkGroup = get_parameter( + selected_workgroup = get_parameter( param_name='athena-workgroup', - message="Select AWS Athena workgroup to use", + message="Select Amazon Athena workgroup to use", choices=[wgr['Name'] for wgr in workgroups], default=default_workgroup ) + if ' (create new)' in selected_workgroup: + selected_workgroup = selected_workgroup.replace(' (create new)', '') + self.WorkGroup = self._ensure_workgroup(name=selected_workgroup) + logger.info(f'Selected workgroup: "{self._WorkGroup}"') return self._WorkGroup @@ -155,6 +164,64 @@ def WorkGroup(self, name: str): self._WorkGroup = name logger.info(f'Selected Athena WorkGroup: "{self._WorkGroup}"') + def _ensure_workgroup(self, name: str) -> str: + try: + s3 = S3(session=self.session) + bucket_name = f'{self.partition}-athena-query-results-cid-{self.account_id}-{self.region}' + + workgroup = self.client.get_work_group(WorkGroup=name) + # "${AWS::Partition}-athena-query-results-cid-${AWS::AccountId}-${AWS::Region}" + if not workgroup.get('WorkGroup', {}).get('Configuration', {}).get('ResultConfiguration', {}).get('OutputLocation', None): + s3 = S3(session=self.session) + buckets = s3.list_buckets(region_name=self.region) + if bucket_name not in buckets: + buckets.append(f'{bucket_name} (create new)') + bucket_name = get_parameter( + param_name='athena-result-bucket', + message=f"Select S3 bucket to use with Amazon Athena Workgroup [{name}]", + choices=[bucket for bucket in buckets] + ) + if ' (create new)' in bucket_name: + bucket_name = bucket_name.replace(' (create new)', '') + s3.ensure_bucket(name=bucket_name) + response = self.client.update_work_group( + WorkGroup=name, + Description='string', + ConfigurationUpdates={ + 'ResultConfigurationUpdates': { + 'OutputLocation': f's3://{bucket_name}', + 'EncryptionConfiguration': { + 'EncryptionOption': 'SSE_S3', + }, + 'AclConfiguration': { + 'S3AclOption': 'BUCKET_OWNER_FULL_CONTROL' + } + } + } + ) + return name + except self.client.exceptions.InvalidRequestException as ex: + # Workgroup does not exist + if 'WorkGroup is not found' in ex.response.get('Error', {}).get('Message'): + s3.ensure_bucket(name=bucket_name) + response = self.client.create_work_group( + Name=name, + Configuration={ + 'ResultConfiguration': { + 'OutputLocation': f's3://{bucket_name}', + 'EncryptionConfiguration': { + 'EncryptionOption': 'SSE_S3', + }, + 'AclConfiguration': { + 'S3AclOption': 'BUCKET_OWNER_FULL_CONTROL' + } + }, + } + ) + return name + except Exception as ex: + raise CidCritical('Failed to create Athena work group') from ex + def list_data_catalogs(self) -> list: return self.client.list_data_catalogs().get('DataCatalogsSummary') diff --git a/cid/helpers/s3.py b/cid/helpers/s3.py new file mode 100644 index 00000000..7d350733 --- /dev/null +++ b/cid/helpers/s3.py @@ -0,0 +1,75 @@ +import json +import logging +import botocore +from typing import Optional, List + +from cid.base import CidBase +from cid.exceptions import CidError + +logger = logging.getLogger(__name__) + + +class S3(CidBase): + + def __init__(self, session): + super().__init__(session) + self.client = self.session.client('s3', region_name=self.region) + + def ensure_bucket(self, name: str) -> str: + try: + response = self.client.head_bucket(Bucket=name) + return name + except botocore.exceptions.ClientError as ex: + if int(ex.response['Error']['Code']) != 404: + raise CidError(f"Cannot check bucket {ex}!") + + response = self.client.create_bucket( + ACL='private', + Bucket=name + ) + response = self.client.put_bucket_encryption( + Bucket=name, + ServerSideEncryptionConfiguration={ + 'Rules': [ + { + 'ApplyServerSideEncryptionByDefault': { + 'SSEAlgorithm': 'AES256', + }, + }, + ] + } + ) + + response = self.client.put_bucket_lifecycle_configuration( + Bucket=name, + LifecycleConfiguration={ + 'Rules': [ + { + 'Expiration': { + 'Days': 14, + }, + 'Filter': { + 'Prefix': '/', + }, + 'ID': 'ExpireAfter14Days', + 'Status': 'Enabled', + }, + ], + }, + ) + return name + + def list_buckets(self, region_name: Optional[str] = None) -> List[str]: + buckets = self.client.list_buckets() + bucket_regions = { + x['Name']: self.client.get_bucket_location(Bucket=x['Name']).get('LocationConstraint', None) for x in buckets['Buckets'] + } + for bucket in bucket_regions: + if bucket_regions[bucket] is None: + bucket_regions[bucket] = 'us-east-1' + + if region_name: + bucket_names = [x['Name'] for x in buckets['Buckets'] if bucket_regions.get(x['Name']) == region_name] + else: + bucket_names = [x['Name'] for x in buckets['Buckets']] + return bucket_names \ No newline at end of file