Skip to content

Commit

Permalink
Sfno fix (#239)
Browse files Browse the repository at this point in the history
* Add warning if jsbeutifier not installed, set default for h5 in inference, fix import

* copy pytorch patches instead of using monkeypatching

* Update README.md to include patching doc

---------

Co-authored-by: Mohammad Amin Nabian <m.a.nabiyan@gmail.com>
  • Loading branch information
daviddpruitt and mnabian authored Nov 17, 2023
1 parent b615801 commit b9608e4
Show file tree
Hide file tree
Showing 22 changed files with 25 additions and 102 deletions.
2 changes: 2 additions & 0 deletions modulus/experimental/sfno/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ This is a research code built for massively parallel training of SFNO for weathe

## Getting started

**For distributed training or inference, run `patch_pytorch.sh` in advance. This will patch the pytorch distributed utilities to support complex values.**

## Installing optional dependencies

Install the optional dependencies by running
Expand Down
5 changes: 1 addition & 4 deletions modulus/experimental/sfno/convert_legacy_to_flexible.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from modulus.experimental.sfno.utils import logging_utils

import torch.distributed as dist

from modulus.experimental.sfno.networks.models import get_model

Expand All @@ -36,10 +37,6 @@
from modulus.experimental.sfno.utils.trainer import Trainer
from modulus.experimental.sfno.utils.YParams import ParamsBase

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()


class CheckpointSaver(Trainer):
"""
Expand Down
4 changes: 1 addition & 3 deletions modulus/experimental/sfno/inference/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@
# distributed computing stuff
from modulus.experimental.sfno.utils import comm
from modulus.experimental.sfno.utils import visualize
import torch.distributed as dist

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

class Inferencer(Trainer):
"""
Expand Down
4 changes: 1 addition & 3 deletions modulus/experimental/sfno/mpu/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@

import torch
import torch.nn.functional as F
import torch.distributed as dist

from modulus.experimental.sfno.utils import comm

from torch._utils import _flatten_dense_tensors

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

def get_memory_format(tensor):
if tensor.is_contiguous(memory_format=torch.channels_last):
Expand Down
5 changes: 1 addition & 4 deletions modulus/experimental/sfno/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda.amp import custom_fwd, custom_bwd
from modulus.experimental.sfno.utils import comm

Expand All @@ -28,10 +29,6 @@
from modulus.experimental.sfno.mpu.helpers import pad_helper
from modulus.experimental.sfno.mpu.helpers import truncate_helper

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()


class distributed_transpose_w(torch.autograd.Function):

Expand Down
4 changes: 1 addition & 3 deletions modulus/experimental/sfno/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from torch.nn.parallel import DistributedDataParallel
from modulus.experimental.sfno.utils import comm
import torch.distributed as dist

# torch utils
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
Expand All @@ -28,9 +29,6 @@
from modulus.experimental.sfno.mpu.helpers import _split
from modulus.experimental.sfno.mpu.helpers import _gather

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

# generalized
class _CopyToParallelRegion(torch.autograd.Function):
Expand Down
5 changes: 1 addition & 4 deletions modulus/experimental/sfno/networks/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import torch

from utils import comm

# imprt patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()
import torch.distributed as dist

def count_parameters(model, device):
with torch.no_grad():
Expand Down
3 changes: 3 additions & 0 deletions modulus/experimental/sfno/patch_pytorch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
cp third_party/torch/distributed/utils.py /usr/local/lib/python3.10/dist-packages/torch/distributed/
echo "Patching complete"
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda import amp

sys.path.append(os.path.join("/opt", "ERA5_wind"))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from modulus.experimental.sfno.utils import comm

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

# profile stuff
from ctypes import cdll
Expand Down
5 changes: 1 addition & 4 deletions modulus/experimental/sfno/perf_tests/distributed/dist_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda import amp

sys.path.append(os.path.join("/opt", "ERA5_wind"))
Expand All @@ -31,10 +32,6 @@
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
from modulus.experimental.sfno.mpu.layers import DistributedRealFFT2, DistributedInverseRealFFT2

# imprt patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()


def main(args, verify):
# parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda import amp

sys.path.append(os.path.join("/opt", "ERA5_wind"))
Expand All @@ -30,10 +31,6 @@
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
from modulus.experimental.sfno.mpu.fft3d import RealFFT3, InverseRealFFT3, DistributedRealFFT3, DistributedInverseRealFFT3

# imprt patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()


def main(args, verify):
# parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda import amp

sys.path.append(os.path.join("/opt", "ERA5_wind"))
Expand All @@ -31,10 +32,6 @@
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
from modulus.experimental.sfno.mpu.layers import DistributedRealFFT2, DistributedInverseRealFFT2

# imprt patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()


def main(args, verify):
# parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda import amp

sys.path.append(os.path.join("/opt", "ERA5_wind"))
Expand All @@ -30,10 +31,6 @@
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
from modulus.experimental.sfno.mpu.fft3d import RealFFT3, InverseRealFFT3, DistributedRealFFT3, DistributedInverseRealFFT3

# imprt patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()


def main(args, verify):
# parameters
Expand Down
4 changes: 1 addition & 3 deletions modulus/experimental/sfno/perf_tests/primitives/comp_mult.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
from torch.cuda import amp
import time
import apex
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

sys.path.append(os.path.join("/opt", "ERA5_wind"))

from modulus.experimental.sfno.mpu.layers import compl_mul_add_fwd, compl_mul_add_fwd_c

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

class ComplexMult(nn.Module):
def __init__(self, num_blocks, block_size, hidden_size_factor, use_complex_kernels=True):
Expand Down
6 changes: 2 additions & 4 deletions modulus/experimental/sfno/perf_tests/sfno/shtfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

from torch.cuda import amp

sys.path.append(os.path.join("/opt", "makani"))
Expand All @@ -31,10 +33,6 @@
from torch_harmonics import RealSHT as RealSphericalHarmonicTransform
from torch_harmonics import InverseRealSHT as InverseRealSphericalHarmonicTransform

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

# profile stuff
from ctypes import cdll
libcudart = cdll.LoadLibrary('libcudart.so')
Expand Down
5 changes: 1 addition & 4 deletions modulus/experimental/sfno/utils/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
from modulus.experimental.sfno.utils.logging_utils import disable_logging
import math
import torch
import torch.distributed as dist
import datetime as dt
from typing import Union
import numpy as np

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

# dummy placeholders
_COMM_LIST = []
_COMM_NAMES = {}
Expand Down
5 changes: 1 addition & 4 deletions modulus/experimental/sfno/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
from torch.utils.data import DataLoader

# distributed stuff
import torch.distributed as dist
from modulus.experimental.sfno.utils import comm

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()


def init_distributed_io(params):
# set up sharding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#import cv2

# distributed stuff
import torch.distributed as dist
from modulus.experimental.sfno.utils import comm

# DALI stuff
Expand All @@ -35,10 +36,6 @@
import modulus.experimental.sfno.utils.dataloaders.dali_es_helper_2d as esh
from modulus.experimental.sfno.utils.grids import GridConverter

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()


class ERA5DaliESDataloader(object):

Expand Down
33 changes: 0 additions & 33 deletions modulus/experimental/sfno/utils/distributed_patch.py

This file was deleted.

5 changes: 1 addition & 4 deletions modulus/experimental/sfno/utils/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@
# distributed computing stuff
from modulus.experimental.sfno.utils import comm
from modulus.experimental.sfno.utils.metrics.functions import GeometricL1, GeometricRMSE, GeometricACC, Quadrature
import torch.distributed as dist
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

class MetricsHandler():
"""
Handler object which takes care of computation of metrics. Keeps buffers for the computation of
Expand Down
4 changes: 1 addition & 3 deletions modulus/experimental/sfno/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
# distributed computing stuff
from modulus.experimental.sfno.utils import comm
from modulus.experimental.sfno.utils import visualize
import torch.distributed as dist

# for the manipulation of state dict
from collections import OrderedDict
Expand All @@ -51,9 +52,6 @@
from modulus.experimental.sfno.third_party.torch.optim.adam import Adam as CustomAdam
from modulus.experimental.sfno.third_party.torch.optim.adamw import AdamW as CustomAdamW

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

class Trainer():
"""
Expand Down
4 changes: 1 addition & 3 deletions modulus/experimental/sfno/utils/trainer_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
# distributed computing stuff
from modulus.experimental.sfno.utils import comm
from modulus.experimental.sfno.utils import visualize
import torch.distributed as dist

# for the manipulation of state dict
from collections import OrderedDict
Expand All @@ -50,9 +51,6 @@
from modulus.experimental.sfno.third_party.torch.optim.adam import Adam as CustomAdam
from modulus.experimental.sfno.third_party.torch.optim.adamw import AdamW as CustomAdamW

# import patched distributed
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
dist = dist_patch()

# profile stuff
from ctypes import cdll
Expand Down

0 comments on commit b9608e4

Please sign in to comment.