Skip to content

Commit

Permalink
Merge pull request #212 from aws-samples/cpu_ddp
Browse files Browse the repository at this point in the history
add DDP CPU example
  • Loading branch information
KeitaW authored Mar 14, 2024
2 parents a0c3005 + 8405a03 commit 946c45f
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 0 deletions.
23 changes: 23 additions & 0 deletions 3.test_cases/16.pytorch-cpu-ddp/1.train.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash
#SBATCH --job-name=cpu-ddp
#SBATCH --exclusive
#SBATCH --wait-all-nodes=1
#SBATCH --nodes 2
#SBATCH --cpus-per-task=4
#SBATCH --output=logs/%x_%j.out # logfile for stdout/stderr

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

echo Node IP: $head_node_ip
export LOGLEVEL=INFO

srun /opt/conda/envs/pytorch/bin/torchrun \
--nnodes 2 \
--nproc_per_node 4 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
ddp.py 50 10
60 changes: 60 additions & 0 deletions 3.test_cases/16.pytorch-cpu-ddp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# PyTorch DDP on CPU <!-- omit in toc -->

This test case is intended to provide a simple distributed training example on CPU using [PyTorch DDP](https://pytorch.org/tutorials/beginner/ddp_series_theory.html).

## 1. Preparation

This guide assumes that you have the following:

* A functional Slurm cluster on AWS, whose compute instances are based on DeepLearning AMI.
* An FSx for Lustre filesystem mounted on `/fsx`.

We recommend that you setup a Slurm cluster using the templates in the architectures [directory](../../1.architectures).


## 2. Submit training job

Submit DDP training job with:

```bash
sbatch 1.train.sbatch
```

Output of the training job can be found in `logs` directory:

```bash
# cat logs/cpu-ddp_xxx.out
Node IP: 10.1.96.108
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING] master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING]
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING] *****************************************
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING] *****************************************
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] Starting elastic_operator with launch configs:
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] entrypoint : ddp.py
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] min_nodes : 2
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] max_nodes : 2
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] nproc_per_node : 4
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] run_id : 5982
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] rdzv_backend : c10d
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] rdzv_endpoint : 10.1.96.108:29500
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] rdzv_configs : {'timeout': 900}
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] max_restarts : 0
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] monitor_interval : 5
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] log_dir : None
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] metrics_cfg : {}
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO]
[2024-03-12 08:22:45,552] torch.distributed.elastic.agent.server.local_elastic_agent: [INFO] log directory set to: /tmp/torchelastic_9g50nxjq/5982_tflt1tcd
[2024-03-12 08:22:45,552] torch.distributed.elastic.agent.server.api: [INFO] [default] starting workers for entrypoint: python
...
[RANK 3] Epoch 49 | Batchsize: 32 | Steps: 8
[RANK 5] Epoch 49 | Batchsize: 32 | Steps: 8
[RANK 4] Epoch 49 | Batchsize: 32 | Steps: 8
[2024-03-12 08:22:56,574] torch.distributed.elastic.agent.server.api: [INFO] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
[2024-03-12 08:22:56,574] torch.distributed.elastic.agent.server.api: [INFO] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish
[2024-03-12 08:22:56,575] torch.distributed.elastic.agent.server.api: [INFO] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
[2024-03-12 08:22:56,575] torch.distributed.elastic.agent.server.api: [INFO] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish
[2024-03-12 08:22:56,575] torch.distributed.elastic.agent.server.api: [INFO] Done waiting for other agents. Elapsed: 0.0010929107666015625 seconds
[2024-03-12 08:22:56,575] torch.distributed.elastic.agent.server.api: [INFO] Done waiting for other agents. Elapsed: 0.0005395412445068359 seconds
```
123 changes: 123 additions & 0 deletions 3.test_cases/16.pytorch-cpu-ddp/ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Modified version of https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

import torch
from torch.utils.data import Dataset

class MyTrainDataset(Dataset):
def __init__(self, size):
self.size = size
self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

def __len__(self):
return self.size

def __getitem__(self, index):
return self.data[index]

def ddp_setup():
init_process_group(backend="gloo")

class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
save_every: int,
snapshot_path: str,
) -> None:
self.model = model
self.rank = os.environ["RANK"]
self.train_data = train_data
self.optimizer = optimizer
self.save_every = save_every
self.epochs_run = 0
self.snapshot_path = snapshot_path
if os.path.exists(snapshot_path):
print("Loading snapshot")
self._load_snapshot(snapshot_path)

self.model = DDP(self.model)

def _load_snapshot(self, snapshot_path):
snapshot = torch.load(snapshot_path)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

def _run_batch(self, source, targets):
self.optimizer.zero_grad()
output = self.model(source)
loss = F.cross_entropy(output, targets)
loss.backward()
self.optimizer.step()

def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
print(f"[RANK {self.rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
self.train_data.sampler.set_epoch(epoch)
for source, targets in self.train_data:
source = source
targets = targets
self._run_batch(source, targets)

def _save_snapshot(self, epoch):
snapshot = {
"MODEL_STATE": self.model.module.state_dict(),
"EPOCHS_RUN": epoch,
}
torch.save(snapshot, self.snapshot_path)
print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")

def train(self, max_epochs: int):
for epoch in range(self.epochs_run, max_epochs):
self._run_epoch(epoch)
if epoch % self.save_every == 0:
self._save_snapshot(epoch)


def load_train_objs():
train_set = MyTrainDataset(2048) # load your dataset
model = torch.nn.Linear(20, 1) # load your model
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
sampler=DistributedSampler(dataset)
)


def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
ddp_setup()
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size)
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
trainer.train(total_epochs)
destroy_process_group()


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()

main(args.save_every, args.total_epochs, args.batch_size)

0 comments on commit 946c45f

Please sign in to comment.