Skip to content

Commit

Permalink
Merge pull request #320 from aws-samples/fix-easy-setup
Browse files Browse the repository at this point in the history
Fix `easy-setup.sh` + make `region` and `profile` configurable in `validate-config.py`
  • Loading branch information
KeitaW authored May 8, 2024
2 parents 8214ab7 + ca85bf2 commit 0c5ee49
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
75 changes: 40 additions & 35 deletions 1.architectures/5.sagemaker-hyperpod/easy-setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ declare -a HELP=(
"[-s|--stack-id-vpc]"
"[-i|--instance-type]"
"[-c|--instance-count]"
"[-a|--availability-zone]"
"[-d|--dry-run]"
"CLUSTER_NAME"
)
Expand All @@ -45,12 +44,11 @@ parse_args() {
exit 0
;;
-r|--region)
aws_cli_args+=(--region "$2")
AWS_REGION="$2"
shift 2
;;
-p|--profile)
aws_cli_args+=(--profile "$2")
AWS_PROFILE="$2"
shift 2
;;
-s|--stack-id-vpc)
Expand All @@ -62,11 +60,7 @@ parse_args() {
shift 2
;;
-c|--instance-count)
INSTANCE_COUNTS="$2"
shift 2
;;
-i|--availability-zone)
AZ="$2"
INSTANCE_COUNT="$2"
shift 2
;;
-d|--dry-run)
Expand All @@ -83,7 +77,8 @@ parse_args() {

parse_args $@

mkdir $CLUSTER_NAME && cd $CLUSTER_NAME
mkdir -p $CLUSTER_NAME
cd $CLUSTER_NAME

# Check for AWS CLI
if ! command -v aws &> /dev/null
Expand All @@ -105,21 +100,34 @@ if [ -z ${CLUSTER_NAME} ]; then
echo "[WARNING] CLUSTER_NAME environment variable is not set, automatically set to ml-cluster"
CLUSTER_NAME=ml-cluster
fi
echo "export CLUSTER_NAME=${CLUSTER_NAME}" > env_vars

# Define stack name
if [ -z ${STACK_ID_VPC} ]; then
echo "[WARNING] STACK_ID_VPC environment variable is not set, automatically set to sagemaker-hyperpod"
STACK_ID_VPC=sagemaker-hyperpod
fi
echo "export STACK_ID_VPC=${STACK_ID_VPC}" >> env_vars

# Define AWS Region
if [ -z ${AWS_REGION} ]; then
echo "[WARNING] AWS_REGION environment variable is not set, automatically set depending on aws cli default region."
export AWS_REGION=$(aws configure get region)
export AWS_REGION=$(aws "${aws_cli_args[@]}" configure get region)
fi
aws_cli_args+=(--region "${AWS_REGION}")
echo "export AWS_REGION=${AWS_REGION}" >> env_vars
echo "[INFO] AWS_REGION = ${AWS_REGION}"

# Define AWS Profile
if [ -z ${AWS_PROFILE} ];
then
echo "[WARNING] AWS_PROFILE environment variable is not set, ignore if you are using default profile."
else
echo "export AWS_PROFILE=${AWS_PROFILE}" >> env_vars
echo "[INFO] AWS_PROFILE = ${AWS_PROFILE}"
aws_cli_args+=(--profile "${AWS_PROFILE}")
fi

# Define Instances seperated by ','.
if [ -z ${INSTANCE} ]; then
echo "[WARNING] INSTANCE environment variable is not set, automatically set to g5.12xlarge."
Expand All @@ -135,10 +143,9 @@ if [ -z ${INSTANCE_COUNT} ]; then
fi

# Retrieve VPC ID
export VPC_ID=`aws cloudformation describe-stacks \
export VPC_ID=`aws "${aws_cli_args[@]}" cloudformation describe-stacks \
--stack-name $STACK_ID_VPC \
--query 'Stacks[0].Outputs[?OutputKey==\`VPC\`].OutputValue' \
--region ${AWS_REGION} \
--output text`

if [[ ! -z $VPC_ID ]]; then
Expand All @@ -150,10 +157,9 @@ else
fi

# Grab the subnet id
export SUBNET_ID=`aws cloudformation describe-stacks \
export SUBNET_ID=`aws "${aws_cli_args[@]}" cloudformation describe-stacks \
--stack-name $STACK_ID_VPC \
--query 'Stacks[0].Outputs[?OutputKey==\`PrimaryPrivateSubnet\`].OutputValue' \
--region ${AWS_REGION} \
--output text`

if [[ ! -z $SUBNET_ID ]]; then
Expand All @@ -165,10 +171,9 @@ else
fi

# Grab the subnet id
export PUBLIC_SUBNET_ID=`aws cloudformation describe-stacks \
export PUBLIC_SUBNET_ID=`aws "${aws_cli_args[@]}" cloudformation describe-stacks \
--stack-name $STACK_ID_VPC \
--query 'Stacks[0].Outputs[?OutputKey==\`PublicSubnet\`].OutputValue' \
--region ${AWS_REGION} \
--output text`

if [[ ! -z $PUBLIC_SUBNET_ID ]]; then
Expand All @@ -180,10 +185,9 @@ else
fi

# Get FSx Filesystem id from CloudFormation
export FSX_ID=`aws cloudformation describe-stacks \
export FSX_ID=`aws "${aws_cli_args[@]}" cloudformation describe-stacks \
--stack-name $STACK_ID_VPC \
--query 'Stacks[0].Outputs[?OutputKey==\`FSxLustreFilesystemId\`].OutputValue' \
--region ${AWS_REGION} \
--output text`

if [[ ! -z $FSX_ID ]]; then
Expand All @@ -195,10 +199,9 @@ else
fi

# Get FSx Filesystem Mountname from CloudFormation
export FSX_MOUNTNAME=`aws cloudformation describe-stacks \
export FSX_MOUNTNAME=`aws "${aws_cli_args[@]}" cloudformation describe-stacks \
--stack-name $STACK_ID_VPC \
--query 'Stacks[0].Outputs[?OutputKey==\`FSxLustreFilesystemMountname\`].OutputValue' \
--region ${AWS_REGION} \
--output text`

if [[ ! -z $FSX_MOUNTNAME ]]; then
Expand All @@ -210,10 +213,9 @@ else
fi

# Get FSx Security Group from CloudFormation
export SECURITY_GROUP=`aws cloudformation describe-stacks \
export SECURITY_GROUP=`aws "${aws_cli_args[@]}" cloudformation describe-stacks \
--stack-name $STACK_ID_VPC \
--query 'Stacks[0].Outputs[?OutputKey==\`SecurityGroup\`].OutputValue' \
--region ${AWS_REGION} \
--output text`

if [[ ! -z $SECURITY_GROUP ]]; then
Expand All @@ -225,10 +227,9 @@ else
fi

# Get sagemaker role ARN
export ROLE=`aws cloudformation describe-stacks \
export ROLE=`aws "${aws_cli_args[@]}" cloudformation describe-stacks \
--stack-name $STACK_ID_VPC \
--query 'Stacks[0].Outputs[?OutputKey==\`AmazonSagemakerClusterExecutionRoleArn\`].OutputValue' \
--region ${AWS_REGION} \
--output text`

if [[ ! -z $ROLE ]]; then
Expand All @@ -251,10 +252,9 @@ else
fi

# Get s3 bucket name
export BUCKET=`aws cloudformation describe-stacks \
export BUCKET=`aws "${aws_cli_args[@]}" cloudformation describe-stacks \
--stack-name $STACK_ID_VPC \
--query 'Stacks[0].Outputs[?OutputKey==\`AmazonS3BucketName\`].OutputValue' \
--region ${AWS_REGION} \
--output text`

if [[ ! -z $BUCKET ]]; then
Expand All @@ -265,12 +265,17 @@ else
return 1
fi


git clone --depth=1 https://github.com/aws-samples/awsome-distributed-training/
if [ ! -d "awsome-distributed-training" ]
then
echo "Cloning the repository..."
git clone --depth=1 https://github.com/aws-samples/awsome-distributed-training/
else
echo "Repository already exists..."
fi
# Use pushd and popd to navigate directories https://en.wikipedia.org/wiki/Pushd_and_popd
pushd awsome-distributed-training/1.architectures/5.sagemaker-hyperpod/LifecycleScripts/
# upload data
aws s3 cp --recursive base-config/ s3://${BUCKET}/src
aws "${aws_cli_args[@]}" s3 cp --recursive base-config/ s3://${BUCKET}/src
# move back to the previous directory
popd

Expand All @@ -282,7 +287,7 @@ cat > provisioning_parameters.json << EOL
"worker_groups": [
{
"instance_group_name": "worker-group-1",
"partition_name": ${INSTANCE}
"partition_name": "${INSTANCE}"
}
],
"fsx_dns_name": "${FSX_ID}.fsx.${AWS_REGION}.amazonaws.com",
Expand All @@ -291,7 +296,7 @@ cat > provisioning_parameters.json << EOL
EOL

# copy to the S3 Bucket
aws s3 cp provisioning_parameters.json s3://${BUCKET}/src/
aws "${aws_cli_args[@]}" s3 cp provisioning_parameters.json s3://${BUCKET}/src/

cat > cluster-config.json << EOL
{
Expand Down Expand Up @@ -328,12 +333,12 @@ cat > cluster-config.json << EOL
EOL

# Validate Cluster configuration
wget https://raw.githubusercontent.com/aws-samples/awsome-distributed-training/main/1.architectures/5.sagemaker-hyperpod/validate-config.py
wget --no-clobber https://raw.githubusercontent.com/aws-samples/awsome-distributed-training/main/1.architectures/5.sagemaker-hyperpod/validate-config.py
# install boto3
pip3 install boto3
# check config for known issues
python3 validate-config.py --cluster-config cluster-config.json --provisioning-parameters provisioning_parameters.json
python3 validate-config.py --cluster-config cluster-config.json --provisioning-parameters provisioning_parameters.json ${aws_cli_args[@]}

echo "aws sagemaker create-cluster --cli-input-json file://cluster-config.json --region ${REGION}"
echo "aws ${aws_cli_args[@]} sagemaker create-cluster --cli-input-json file://cluster-config.json"
[[ DRY_RUN -eq 1 ]] && exit 0
aws sagemaker create-cluster --cli-input-json "file://cluster-config.json" --region ${REGION}
aws ${aws_cli_args[@]} sagemaker create-cluster --cli-input-json "file://cluster-config.json"
8 changes: 6 additions & 2 deletions 1.architectures/5.sagemaker-hyperpod/validate-config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def validate_fsx_lustre(fsx_client, cluster_config, provisioning_parameters):

def main():
parser = argparse.ArgumentParser(description="Validate cluster config.")
parser.add_argument("--region", help="AWS Region where the cluster will be created", default=None)
parser.add_argument("--profile", help="AWS Profile to use for creating the cluster", default=None)
parser.add_argument("--cluster-config", help="Path to the cluster config JSON file")
parser.add_argument("--provisioning-parameters", help="Path to the provisioning parameters JSON file")
args = parser.parse_args()
Expand All @@ -161,8 +163,10 @@ def main():
except json.decoder.JSONDecodeError:
print(f"❌ provisioning_parameters.json is invalid.")
return False

session = boto3.Session(profile_name=args.profile, region_name=args.region)

ec2_client = boto3.client('ec2')
ec2_client = session.client('ec2')

# check instance group name
valid = validate_instance_groups(cluster_config, provisioning_parameters)
Expand All @@ -174,7 +178,7 @@ def main():
valid = validate_sg(ec2_client, cluster_config) and valid

# Validate FSx Lustre
valid = validate_fsx_lustre(boto3.client('fsx'), cluster_config, provisioning_parameters) and valid
valid = validate_fsx_lustre(session.client('fsx'), cluster_config, provisioning_parameters) and valid

# validate provisioning_parameters
valid = validate_provisioning_parameters(provisioning_parameters)
Expand Down

0 comments on commit 0c5ee49

Please sign in to comment.