Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with DeepSpeed using Pytorch Lightning #20

Open
PraljakReps opened this issue Dec 18, 2023 · 4 comments
Open

Issues with DeepSpeed using Pytorch Lightning #20

PraljakReps opened this issue Dec 18, 2023 · 4 comments

Comments

@PraljakReps
Copy link

PraljakReps commented Dec 18, 2023

Hello,

I am having issues with using deepspeed (stage 2) for 2 node configuration with 8 A100 GPUs. I followed https://github.com/argonne-lcf/GettingStarted/tree/master/DataScience/DeepSpeed, but I am using pytorch lightning instead for implementing DeepSpeed https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/deepspeed.html.

I am finding that I have no problems with training a model with 4 GPUs over 1 node using DeepSpeed(stage=2) with pytorch lightning; however, when I use 2 nodes with 8 total gpus, it appears that the 2nd node is stalling, and the code freezes at the final GPU rank on the first node.

I tried to see if this was reproducible with the boring_model.py, which is Pytorch Lightning's minimalist code script for reproducible error messages... I ended up with similar issues.

Here is the boring_model.py: https://gist.github.com/PraljakReps/d699f5d16af00e35cf4c8b8abfb09b6c

Using the trainer found in the above python script, I tried three configurations.

  • (works fine) gpus=4 and num_nodes=1
  • (error; second node hangs) gpus=4 and num_nodes=2
  • (error; second node hangs) mpiexec , gpus=4, and num_nodes=2

Depending on your pytorch lightning version (see below for my virtual env.), the trainer should look like this to reproduce my errors.

Initialize a trainer

trainer = pl.Trainer(
    gpus=4,
    max_epochs=1,
    num_nodes=2,
    precision=16,
    strategy="deepspeed_stage_2",
    callbacks=[lr_monitor]
)

i followed this link too when running mpiexec command: https://docs.alcf.anl.gov/polaris/data-science-workflows/frameworks/deepspeed/#:~:text=DeepSpeed.%20The%20base%20conda%20environment,cloning%20the%20base%20environment%20can

Thus, the code that I ran is the following:

NHOSTS=$(wc -l < "${PBS_NODEFILE}")
NGPU_PER_HOST=$(nvidia-smi -L | wc -l)
NGPUS="$((${NHOSTS}*${NGPU_PER_HOST}))"
mpiexec \
  --verbose \
  --envall \
  -n "${NGPUS}" \
  --ppn "${NGPU_PER_HOST}" \
  --hostfile="${PBS_NODEFILE}" \
  python \
  boring_model.py

and I am still getting the issue where the second node hangs...

Note: I am entering a compute node interactively with the following command:

qsub -I -l select=2:ngpus=4 -l filesystems=home:eagle -l walltime=1:00:00 -q debug -A <account>
Is there a way to run deepspeed over multiple-nodes on Polaris for the boring_model.py with pytorch lightning as a test case? Of course, my main goal is to conduct multi-node training for my research project, but I think success running of boring_model.py with pytorch-lightning+deepspeed is a easy case.

Software packages:

pytorch-lightning==1.9.5
torch==2.0.1
torchmetrics==1.2.0
lightning-bolts==0.7.0
deepspeed==0.12.5
python==3.8.18
@saforem2
Copy link
Member

@PraljakReps I came across this (from our own @coreyjadams 😂 just 3 weeks ago) that might be helpful

Lightning-AI/pytorch-lightning#19086

@saforem2
Copy link
Member

but I've been having trouble getting things to launch / work correctly on more than one node of Polaris.

If you just need DeepSpeed support, but can do without PyTorch Lightning, you can try:

you should then be able to

$ module load conda/2023-10-04 ; conda activate base
$ # cd /path/to/your/project/
$ # and make virtual environment in your project dir:
$ mkdir -p venvs/polaris/2023-10-04
$ python3 -m venv venvs/polaris/2023-10-04 --system-site-packages
$ source venvs/polaris/2023-10-04/bin/activate
$ # install [`ezpz`](https://github.com/saforem2/ezpz)
$ # (a little python library I've been working on)
$ python3 -m pip install "git+https://github.com/saforem2/ezpz"
$ # determine available GPUs
$ NHOSTS=$(wc -l < "${PBS_NODEFILE}")
$ NGPU_PER_HOST=$(nvidia-smi -L | wc -l)
$ NGPUS="$((${NHOSTS}*${NGPU_PER_HOST}))"
$ mpiexec \
    --verbose \
    --envall \
    -n "${NGPUS}" \
    --ppn "${NGPU_PER_HOST}" \
    --hostfile="${PBS_NODEFILE}" \
    python3 \
    deepspeed_test.py

where deepspeed_test.py contains:

"""
deepspeed_test.py
"""
import os
from pathlib import Path
import deepspeed

from ezpz import setup_torch, get_world_size

import torch
from torch.utils.data import DataLoader  # , Dataset
from pl_bolts.datasets import RandomDataset

BACKEND = 'deepspeed'  # also tried with DDP

RANK = setup_torch(backend=BACKEND)
WORLD_SIZE = get_world_size()
GPUS_PER_NODE = torch.cuda.device_count()
NODE_ID = RANK % GPUS_PER_NODE
NUM_NODES = WORLD_SIZE // GPUS_PER_NODE

NUM_SAMPLES = 10000
train = RandomDataset(32, NUM_SAMPLES)
train = DataLoader(train, batch_size=32)
val = RandomDataset(32, NUM_SAMPLES)
val = DataLoader(val, batch_size=32)
test = RandomDataset(32, NUM_SAMPLES)
test = DataLoader(test, batch_size=32)

def main():
    from ezpz import load_ds_config
    from .boring_model import BoringModel
    model = BoringModel()
    # lr_monitor = LearningRateMonitor(logging_interval='step')
    # ---- Couldn't get the `pl.Trainer` to work ---------------
    # Initialize a trainer
    # trainer = pl.Trainer(
    #     accelerator='gpu',
    #     strategy='deepspeed',
    #     # strategy="deepspeed_stage_2",
    #     # devices=WORLD_SIZE,
    #     gpus=GPUS_PER_NODE,
    #     max_epochs=1,
    #     num_nodes=NUM_NODES,
    #     # progress_bar_refresh_rate=20,
    #     precision=16,
    #     callbacks=[lr_monitor]
    # )
    # ---- Couldn't get the `pl.Trainer` to work ---------------
    if RANK == 0:
        print(f'{model=}')
    optimizer_, _ = model.configure_optimizers()
    optimizer = optimizer_[0]

    HERE = Path(os.path.abspath(__file__)).parent
    ds_config_path = HERE / 'ds_config.yaml'
    ds_config = load_ds_config(ds_config_path)
    model_engine, optimizer, *_ = deepspeed.initialize(
        model=model,
        config=ds_config,
        optimizer=optimizer,
        model_parameters=model.parameters(),
    )

    # Train loop
    device = model_engine.local_rank
    for step, x in enumerate(train):
        x = x.to(device)
        y = model_engine(x)
        loss = model_engine.loss(x, y)
        model_engine.backward(loss)
        model_engine.step()
        if RANK == 0:
            print(f'{step=}, {loss.item()=}')


if __name__ == '__main__':
    main()

and ds_config.yaml has:

---
dump_state: false
gradient_accumulation_steps: 1
wall_clock_breakdown: true
train_micro_batch_size_per_gpu: 1
fp16:
  enabled: false
  # min_loss_scale: 0
# bf16:
#   enabled: false
...

🤷🏻‍♂️ happy to chat more if you have any questions about anything

@saforem2
Copy link
Member

in the above ^ boring_model.py contains:

"""
boring_model.py
"""
import torch
from pytorch_lightning import LightningModule


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights
        # during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(
            prediction,
            torch.ones_like(prediction)
        )

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log('fake_test_acc', loss)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

@PraljakReps
Copy link
Author

PraljakReps commented Dec 20, 2023

Thanks @saforem2! The code works using deepspeed boilerplate without pl.trainer.

For now, I don't mind using DeepSpeed basic boiler plate, but ideally, I would like to get trainer.fit() to work since there is a lot of advantages of pytorch_lightning (e.g. checkpointing, tracking metrics, etc).

Do you recall the error message when implementing pl.trainer(...) and trainer.fit(...) ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants