Skip to content

Commit

Permalink
Merge pull request #272 from aws-samples/olcf-6
Browse files Browse the repository at this point in the history
Add OLFC-6 test case
  • Loading branch information
verdimrc authored May 7, 2024
2 parents a36787a + d126672 commit ff820c0
Show file tree
Hide file tree
Showing 25 changed files with 577 additions and 60 deletions.
1 change: 0 additions & 1 deletion 3.test_cases/15.gpt-neox/0.gpt-neox.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,3 @@ RUN git clone https://github.com/EleutherAI/gpt-neox.git \
# Rebuild newer flash-attn
RUN MAX_JOBS=192 FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn==2.5.5 --upgrade
WORKDIR /workspace/gpt-neox
COPY src/c4_prepare_data.py c4_prepare_data.py
61 changes: 3 additions & 58 deletions 3.test_cases/15.gpt-neox/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Pythia GPT-NeoX Test Case <!-- omit in toc -->
# GPT-NeoX Test Cases <!-- omit in toc -->

GPT-NeoX is an [EleutherAI](https://www.eleuther.ai)'s library for training large-scale language models on GPUs. This framework is based on [NVIDIA's Megatron Language Model](https://github.com/NVIDIA/Megatron-LM) and has been augmented with techniques from [DeepSpeed](https://www.deepspeed.ai) as well as some novel optimizations. This test case illustrates how to train [Pythia](https://arxiv.org/abs/2304.01373) model using GPT-Neox.
GPT-NeoX is an [EleutherAI](https://www.eleuther.ai)'s library for training large-scale language models on GPUs. This framework is based on [NVIDIA's Megatron Language Model](https://github.com/NVIDIA/Megatron-LM) and has been augmented with techniques from [DeepSpeed](https://www.deepspeed.ai) as well as some novel optimizations.

## 1. Preparation

Expand Down Expand Up @@ -54,7 +54,7 @@ If you wish to reduce memory footprint of the build process, consider tweaknig `
3. Convert the Docker image to a squash file with the command below.

```bash
enroot import -o ${ENROOT_IMAGE} dockerd://get-neox:latest
enroot import -o ${ENROOT_IMAGE} dockerd://gpt-neox:latest
```

The file will be stored in the `/apps` directory (default). The output should look as below.
Expand Down Expand Up @@ -87,58 +87,3 @@ In this option, you will use a compute node to build the image. Submit the job a
sbatch 1.build-image.sbatch
```


## 4. Dataset preparation

This test case uses the [C4 dataset](https://paperswithcode.com/paper/exploring-the-limits-of-transfer-learning). In this section, you will retrieve and tokenize the dataset.

We will use GPT-NeoX-20B tokenizer. Place the tokenizer file as follows.

```bash
mkdir -p ${MODEL_PATH}/tokenizers
wget https://the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json -O ${MODEL_PATH}/tokenizers/20B_tokenizer.json
```

To retrieve and tokenize the Pile dataset, we use `prepare_data.py` of NeoX through container. The exact steps are described in `1.prepare-data.sbatch`.

```bash
sbatch 2.prepare-data.sbatch
```

By default, the script only downloads subset of the dataset. Use the following if you wish to download whole C4 dataset:
```bash
DATASET=c4 sbatch 2.prepare-data.sbatch
```

You will see the following data after the job.

```bash
$ ls ${DATA_PATH}/c4_openwebtext/
c4-train.00000-of-01024.jsonl c4-train.00002-of-01024.jsonl c4_openwebtext_text_document.bin
c4-train.00001-of-01024.jsonl c4-train.00003-of-01024.jsonl c4_openwebtext_text_document.idx
```

## 5. Model training

GPT-NeoX parameters are defined in a YAML configuration file which is passed to the `deepy.py` launcher.
Parameters originate from either the [DeepSpeed runner CLI (DSL)](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/launcher/runner.py#L33), [DeepSpeed configuration file (DSC)](https://www.deepspeed.ai/docs/config-json/), [Megatron-LM CLI (Meg)](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/arguments.py#L224) or are GPT-NeoX (NeoX) modifications. See [the configuration README](https://github.com/EleutherAI/gpt-neox/blob/main/configs/README.md) of NeoX repository. You need to make few changes to the config files to make it work on a Slurm cluster. Firstly, you need to tell where to retrieve training data and model checkpoints.

```json
"vocab_file": "/fsx/gpt-neox/tokenizers/20B_tokenizer.json",
"save": "/fsx/gpt-neox/models/pythia/1-4B_checkpoints",
"load": "/fsx/gpt-neox/models/pythia/1-4B_checkpoints",
"data_path": " /fsx/c4_subset/c4_openwebtext/c4_openwebtext_text_document",
```

Additionally, you need to modify all of your configs to conform to the JSON. When launching a GPT-NeoX job you can specify multiple YAML config files. Internally, all of these files are merged into one config and then passed as a single long command line argument to DeepSpeed. When using SLURM and its internal command srun, python fails to parse this long command line argument unless it is in the more restrictive JSON format. This test case prepares sample JSON configs in `configs/pythia` directory.

Note: `gas` (`gradient_accumulation_steps`) in the original `pythia` config has been removed in the JSON configs. See https://github.com/EleutherAI/gpt-neox/pull/1144 for details.

Launch distributed training using `3.train.sbatch`.

```bash
sbatch 3.train.sbatch
````
By default, the 1.4 B model will be trained. You may modify the `MODEL_CONFIG` variable in the script to train different sizing.
2 changes: 2 additions & 0 deletions 3.test_cases/15.gpt-neox/examples/olcf-6/.gitigonre
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.csv
*.ipynb
61 changes: 61 additions & 0 deletions 3.test_cases/15.gpt-neox/examples/olcf-6/1.train.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/bin/bash

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

#SBATCH --job-name="neox"
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8 # Number of GPU per node
#SBATCH --output=logs/%x_%j.out # logfile for stdout
#SBATCH --error=logs/%x_%j.err # logfile for stderr, remove it to merge both outputs
#SBATCH --wait-all-nodes=1
#SBATCH --exclusive
set -uxo pipefail

# default variables for Enroot, if these variables are defined then use them
: "${FSX_PATH:=/fsx}"
: "${IMAGE:=$FSX_PATH/apps/gpt-neox.sqsh}"
: "${CONTAINER_MOUNT:=$FSX_PATH:$FSX_PATH}"
## EFA settings
export FI_LOG_LEVEL=warn
export FI_PROVIDER=efa # change to eth if you want to use ENA for comparisons
export FI_EFA_USE_HUGE_PAGE=0
# https://discuss.pytorch.org/t/nccl-network-is-unreachable-connection-refused-when-initializing-ddp/137352
# https://github.com/pytorch/pytorch/issues/68893
export NCCL_SOCKET_IFNAME=en
export NCCL_ASYNC_ERROR_HANDLING=1
#export NCCL_DEBUG=INFO

export DATA_CONFIG=${PWD}/configs/frontier.yml
export MODEL_CONFIG=${PWD}/configs/forge-m.yml
# Some potentially useful distributed environment variables
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
export NODES=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
export NODES_ARRAY=($NODES)
export HEAD_NODE=${NODES_ARRAY[0]}
export MASTER_ADDR=$(hostname --ip-address)
export MASTER_PORT=$RANDOM
export NNODES=$SLURM_JOB_NUM_NODES
export NPROC=$SLURM_GPUS_PER_NODE
export WORLD_SIZE=$(( $NNODES * $NPROC ))

declare -a ARGS=(
--container-image $IMAGE
--container-mounts $CONTAINER_MOUNT
)

declare -a TORCHRUN_ARGS=(
# change this to match the number of gpus per node:
--master_addr $MASTER_ADDR \
--master_port $RANDOM \
--nproc_per_node=8 \
--nnodes=$SLURM_JOB_NUM_NODES \
--rdzv_id=$SLURM_JOB_ID \
--rdzv_backend=c10d \
--rdzv_endpoint=$(hostname) \
)

srun -l "${ARGS[@]}" python deepy.py train.py ${MODEL_CONFIG} ${DATA_CONFIG}
31 changes: 31 additions & 0 deletions 3.test_cases/15.gpt-neox/examples/olcf-6/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

# Pythia GPT-NeoX Test Case <!-- omit in toc -->

This test case illustrates how to train [Pythia](https://arxiv.org/abs/2304.01373) model using GPT-Neox.

## 1. Preparation

This test case assumes that you have built GPT-NeoX container [`../../0.gpt-neox.dockerfile`](https://github.com/aws-samples/awsome-distributed-training/tree/main/3.test_cases/15.gpt-neox).

## 2. Download Dataset

This test case make use of [Tokenized Data for FORGE Foundation Models](https://doi.ccs.ornl.gov/ui/doi/453). Download the data and place as follows:

```bash
/fsx/data/olcf
├── README.txt
├── all_text_document.bin
├── all_text_document.idx
└── all_vocab.json
```

This dataset comprises a vast corpus of 257 billion tokens, accompanied by the corresponding vocabulary file employed in the pre-training of FORGE foundation models. The primary data source for this corpus is scientific documents derived from diverse origins, and they have been tokenized using the Hugging Face BPE tokenizer. Further details about this research can be found in the publication titled "FORGE: Pre-Training Open Foundation Models for Science" authored by Junqi Yin, Sajal Dash, Feiyi Wang, and Mallikarjun (Arjun) Shankar, presented at SC'23. The data tokenization pipeline and resulting artifacts use CORE data [Ref: Knoth, P., & Zdrahal, Z. (2012). CORE: three access levels to underpin open access. D-Lib Magazine, 18(11/12)]. For use of these data sets for any purpose, please follow the guidelines provided in https://core.ac.uk/terms .

## 3. Train

Now that you can kickstart the training with:

```bash
sbatch 1.train.sbatch
```

101 changes: 101 additions & 0 deletions 3.test_cases/15.gpt-neox/examples/olcf-6/configs/forge-l.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# GPT-2 pretraining setup
{

"tokenizer_type": "HFTokenizer",
"data-path": "/fsx/data/olcf/all_text_document",
"vocab-file": "/fsx/data/olcf/all_vocab.json",
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
#"pipe-parallel-size": 1,
"model-parallel-size": 2,

# batch / data settings
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 1,
"data-impl": "mmap",

#aws-rccl workaround
"num_workers": 0,

# model settings
"num-layers": 48,
"hidden-size": 6144,
"num-attention-heads": 48,
"seq-length": 2048,
"max-position-embeddings": 2048,
"norm": "layernorm",
"pos-emb": "rotary",
"no-weight-tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",

# these should provide some speedup but takes a while to build, set to true if desired
"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": true,

# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",


# optimizer settings
"optimizer": {
"type": "adam",
"params": {
"lr": 0.006,
"betas": [0.9, 0.999],
"eps": 1.0e-8,
}
},
"min_lr": 0.00006,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 450000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 450000000,
"contiguous_gradients": True,
},


# activation checkpointing
"checkpoint-activations": true,
"checkpoint-num-layers": 1,
"partition-activations": true,
"synchronize-each-layer": true,

# regularization
"gradient_clipping": 1.0,
"weight-decay": 0.1,
"hidden-dropout": 0,
"attention-dropout": 0,

# precision settings
"fp16": {
"type": "bfloat16",
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train-iters": 15300,
"lr-decay-iters": 15300,
"distributed-backend": "nccl",
"lr-decay-style": "cosine",
"warmup": 0.01,
"checkpoint-factor": 50,
"eval-interval": 100,
"eval-iters": 10,

# logging
"log-interval": 20,
"steps_per_print": 10,
"keep-last-n-checkpoints": 200,
"wall_clock_breakdown": true,
}
102 changes: 102 additions & 0 deletions 3.test_cases/15.gpt-neox/examples/olcf-6/configs/forge-m.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# GPT-2 pretraining setup
{

"tokenizer_type": "HFTokenizer",
"data-path": "/fsx/data/olcf/all_text_document",
"vocab-file": "/fsx/data/olcf/all_vocab.json",

# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
#"pipe-parallel-size": 1,
"model-parallel-size": 2,

# batch / data settings
"train_micro_batch_size_per_gpu": 12,
"gradient_accumulation_steps": 2,
"data-impl": "mmap",

#aws-rccl workaround
"num_workers": 0,

# model settings
"num-layers": 40,
"hidden-size": 5120,
"num-attention-heads": 40,
"seq-length": 2048,
"max-position-embeddings": 2048,
"norm": "layernorm",
"pos-emb": "rotary",
"no-weight-tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",

# these should provide some speedup but takes a while to build, set to true if desired
"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": true,

# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",


# optimizer settings
"optimizer": {
"type": "adam",
"params": {
"lr": 0.006,
"betas": [0.9, 0.999],
"eps": 1.0e-8,
}
},
"min_lr": 0.00006,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},


# activation checkpointing
"checkpoint-activations": true,
"checkpoint-num-layers": 1,
"partition-activations": true,
"synchronize-each-layer": true,

# regularization
"gradient_clipping": 1.0,
"weight-decay": 0.1,
"hidden-dropout": 0,
"attention-dropout": 0,

# precision settings
"fp16": {
"type": "bfloat16",
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train-iters": 15300,
"lr-decay-iters": 15300,
"distributed-backend": "nccl",
"lr-decay-style": "cosine",
"warmup": 0.01,
"checkpoint-factor": 50,
"eval-interval": 100,
"eval-iters": 10,

# logging
"log-interval": 50,
"steps_per_print": 10,
"keep-last-n-checkpoints": 200,
"wall_clock_breakdown": true,
}
Loading

0 comments on commit ff820c0

Please sign in to comment.