From c6dab482d61408153b11cb81595d3078b0d205c9 Mon Sep 17 00:00:00 2001 From: Peter Harrington <48932392+pzharrington@users.noreply.github.com> Date: Mon, 16 Dec 2024 09:42:49 -0800 Subject: [PATCH] StormCast training code improvements (#738) * adding stormcast raw files * major cleanup, refactor and consolidation Signed-off-by: Peter Harrington * More cleanup and init readme Signed-off-by: Peter Harrington * port command line args to standard argparse Signed-off-by: Peter Harrington * remove unused network and loss wrappers Signed-off-by: Peter Harrington * add torchrun instructions Signed-off-by: Peter Harrington * drop dnnlib utils Signed-off-by: Peter Harrington * use Modulus DistributedManager, streamline cmd args Signed-off-by: Peter Harrington * Use standard torch checkpoints instead of pickles Signed-off-by: Peter Harrington * Standardize model configs and channel selection across training and inference Signed-off-by: Peter Harrington * checkpoint format standardization for train/inference Signed-off-by: Peter Harrington * finalize additional deps Signed-off-by: Peter Harrington * format and linting Signed-off-by: Peter Harrington * drop docker and update changelog Signed-off-by: Peter Harrington * Address feedback Signed-off-by: Peter Harrington * add variables to readme, rename network types Signed-off-by: Peter Harrington * swap stormcast to modulus nn and loss defs Signed-off-by: Peter Harrington * Swap to modulus checkpoint save and load utils Signed-off-by: Peter Harrington * Swap to modulus networks/losses, use modulus checkpointing and logging Signed-off-by: Peter Harrington * add power spectrum to modulus metrics, remove unused utils Signed-off-by: Peter Harrington * Readme update and unit tests Signed-off-by: Peter Harrington * drop unused files Signed-off-by: Peter Harrington * drop unused diffusions files Signed-off-by: Peter Harrington * update changelog Signed-off-by: Peter Harrington --------- Signed-off-by: Peter Harrington Co-authored-by: nvssh nssswitch user account --- CHANGELOG.md | 2 + examples/generative/stormcast/README.md | 89 +- .../stormcast/config/dataset/hrrr_era5.yaml | 70 ++ .../stormcast/config/diffusion.yaml | 45 + .../stormcast/config/hydra/default.yaml | 18 + .../stormcast/config/inference/stormcast.yaml | 33 + .../stormcast/config/model/stormcast.yaml | 30 + .../stormcast/config/regression.yaml | 35 + .../config/sampler/edm_deterministic.yaml | 32 + .../stormcast/config/stormcast_inference.yaml | 36 + .../stormcast/config/training/default.yaml | 41 + examples/generative/stormcast/inference.py | 218 ++-- examples/generative/stormcast/train.py | 190 +-- .../generative/stormcast/utils/YParams.py | 75 -- .../stormcast/utils/data_loader_hrrr_era5.py | 50 +- .../stormcast/utils/diffusions/generate.py | 85 -- .../stormcast/utils/diffusions/losses.py | 100 -- .../stormcast/utils/diffusions/networks.py | 1043 ----------------- .../stormcast/utils/diffusions/run_edm.py | 95 -- .../utils/diffusions/training_loop.py | 557 --------- examples/generative/stormcast/utils/misc.py | 241 ---- examples/generative/stormcast/utils/nn.py | 127 ++ .../generative/stormcast/utils/spectrum.py | 74 +- .../generative/stormcast/utils/trainer.py | 472 ++++++++ .../stormcast/utils/training_stats.py | 304 ----- modulus/metrics/diffusion/loss.py | 13 +- modulus/metrics/general/power_spectrum.py | 116 ++ modulus/models/diffusion/__init__.py | 2 +- modulus/models/diffusion/preconditioning.py | 44 +- modulus/models/diffusion/song_unet.py | 20 +- modulus/models/diffusion/unet.py | 101 ++ .../utils/generative/deterministic_sampler.py | 32 +- test/metrics/diffusion/test_losses.py | 11 +- test/metrics/test_metrics_general.py | 25 + test/models/diffusion/test_preconditioning.py | 38 + test/models/diffusion/test_song_unet.py | 20 + test/models/diffusion/test_unet_wrappers.py | 134 +++ 37 files changed, 1677 insertions(+), 2941 deletions(-) create mode 100644 examples/generative/stormcast/config/dataset/hrrr_era5.yaml create mode 100644 examples/generative/stormcast/config/diffusion.yaml create mode 100644 examples/generative/stormcast/config/hydra/default.yaml create mode 100644 examples/generative/stormcast/config/inference/stormcast.yaml create mode 100644 examples/generative/stormcast/config/model/stormcast.yaml create mode 100644 examples/generative/stormcast/config/regression.yaml create mode 100644 examples/generative/stormcast/config/sampler/edm_deterministic.yaml create mode 100644 examples/generative/stormcast/config/stormcast_inference.yaml create mode 100644 examples/generative/stormcast/config/training/default.yaml delete mode 100644 examples/generative/stormcast/utils/YParams.py delete mode 100644 examples/generative/stormcast/utils/diffusions/generate.py delete mode 100644 examples/generative/stormcast/utils/diffusions/losses.py delete mode 100644 examples/generative/stormcast/utils/diffusions/networks.py delete mode 100644 examples/generative/stormcast/utils/diffusions/run_edm.py delete mode 100644 examples/generative/stormcast/utils/diffusions/training_loop.py delete mode 100644 examples/generative/stormcast/utils/misc.py create mode 100644 examples/generative/stormcast/utils/nn.py create mode 100644 examples/generative/stormcast/utils/trainer.py delete mode 100644 examples/generative/stormcast/utils/training_stats.py create mode 100644 modulus/metrics/general/power_spectrum.py create mode 100644 test/models/diffusion/test_unet_wrappers.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b5d5d74e09..b65314d743 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Refactored StormCast training example + ### Deprecated ### Removed diff --git a/examples/generative/stormcast/README.md b/examples/generative/stormcast/README.md index 8ae06e28e0..970e8c8029 100644 --- a/examples/generative/stormcast/README.md +++ b/examples/generative/stormcast/README.md @@ -1,8 +1,6 @@ ## StormCast: Kilometer-Scale Convection Allowing Model Emulation using Generative Diffusion Modeling -**Note: this example is an initial release of the StormCast code and will be heavily refactored in future releases** - ## Problem overview Convection-allowing models (CAMs) are essential tools for forecasting severe thunderstorms and @@ -18,11 +16,18 @@ accuracy, demonstrating ability to replicate storm dynamics, observed radar refl atmospheric structure via deep learning-based CAM emulation. StormCast enables high-resolution ML-driven regional weather forecasting and climate risk analysis. -

+The design of StormCast relies on two neural networks: + 1. A regression model, which provides a deterministic estimate of the next HRRR timestep given the previous timestep's HRRR and background ERA5 states + 2. A diffusion model, which is given the previous HRRR timestep as well as the estimate from the regression model, and provides a correction to the regression model estimate to produce a final high-quality prediction of the next high-resolution atmospheric state. + +Much like other data-driven weather models, StormCast can make longer forecasts (more than one timestep) during inference by feeding its predictions back into the model as input for the next step (autoregressive rollout). The regression and diffusion components are trained separately (with the diffusion model training requiring a regression model as prerequisite), then coupled together in inference. Note in the above description, we specifically name HRRR and ERA5 as the regional high-resolution and global coarse-resolution data sources/targets, respectively, but the StormCast setting should generalize to any regional/global coupling of interest. + + + ## Getting started ### Preliminaries @@ -30,66 +35,76 @@ Start by installing Modulus (if not already installed) and copying this folder ( ### Configuration basics -StormCast training is handled by `train.py` and controlled by a YAML configuration file in `config/config.yaml` and command line arguments. You can choose the configuration file using the `--config_file` option, and a specific configuration within that file with the `--config-name` option. The main configuration file specifies the training dataset, the model configuration and the training options. To change a configuration option, you can either edit the existing configurations directly or make new ones by inheriting from the existing configs and overriding specific options. For example, one could create a new config for training the diffusion model in StormCast by creating a new config that inherits from the existing `diffusion` config in `config/config.yaml`: -``` -diffusion_bs64: - <<: *diffusion - batch_size: 1 +StormCast training is handled by `train.py`, configured using [hydra](https://hydra.cc/docs/intro/) based on the contents of the `config` directory. Hydra allows for YAML-based modular and hierarchical configuration management and supports command-line overrides for quick testing and experimentation. The `config` directory includes the following subdirectories: + - `dataset`: specifies the resolution, number of variables, and other parameters of the dataset + - `model`: specifies the model type and model-specific hyperparameters + - `sampler`: specifies hyperparameters used in the sampling process for diffusion models + - `training`: specifies training-specific hyperparameters and settings like checkpoint/log frequency and where to save training outputs + - `inference` specifies inference-specific settings like which initial condition to run, which model checkpoints to use, etc. + - `hydra`: specifies basic hydra settings, like where to store outputs (based on the training or inference outputs directories) + +Also in the `config` directory are several top-level configs which show how to train a `regression` model or `diffusion` model, and run inference (`stormcast-inference`). One can select any of these by specifying it as a config name at the command line (e.g., `--config-name=regression`); optionally one can also override any specific items of interest via command line args, e.g.: +```bash +python train.py --config-name regression training.batch_size=4 ``` -The basic configuration file currently contains configurations for just the `regression` and `diffusion` components of StormCast. Note any diffusion model you train will need a pretrained regression model to use, due to how StormCast is designed (you can refer to the paper for more details), thus there are two config items that must be defined to train a diffusion model: - 1. `regression_weights` -- The path to a checkpoint with model weights for the regression model. This file should be a pytorch checkpoint saved by your training script, with the `state_dict` for the regression network saved under the `net` key. - 2. `regression_config` -- the config name used to train this regression model +More extensive configuration modifications can be made by creating a new top-level configuration file similar to `regression` or `diffusion`. See `diffusion.yaml` for an example of how to specify a top-level config that uses default configuration settings with additional custom modifications added on top. -All configuration items related to the dataset are also contained in `config/config.yaml`, most importantly the location on the filesystem of the prepared HRRR/ERA5 Dataset (see [Dataset section](#dataset) for details). +Note any diffusion model you train will need a pretrained regression model to use, so there are two config items that must be defined to train a diffusion model: + 1. `model.use_regession_net = True` + 2. `model.regression_weights` set to the path of a Modulus (`.mdlus`) checkpoint with model weights for the regression model. These are saved in the checkpoints directory during training. -There is also a model registry `config/registry.json` which can be used to index different model versions to be used in inference/evaluation. For simplicity, there is just a single model version specified there currently, which matches the StormCast model used to generate results in the paper. +Once again, the reference `diffusion.yaml` top-level config shows an example of how to specify these settings. + +At runtime, hydra will parse the config subdirectory and command line over-rides into a runtime configuration object `cfg`, which will have all settings accessible via both attribute or dictionary-like interfaces. For example, the total training batch size can be accessed either as `cfg.training.batch_size` or `cfg['training']['batch_size']`. ### Training the regression model -To train the StormCast regression model, we use the default configuration file `config.yaml` and specify the `regression` config, along with the `--outdir` argument to choose where training logs and checkpoints should be saved. -We also can use command line options defined in `train.py` to specify other details, like a unique run ID to use for the experiment (`--run_id`). On a single GPU machine, for example, run: +To train the StormCast regression model, we simply specify the example `regression` config and an optional name for the training experiment. On a single GPU machine, for example, run: ```bash -python train.py --outdir rundir --config_file ./config/config.yaml --config_name regression --run_id 0 +python train.py --config-name regression training.experiment_name=regression +``` + +This will initialize training experiment and launch the main training loop, which is defined in `utils/trainer.py`. Outputs (training logs, checkpoints, etc.) will be saved to a directory specified by the following `training` config items: +```yaml +training.outdir: 'rundir' # Root path under which to save training outputs +training.experiment_name: 'stormcast' # Name for the training experiment +training.run_id: '0' # Unique ID to use for this training run +training.rundir: ./${training.outdir}/${training.experiment_name}/${training.run_id} # Path where experiement outputs will be saved ``` -This will initialize training experiment and launch the main training loop, which is defined in `utils/diffusions/training_loop.py`. +As you can see, the `training.run_id` setting can be used for distinguishing between different runs of the same configuration. The final training output directory is constructed by composing together the `training.outdir` root path (defaults to `rundir`), the `training.experiment_name`, and the `training.run_id`. ### Training the diffusion model -The method for launching a diffusion model training looks almost identical, and we just have to change the configuration name appropriately. However, since we need a pre-trained regression model for the diffusion model training, this config must define `regression_pickle` to point to a compatible pickle file with network weights for the regression model. Once that is taken care of, launching diffusion training looks nearly identical as previously: +The method for launching a diffusion model training looks almost identical, and we just have to change the configuration name appropriately. However, since we need a pre-trained regression model for the diffusion model training, the specified config must include the settings mentioned above in [Configuration Basics](#configuration-basics) to provide network weights for the regression model. With that, launching diffusion training looks something like: ```bash -python train.py --outdir rundir --config_file ./config/config.yaml --config_name diffusion --run_id 0 +python train.py --config-name diffusion training.experiment_name=diffusion ``` Note that the full training pipeline for StormCast is fairly lengthy, requiring about 120 hours on 64 NVIDIA H100 GPUs. However, more lightweight trainings can still produce decent models if the diffusion model is not trained for as long. -Both regression and diffusion training can be distributed easily with data parallelism via `torchrun`. One just needs to ensure the configuration being run has a large enough batch size to be distributed over the number of available GPUs/processes. The example `regression` and `diffusion` configs in `config/config.yaml` just use a batch size of 1 for simplicity, but new configs can be easily added [as described above](#configuration-basics). For example, distributed training over 8 GPUs on one node would look something like: +Both regression and diffusion training can be distributed easily with data parallelism via `torchrun` or other launchers (e.g., SLURM `srun`). One just needs to ensure the configuration being run has a large enough batch size to be distributed over the number of available GPUs/processes. The example `regression` and `diffusion` configs just use a batch size of 1 for simplicity, but new configs can be easily added [as described above](#configuration-basics). For example, distributed training over 8 GPUs on one node would look something like: ```bash -torchrun --standalone --nnodes=1 --nproc_per_node=8 train.py --outdir rundir --config_file ./config/config.yaml --config_name --run_id 0 +torchrun --standalone --nnodes=1 --nproc_per_node=8 train.py --config-name ``` -Once the training is completed, you can enter a new model into `config/registry.json` that points to the checkpoints (`.pt` file in your training output directory), and you are ready to run inference. - ### Inference -A simple demonstrative inference script is given in `inference.py`, which loads a pretrained model from a local directory named `stormcast_checkpoints`. -Yout should update this path to the checkpoints saved by your training runs that you want to run inference for. -The `inference.py` script will run a 12-hour forecast and save outputs as a `zarr` file along with a few plots saved as `png` files. +A simple demonstrative inference script is given in `inference.py`, which is also configured using hydra in a manner similar to training. The reference `stormcast_inference` config shows an example inference config, which looks largely the same as a training config except the output directory is now controlled by the settings from `inference` rather than `training` config: +```yaml +inference.outdir: 'rundir' # Root path under which to save inference outputs +inference.experiment_name: 'stormcast-inference' # Name for the inference experiment being run +inference.run_id: '0' # Unique identifier for the inference run +inference.rundir: ./${inference.outdir}/${inference.experiment_name}/${inference.run_id} # Path where experiment outputs will be saved +``` To run inference, simply do: - -```bash -python inference.py -``` -This inference script is configured by the contents of a model registry, which specifies config files and names to use for each of the diffusion and regression networks, along with other inference options which specify architecure types and a short description of the model. The `inference.py` script will automatically use the default file for the model registry (`config/registry.json`) and evaluate the `stormcast` example model, but you can configure it to run your desired inference case(s) with the following command-line options: ```bash - --outdir DIR Where to save the results - --registry_file FILE Path to model registry file - --model_name MODEL Name of model to evaluate from the registry +python inference.py --config-name ``` -We also recommend bringing your checkpoints to [earth2studio](https://github.com/NVIDIA/earth2studio) -for further anaylysis and visualizations. +This will load regression and diffusion models from directories specified by `inference.regression_checkpoint` and `inference.diffusion_checkpoint` respectively; each of these should be a path to a Modulus checkpoint (`.mdlus` file) from your training runs. The `inference.py` script will use these models to run a forecast and save outputs as a `zarr` file along with a few plots saved as `png` files. We also recommend bringing your checkpoints to [earth2studio](https://github.com/NVIDIA/earth2studio) +for further analysis and visualizations. ## Dataset @@ -133,7 +148,7 @@ A custom dataset object is defined in `utils/data_loader_hrrr_era5.py`, which lo ## Logging -These scripts use Weights & Biases for experiment tracking, which can be enabled by passing the `--log_to_wandb` argument to `train.py`. Academic accounts are free to create at [wandb.ai](https://wandb.ai/). +These scripts use Weights & Biases for experiment tracking, which can be enabled by setting `training.log_to_wandb=True`. Academic accounts are free to create at [wandb.ai](https://wandb.ai/). Once you have an account set up, you can adjust `entity` and `project` in `train.py` to the appropriate names for your `wandb` workspace. diff --git a/examples/generative/stormcast/config/dataset/hrrr_era5.yaml b/examples/generative/stormcast/config/dataset/hrrr_era5.yaml new file mode 100644 index 0000000000..7e5af370b5 --- /dev/null +++ b/examples/generative/stormcast/config/dataset/hrrr_era5.yaml @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Main dataset +location: 'data' # Path to the dataset +conus_dataset_name: 'hrrr_v3' # Version name for the dataset +hrrr_stats: 'stats_v3_2019_2021' # Summary stats name for the dataset + +# Domain +hrrr_img_size: [512, 640] # Image dimensions of the HRRR region of interest +boundary_padding_pixels: 0 # set this to 0 for no padding of ERA5 beyond HRRR domain, + # 32 for 32 pixels of padding in each direction, etc. + +# Temporal selection +dt: 1 # Timestep between samples (in multiples of the base HRRR 1hr timestep) +train_years: [2018, 2019, 2020, 2021] # Years to use for training +valid_years: [2022] # Years to use for validation + +# Variable selection +invariants: ["lsm", "orog"] # Invariant quantitites to include +input_channels: 'all' #'all' or list of channels to condition on +diffusion_channels: "all" #'all' or list of channels to condition on +exclude_channels: # Dataset channels to exclude from inputs/predicitons + - u35 + - u40 + - v35 + - v40 + - t35 + - t40 + - q35 + - q40 + - w1 + - w2 + - w3 + - w4 + - w5 + - w6 + - w7 + - w8 + - w9 + - w10 + - w11 + - w13 + - w15 + - w20 + - w25 + - w30 + - w35 + - w40 + - p25 + - p30 + - p35 + - p40 + - z35 + - z40 + - tcwv + - vil diff --git a/examples/generative/stormcast/config/diffusion.yaml b/examples/generative/stormcast/config/diffusion.yaml new file mode 100644 index 0000000000..23a0288557 --- /dev/null +++ b/examples/generative/stormcast/config/diffusion.yaml @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Defaults +defaults: + + # Dataset + - dataset/hrrr_era5 + + # Model + - model/stormcast + + # Training + - training/default + + # Sampler + - sampler/edm_deterministic + + # Hydra + - hydra/default + + - _self_ + +# Diffusion model specific changes +model: + use_regression_net: True + regression_weights: "stormcast_checkpoints/regression/StormCastUNet.0.0.mdlus" + previous_step_conditioning: True + spatial_pos_embed: True + +training: + loss: 'edm' \ No newline at end of file diff --git a/examples/generative/stormcast/config/hydra/default.yaml b/examples/generative/stormcast/config/hydra/default.yaml new file mode 100644 index 0000000000..07d11b190f --- /dev/null +++ b/examples/generative/stormcast/config/hydra/default.yaml @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +run: + dir: ${training.outdir}/${training.experiment_name}/${training.run_id} diff --git a/examples/generative/stormcast/config/inference/stormcast.yaml b/examples/generative/stormcast/config/inference/stormcast.yaml new file mode 100644 index 0000000000..861c7a331d --- /dev/null +++ b/examples/generative/stormcast/config/inference/stormcast.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# General inference config items +outdir: 'rundir' # Root path under which to save inference outputs +experiment_name: 'stormcast-inference' # Name for the inference experiment being run +run_id: 0 # Unique identifier for the inference run +rundir: ./${inference.outdir}/${inference.experiment_name}/${inference.run_id} # Path where experiement outputs will be saved +regression_checkpoint: stormcast_checkpoints/regression/StormCastUNet.0.0.mdlus +diffusion_checkpoint: stormcast_checkpoints/diffusion/EDMPrecond.0.0.mdlus + +# Initial and lead times +initial_time: "2022-11-04T21:00:00" # datetime to intialize forecast with (YYYY-MM-DDTHH:MM:SS) + # note minimum time resolution of HRRR data is 1hr +n_steps: 12 # number of steps (in units of 1hr timesteps) to forecast + +# I/O +plot_var_hrrr: "refc" # HRRR variable to plot +plot_var_era5: "t2m" # ERA5 variable to plot +output_hrrr_channels: [] # HRRR variables to save to disk (empty list == all channels saved) \ No newline at end of file diff --git a/examples/generative/stormcast/config/model/stormcast.yaml b/examples/generative/stormcast/config/model/stormcast.yaml new file mode 100644 index 0000000000..c07d883907 --- /dev/null +++ b/examples/generative/stormcast/config/model/stormcast.yaml @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Model name +model_name: 'regression' # 'regression', 'diffusion', or custom model added by user + +# Model hyperparameters +previous_step_conditioning: False # Whether or not to condition on the previous step outputs as well (enable for diffusion) +spatial_pos_embed: False # Whether or not to add an additive position embed after the first conv in the U-Net +P_mean: -1.2 # Center of the EDM lognormal noise sampling distribution +attn_resolutions: [] # Internal resolutions within the U-Net to apply self-attention (empty list == no self-attention) + +# Pretrained regression model +use_regression_net: False # Whether or not to use a regression net as a first step (enable for diffusion training) +regression_weights: "stormcast_checkpoints/regression/UNet.0.0.mdlus" # Path to pretrained regression network, + # used if use_regression_net=True diff --git a/examples/generative/stormcast/config/regression.yaml b/examples/generative/stormcast/config/regression.yaml new file mode 100644 index 0000000000..c6afbf81a1 --- /dev/null +++ b/examples/generative/stormcast/config/regression.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Defaults +defaults: + + # Dataset + - dataset/hrrr_era5 + + # Model + - model/stormcast + + # Training + - training/default + + # Sampler + - sampler/edm_deterministic + + # Hydra + - hydra/default + + - _self_ diff --git a/examples/generative/stormcast/config/sampler/edm_deterministic.yaml b/examples/generative/stormcast/config/sampler/edm_deterministic.yaml new file mode 100644 index 0000000000..921e33fac9 --- /dev/null +++ b/examples/generative/stormcast/config/sampler/edm_deterministic.yaml @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Sampler args +# Below are passed as kwargs to modulus.utils.generative.determinisitic_sampler +# Also supports stochastic sampling via S_churn and related args. +# See EDM paper for details (https://arxiv.org/abs/2206.00364) + +name: 'EDM Deterministic' +args: + num_steps: 18 + sigma_min: 0.002 + sigma_max: 800 + rho: 7 + S_churn: 0. + S_min: 0. + S_max: .inf + S_noise: 1 \ No newline at end of file diff --git a/examples/generative/stormcast/config/stormcast_inference.yaml b/examples/generative/stormcast/config/stormcast_inference.yaml new file mode 100644 index 0000000000..b3d33ee78d --- /dev/null +++ b/examples/generative/stormcast/config/stormcast_inference.yaml @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Defaults +defaults: + + # Dataset + - dataset/hrrr_era5 + + # Model + - model/stormcast + + # Sampler + - sampler/edm_deterministic + + # Inference + - inference/stormcast + + - _self_ + +hydra: + run: + dir: ${inference.rundir} \ No newline at end of file diff --git a/examples/generative/stormcast/config/training/default.yaml b/examples/generative/stormcast/config/training/default.yaml new file mode 100644 index 0000000000..40d9eb946f --- /dev/null +++ b/examples/generative/stormcast/config/training/default.yaml @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# General training config items +outdir: 'rundir' # Root path under which to save training outputs +experiment_name: 'stormcast-training' # Name for the training experiment +run_id: '0' # Unique ID to use for this training run +rundir: ./${training.outdir}/${training.experiment_name}/${training.run_id} # Path where experiement outputs will be saved +num_data_workers: 4 # Number of dataloader worker threads per proc +log_to_wandb: False # Whether or not to log to Weights & Biases (requires wandb account) +seed: -1 # Specify a random seed by setting this to an int > 0 +cudnn_benchmark: True # Enable/disable CuDNN benchmark mode +resume_checkpoint: null # Specify a path to a training checkpoint to resume from + +# Logging frequency +print_progress_freq: 5 # How often to print progress, measured in number of training steps +checkpoint_freq: 5 # How often to save the checkpoints, measured in number of training steps +validation_freq: 5 # how often to record the validation loss, measured in number of training steps + +# Optimization hyperparameters +batch_size: 1 # Total training batch size -- must be >= (and divisble by) number of GPUs being used +lr: 4E-4 # Initial learning rate +lr_rampup_steps: 1000 # Number of training steps over which to perform linear LR warmup +total_train_steps: 20 # Number of total training steps +clip_grad_norm: -1 # Threshold for gradient clipping, set to -1 to disable +loss: 'regression' # Loss type; use 'regression' or 'edm' for the regression and diffusion, respectively + diff --git a/examples/generative/stormcast/inference.py b/examples/generative/stormcast/inference.py index 7e961b8d1a..29366ad8d4 100644 --- a/examples/generative/stormcast/inference.py +++ b/examples/generative/stormcast/inference.py @@ -22,152 +22,68 @@ import xarray as xr import zarr import pandas as pd -import json -import argparse +import hydra from modulus.distributed import DistributedManager +from omegaconf import DictConfig +from modulus.models import Module -from utils.diffusions.run_edm import EDMRunner -from utils.diffusions.networks import get_preconditioned_architecture, EasyRegressionV2 -from utils.data_loader_hrrr_era5 import get_dataset -from utils.YParams import YParams - - -def main( - device: str = "cuda:0", - initial_time: datetime = datetime(2022, 11, 4, 21, 0), - plot_var_hrrr: str = "refc", - plot_var_era5: str = "t2m", - n_steps: int = 12, - output_hrrr_channels: list[str] = [], -): - parser = argparse.ArgumentParser(description="Run a StormCast inference") - parser.add_argument( - "--outdir", - help="Where to save the results", - metavar="DIR", - type=str, - default="./rundir", - ) - parser.add_argument( - "--registry_file", - help="Path to model registry file", - metavar="FILE", - type=str, - default="config/registry.json", - ) - parser.add_argument( - "--model_name", - help="Name of model to evaluate from the registry", - metavar="MODEL", - type=str, - default="stormcast", - ) - opts = parser.parse_args() +from utils.nn import regression_model_forward, diffusion_model_forward +from utils.data_loader_hrrr_era5 import HrrrEra5Dataset - DistributedManager.initialize() - # load model registry: - with open(opts.registry_file, "r") as f: - registry = json.load(f) - model_info = registry["models"][opts.model_name] +@hydra.main(version_base=None, config_path="config", config_name="stormcast_inference") +def main(cfg: DictConfig): - params = YParams( - model_info["regression_config_file"], model_info["regression_config_name"] - ) - params.local_batch_size = 1 - params.valid_years = [2022] - residual = params.residual - - edm_config = model_info["edm_config_name"] - edm_params = YParams(model_info["edm_config_file"], edm_config) - diffusion_channels = edm_params.diffusion_channels - input_channels = edm_params.input_channels - diffusion_path = model_info["edm_checkpoint_path"] - edm_runner = EDMRunner(edm_params, checkpoint_path=diffusion_path) - residual = edm_params.residual - - os.makedirs(opts.outdir, exist_ok=True) - - if params.boundary_padding_pixels > 0: - params.era5_img_size = ( - params.hrrr_img_size[0] + 2 * params.boundary_padding_pixels, - params.hrrr_img_size[1] + 2 * params.boundary_padding_pixels, - ) - else: - params.era5_img_size = params.hrrr_img_size + # Initialize + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + + initial_time = datetime.fromisoformat(cfg.inference.initial_time) + n_steps = cfg.inference.n_steps - dataset = get_dataset(params, train=False) + # Dataset prep + dataset = HrrrEra5Dataset(cfg.dataset, train=False) - resolution = params.hrrr_img_size[0] - _, hrrr_channels = dataset._get_hrrr_channel_names() + base_hrrr_channels, hrrr_channels = dataset._get_hrrr_channel_names() diffusion_channels = ( hrrr_channels - if params.diffusion_channels == "all" - else params.diffusion_channels + if cfg.dataset.diffusion_channels == "all" + else cfg.dataset.diffusion_channels ) input_channels = ( - hrrr_channels if params.input_channels == "all" else params.input_channels + hrrr_channels + if cfg.dataset.input_channels == "all" + else cfg.dataset.input_channels ) - input_channel_indices = [hrrr_channels.index(channel) for channel in input_channels] + diffusion_channel_indices = [ hrrr_channels.index(channel) for channel in diffusion_channels ] - conditional_channels = ( - len(diffusion_channels) + len(params.invariants) + params.n_era5_channels - ) - net = get_preconditioned_architecture( - name="regression", - resolution=resolution, - target_channels=len(diffusion_channels), - conditional_channels=conditional_channels, - label_dim=0, - spatial_embedding=params.spatial_pos_embed, - attn_resolutions=params.attn_resolutions, - ) - - # Load pretrained regression model - regression_path = model_info["regression_checkpoint_path"] - chkpt = torch.load(regression_path, weights_only=True) - net.load_state_dict(chkpt["net"], strict=True) - model = EasyRegressionV2(net).to(device) + input_channel_indices = [ + list(hrrr_channels).index(channel) for channel in input_channels + ] hrrr_data = xr.open_zarr( - os.path.join(params.location, params.conus_dataset_name, "valid", "2021.zarr") + os.path.join( + cfg.dataset.location, cfg.dataset.conus_dataset_name, "valid", "2021.zarr" + ) ) - dataset_obj = get_dataset(params, train=False) - invariant_array = dataset._get_invariants() invariant_tensor = torch.from_numpy(invariant_array).to(device).repeat(1, 1, 1, 1) - model.set_invariant(invariant_tensor) - base_hrrr_channels, hrrr_channels = dataset_obj._get_hrrr_channel_names() - hrrr_channel_indices = [ - list(base_hrrr_channels).index(channel) for channel in hrrr_channels - ] - if len(output_hrrr_channels) == 0: - output_hrrr_channels = hrrr_channels.copy() - - diffusion_channels, input_channels = hrrr_channels, hrrr_channels - diffusion_channel_indices = [ - list(hrrr_channels).index(channel) for channel in diffusion_channels - ] - input_channel_indices = [ - list(hrrr_channels).index(channel) for channel in input_channels - ] + if len(cfg.inference.output_hrrr_channels) == 0: + output_hrrr_channels = diffusion_channels.copy() vardict: dict[str, int] = { hrrr_channel: i for i, hrrr_channel in enumerate(hrrr_channels) } - era5_data_path = os.path.join(params.location, "era5", "valid", "2021.zarr") - - era5_data = xr.open_zarr(era5_data_path) - - era5_channels = era5_data.channel.values - - vardict_era5 = {era5_channel: i for i, era5_channel in enumerate(era5_channels)} + vardict_era5 = { + era5_channel: i for i, era5_channel in enumerate(dataset.era5_channels.values) + } color_limits = { "u10m": (-5, 5), @@ -182,16 +98,22 @@ def main( (initial_time - datetime(initial_time.year, 1, 1, 0, 0)).total_seconds() / 3600 ) + hrrr_channel_indices = [ + list(base_hrrr_channels).index(channel) for channel in hrrr_channels + ] means_hrrr = dataset.means_hrrr[hrrr_channel_indices] stds_hrrr = dataset.stds_hrrr[hrrr_channel_indices] - means_era5 = dataset.means_era5 stds_era5 = dataset.stds_era5 + # Load pretrained models + net = Module.from_checkpoint(cfg.inference.regression_checkpoint) + regression_model = net.to(device) + net = Module.from_checkpoint(cfg.inference.diffusion_checkpoint) + diffusion_model = net.to(device) + # initialize zarr - zarr_output_path = os.path.join( - opts.outdir, initial_time.strftime("%Y-%m-%dT%H_%M_%S"), "data.zarr" - ) + zarr_output_path = os.path.join(cfg.inference.rundir, "data.zarr") group = zarr.open_group(zarr_output_path, mode="w") group.array("latitude", data=hrrr_data["latitude"].values) group.array("longitude", data=hrrr_data["longitude"].values) @@ -247,7 +169,9 @@ def main( break hrrr_0 = out - out = model(hrrr_0, boundary, mask=None) + out = regression_model_forward( + regression_model, hrrr_0, boundary, invariant_tensor + ) out_noedm = out.clone() hrrr_0 = torch.cat( ( @@ -256,15 +180,18 @@ def main( ), dim=1, ) - edm_corrected_outputs, _ = edm_runner.run(hrrr_0) - if residual: - out[0, diffusion_channel_indices] += edm_corrected_outputs[0].float() - else: - out[0, diffusion_channel_indices] = edm_corrected_outputs[0].float() + edm_corrected_outputs = diffusion_model_forward( + diffusion_model, + hrrr_0, + diffusion_channel_indices, + invariant_tensor, + sampler_args=dict(cfg.sampler.args), + ) + out[0, diffusion_channel_indices] += edm_corrected_outputs[0].float() out_edm = out.clone() boundary = data["era5"][0].cuda().float().unsqueeze(0) - varidx = vardict[plot_var_hrrr] + varidx = vardict[cfg.inference.plot_var_hrrr] fig, ax = plt.subplots(1, 4, figsize=(20, 5)) @@ -277,12 +204,12 @@ def main( error = pred - tar - if plot_var_hrrr in color_limits: + if cfg.inference.plot_var_hrrr in color_limits: im = ax[0].imshow( pred[0, varidx], origin="lower", cmap="magma", - clim=color_limits[plot_var_hrrr], + clim=color_limits[cfg.inference.plot_var_hrrr], ) else: im = ax[0].imshow(pred[0, varidx], origin="lower", cmap="magma") @@ -290,33 +217,35 @@ def main( fig.colorbar(im, ax=ax[0], fraction=0.046, pad=0.04) ax[0].set_title( "Predicted, {}, \n initial time {} \n lead_time {} hours".format( - plot_var_hrrr, initial_time, i + cfg.inference.plot_var_hrrr, initial_time, i ) ) - if plot_var_hrrr in color_limits: + if cfg.inference.plot_var_hrrr in color_limits: im = ax[1].imshow( tar[0, varidx], origin="lower", cmap="magma", - clim=color_limits[plot_var_hrrr], + clim=color_limits[cfg.inference.plot_var_hrrr], ) else: im = ax[1].imshow(tar[0, varidx], origin="lower", cmap="magma") fig.colorbar(im, ax=ax[1], fraction=0.046, pad=0.04) - ax[1].set_title("Actual, {}".format(plot_var_hrrr)) - if plot_var_era5 in color_limits: + ax[1].set_title("Actual, {}".format(cfg.inference.plot_var_hrrr)) + if cfg.inference.plot_var_era5 in color_limits: im = ax[2].imshow( - era5[0, vardict_era5[plot_var_era5]], + era5[0, vardict_era5[cfg.inference.plot_var_era5]], origin="lower", cmap="magma", - clim=color_limits[plot_var_era5], + clim=color_limits[cfg.inference.plot_var_era5], ) else: im = ax[2].imshow( - era5[0, vardict_era5[plot_var_era5]], origin="lower", cmap="magma" + era5[0, vardict_era5[cfg.inference.plot_var_era5]], + origin="lower", + cmap="magma", ) fig.colorbar(im, ax=ax[2], fraction=0.046, pad=0.04) - ax[2].set_title("ERA5, {}".format(plot_var_era5)) + ax[2].set_title("ERA5, {}".format(cfg.inference.plot_var_era5)) maxerror = np.max(np.abs(error[0, varidx])) im = ax[3].imshow( error[0, varidx], @@ -326,12 +255,9 @@ def main( vmin=-maxerror, ) fig.colorbar(im, ax=ax[3], fraction=0.046, pad=0.04) - ax[3].set_title("Error, {}".format(plot_var_hrrr)) - - plt.savefig(f"{opts.outdir}/out_{i}.png") + ax[3].set_title("Error, {}".format(cfg.inference.plot_var_hrrr)) - # create output name with config, edm_config, initial_time - edm_config = model_info["edm_config_name"] + plt.savefig(f"{cfg.inference.rundir}/out_{i}.png") level_names = [ "1", @@ -504,7 +430,7 @@ def convert_strings_to_ints(string_list): ds_targ["latitude"] = xr.DataArray(lats, dims=("y", "x")) ds_targ = ds_targ.assign_coords(levels=model_levels) - ds_out_path = os.path.join(opts.outdir, initial_time.strftime("%Y-%m-%dT%H_%M_%S")) + ds_out_path = cfg.inference.rundir ds_pred_edm.to_netcdf(f"{ds_out_path}/ds_pred_edm.nc", format="NETCDF4") ds_pred_noedm.to_netcdf(f"{ds_out_path}/ds_pred_noedm.nc", format="NETCDF4") diff --git a/examples/generative/stormcast/train.py b/examples/generative/stormcast/train.py index 9d4a66d698..aa455f9c9e 100644 --- a/examples/generative/stormcast/train.py +++ b/examples/generative/stormcast/train.py @@ -18,180 +18,74 @@ paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" import os -import re -import json +import hydra import torch import wandb import glob -import argparse +from omegaconf import DictConfig, OmegaConf from modulus.distributed import DistributedManager +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper -from utils.misc import EasyDict, print0 -from utils.diffusions import training_loop +from utils.trainer import training_loop -def main(**kwargs): +@hydra.main(version_base=None, config_path="config", config_name="regression") +def main(cfg: DictConfig) -> None: """Train regression or diffusion models for use in the StormCast (https://arxiv.org/abs/2408.10958) ML-based weather model""" - parser = argparse.ArgumentParser( - description="Train regression or diffusion models for use in StormCast" - ) - - # Main options. - parser.add_argument( - "--outdir", - help="Where to save the results", - metavar="DIR", - type=str, - required=True, - ) - parser.add_argument( - "--config_file", - help="Path to config file", - metavar="FILE", - type=str, - required=True, - ) - parser.add_argument( - "--config_name", - help="Name of config to use", - metavar="NAME", - type=str, - required=True, - ) - parser.add_argument( - "--log_to_wandb", help="Log to wandb", default=False, action="store_true" - ) - parser.add_argument( - "--run_id", help="run id", metavar="INT", type=int, default=None - ) - - # Performance-related. - parser.add_argument( - "--bench", - help="Enable cuDNN benchmarking", - metavar="BOOL", - type=bool, - default=True, - ) - - # I/O-related. - parser.add_argument( - "--desc", help="String to include in result dir name", metavar="STR", type=str - ) - parser.add_argument( - "--dump", help="How often to dump state", metavar="TICKS", type=int, default=10 - ) - parser.add_argument( - "--seed", help="Random seed [default: random]", metavar="INT", type=int - ) - parser.add_argument( - "--resume", help="Resume from previous training state", metavar="PT", type=str - ) - parser.add_argument( - "-n", "--dry-run", help="Print training options and exit", action="store_true" - ) - # Initialize - opts = parser.parse_args() DistributedManager.initialize() dist = DistributedManager() - - # Initialize config dict. - c = EasyDict() - - # Training options. - c.optimizer_kwargs = EasyDict(betas=[0.9, 0.999], eps=1e-8) - c.update(cudnn_benchmark=opts.bench, state_dump_ticks=opts.dump) + logger = PythonLogger("main") + logger0 = RankZeroLoggingWrapper(logger, dist) # Log only from rank 0 # Random seed. - if opts.seed is not None: - c.seed = opts.seed - else: + if cfg.training.seed < 0: seed = torch.randint(1 << 31, size=[], device=torch.device("cuda")) torch.distributed.broadcast(seed, src=0) - c.seed = int(seed) - - # Description string. - desc = f"hrrr-gpus{dist.world_size:d}" - if opts.desc is not None: - desc += f"-{opts.desc}" - - desc = opts.config_name + "-" + desc + cfg.training.seed = int(seed) - # Pick output directory. - cur_run_id = opts.run_id if opts.run_id is not None else 0 - c.run_dir = os.path.join(opts.outdir, f"{cur_run_id}-{desc}") + # Resume from specified checkpoint, if provided + if cfg.training.resume_checkpoint is not None: + resume = cfg.training.resume_checkpoint + if not os.path.isfile(resume) or not resume.endswith(".pt"): + raise ValueError( + "training.resume_checkpoint must point to a modulus .pt checkpoint from a previous training run" + ) - # if run_dir exists, then resume training - if os.path.exists(c.run_dir): + # If run directory already exists, then resume training from last checkpoint + wandb_resume = False + if os.path.exists(cfg.training.rundir): training_states = sorted( - glob.glob(os.path.join(c.run_dir, "training-state-*.pt")) + glob.glob(os.path.join(cfg.training.rundir, "checkpoints/checkpoint*.pt")) ) if training_states: - print0("Resuming training from previous run_dir: " + c.run_dir) - last_training_state = sorted( - glob.glob(os.path.join(c.run_dir, "training-state-*.pt")) - )[-1] - last_kimg = int( - re.fullmatch( - r"training-state-(\d+).pt", os.path.basename(last_training_state) - ).group(1) + logger0.info( + "Resuming training from previous run_dir: " + cfg.training.rundir ) - c.resume_kimg = last_kimg - c.resume_state_dump = last_training_state - print0( - "Resuming training from previous training-state-*.pt file: " + last_training_state = training_states[-1] + cfg.training.resume_checkpoint = last_training_state + logger0.info( + "Resuming training from previous checkpoint file: " + last_training_state ) - - if opts.resume is not None: - match = re.fullmatch(r"training-state-(\d+).pt", os.path.basename(opts.resume)) - if not match or not os.path.isfile(opts.resume): - raise ValueError( - "--resume must point to training-state-*.pt from a previous training run" - ) - c.resume_kimg = int(match.group(1)) - c.resume_state_dump = opts.resume - - # Print options. - if opts.dry_run: - print0("Dry run; exiting.") - return - - # Create output directory. - print0("Creating output directory...") - if dist.rank == 0: - os.makedirs(c.run_dir, exist_ok=True) - with open(os.path.join(c.run_dir, "training_options.json"), "wt") as f: - json.dump(c, f, indent=2) - # Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) - - if opts.log_to_wandb: - entity, project = "nv-research-climate", "hrrr" - entity = entity - wandb_project = project - wandb_name = opts.config_name + "_" + desc - wandb_group = opts.config_name + "_" + str(cur_run_id) - os.makedirs(os.path.join(c.run_dir, "wandb"), exist_ok=True) - wandb.init( - dir=os.path.join(c.run_dir, "wandb"), - config=c, - name=wandb_name, - group=wandb_group, - project=wandb_project, - entity=entity, - resume=opts.resume, - mode="online", - ) - - # config options - c.config_file = opts.config_file - c.config_name = opts.config_name - c.log_to_wandb = opts.log_to_wandb + wandb_resume = True + + # Setup wandb, if enabled + if dist.rank == 0 and cfg.training.log_to_wandb: + entity, project = "wandb_entity", "wandb_project" + wandb.init( + dir=cfg.training.rundir, + config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), + name=os.path.basename(cfg.training.rundir), + project=project, + entity=entity, + resume=wandb_resume, + mode="online", + ) # Train. - training_loop.training_loop(**c) + training_loop(cfg) # ---------------------------------------------------------------------------- diff --git a/examples/generative/stormcast/utils/YParams.py b/examples/generative/stormcast/utils/YParams.py deleted file mode 100644 index f7c504054d..0000000000 --- a/examples/generative/stormcast/utils/YParams.py +++ /dev/null @@ -1,75 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ruamel.yaml import YAML -import logging - - -class YParams: - """Yaml file parser""" - - def __init__(self, yaml_filename, config_name, print_params=False): - self._yaml_filename = yaml_filename - self._config_name = config_name - self.params = {} - - if print_params: - print("------------------ Configuration ------------------") - - with open(yaml_filename) as _file: - - for key, val in YAML().load(_file)[config_name].items(): - if print_params: - print(key, val) - if val == "None": - val = None - - self.params[key] = val - self.__setattr__(key, val) - - if print_params: - print("---------------------------------------------------") - - # override setattr now so both the dict and the attrs get updated - self.__setattr__ = self.__custom_setattr__ - - def __custom_setattr__(self, key, val): - self.params[key] = val - super().__setattr__(key, val) - - def __getitem__(self, key): - return self.params[key] - - def __setitem__(self, key, val): - self.params[key] = val - self.__setattr__(key, val) - - def __contains__(self, key): - return key in self.params - - def update_params(self, config): - """Update the params according to configuraiton dict config""" - for key, val in config.items(): - self.params[key] = val - self.__setattr__(key, val) - - def log(self): - logging.info("------------------ Configuration ------------------") - logging.info("Configuration file: " + str(self._yaml_filename)) - logging.info("Configuration name: " + str(self._config_name)) - for key, val in self.params.items(): - logging.info(str(key) + " " + str(val)) - logging.info("---------------------------------------------------") diff --git a/examples/generative/stormcast/utils/data_loader_hrrr_era5.py b/examples/generative/stormcast/utils/data_loader_hrrr_era5.py index e245318c69..daf32005a5 100644 --- a/examples/generative/stormcast/utils/data_loader_hrrr_era5.py +++ b/examples/generative/stormcast/utils/data_loader_hrrr_era5.py @@ -15,31 +15,25 @@ # limitations under the License. import os -import logging import glob import torch import numpy as np from torch.utils.data import Dataset -from utils.misc import print0 +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper +from modulus.distributed import DistributedManager from datetime import datetime, timedelta import dask import xarray as xr -def worker_init(wrk_id): - np.random.seed(torch.utils.data.get_worker_info().seed % (2**32 - 1)) +logger = PythonLogger("dataset") -def get_dataset(params, train): - """Return the train or validation dataset specified by params""" - return HrrrEra5DatasetForecast( - params, - train=train, - location=params.location, - ) +def worker_init(wrk_id): + np.random.seed(torch.utils.data.get_worker_info().seed % (2**32 - 1)) -class HrrrEra5DatasetForecast(Dataset): +class HrrrEra5Dataset(Dataset): """ Paired dataset object serving time-synchronized pairs of ERA5 and HRRR samples Expects data to be stored under directory specified by 'location' with the @@ -70,12 +64,16 @@ class HrrrEra5DatasetForecast(Dataset): year containing the data of interest. """ - def __init__(self, params, train, location: str): + def __init__(self, params, train): + + dist = DistributedManager() + self.logger0 = RankZeroLoggingWrapper(logger, dist) + dask.config.set( scheduler="synchronous" ) # for threadsafe multiworker dataloaders self.params = params - self.location = location + self.location = self.params.location self.train = train self.path_suffix = "train" if train else "valid" self.dt = params.dt @@ -147,7 +145,7 @@ def _get_files_stats(self): self.era5_paths, key=lambda x: int(os.path.basename(x).replace(".zarr", "")) ) - print0("list of all era5 paths: ", self.era5_paths) + self.logger0.info(f"list of all era5 paths: {self.era5_paths}") if self.train: # keep only years specified in the params.train_years list @@ -172,7 +170,7 @@ def _get_files_stats(self): int(os.path.basename(x).replace(".zarr", "")) for x in self.era5_paths ] - print0("list of all era5 paths after filtering: ", self.era5_paths) + self.logger0.info(f"list of all era5 paths after filtering: {self.era5_paths}") self.n_years = len(self.era5_paths) with xr.open_zarr(self.era5_paths[0], consolidated=True) as ds: @@ -191,7 +189,7 @@ def _get_files_stats(self): os.path.join(self.location, self.conus_dataset_name, "**", "????.zarr"), recursive=True, ) - print0("list of all hrrr paths: ", self.hrrr_paths) + self.logger0.info(f"list of all hrrr paths: {self.hrrr_paths}") self.hrrr_paths = sorted( self.hrrr_paths, key=lambda x: int(os.path.basename(x).replace(".zarr", "")) ) @@ -218,11 +216,11 @@ def _get_files_stats(self): int(os.path.basename(x).replace(".zarr", "")) for x in self.hrrr_paths ] - print0("list of all hrrr paths after filtering: ", self.hrrr_paths) + self.logger0.info(f"list of all hrrr paths after filtering: {self.hrrr_paths}") years = [int(os.path.basename(x).replace(".zarr", "")) for x in self.hrrr_paths] - print0("years: ", years) - print0("self.years: ", self.years) + self.logger0.info(f"years: {years}") + self.logger0.info(f"self.years: {self.years}") assert ( years == self.years ), "Number of years for ERA5 in %s and HRRR in %s must match" % ( @@ -273,12 +271,12 @@ def compute_total_samples(self): first_sample = datetime( year=first_year, month=8, day=1, hour=1, minute=0, second=0 ) # marks transition of hrrr model version - logging.info("First sample is {}".format(first_sample)) + self.logger0.info("First sample is {}".format(first_sample)) else: first_sample = datetime( year=first_year, month=1, day=1, hour=0, minute=0, second=0 ) - logging.info("First sample is {}".format(first_sample)) + self.logger0.info("First sample is {}".format(first_sample)) last_sample = datetime( year=last_year, month=12, day=31, hour=23, minute=0, second=0 @@ -287,7 +285,7 @@ def compute_total_samples(self): last_sample = datetime( year=last_year, month=12, day=15, hour=0, minute=0, second=0 ) - logging.info("Last sample is {}".format(last_sample)) + self.logger0.info("Last sample is {}".format(last_sample)) all_datetimes = [ first_sample + timedelta(hours=x) for x in range(int((last_sample - first_sample).total_seconds() / 3600) + 1) @@ -310,7 +308,7 @@ def compute_total_samples(self): and (x + timedelta(hours=self.dt) not in missing_samples) ] - logging.info( + self.logger0.info( "Total datetimes in training set are {} of which {} are valid".format( len(all_datetimes), len(self.valid_samples) ) @@ -414,7 +412,7 @@ def _construct_era5_window(self): Build custom indexing window to subselect HRRR region from ERA5 """ - logging.info( + self.logger0.info( "Constructing ERA5 window, extending HRRR domain by {} pixels in each direction".format( self.boundary_padding_pixels ) @@ -459,7 +457,7 @@ def _construct_era5_window(self): added_pixels_x = new_x.shape[1] - self.hrrr_lon.shape[1] added_pixels_y = new_x.shape[0] - self.hrrr_lon.shape[0] - logging.info( + self.logger0.info( "Added {} pixels in x, {} pixels in y".format( added_pixels_x, added_pixels_y ) diff --git a/examples/generative/stormcast/utils/diffusions/generate.py b/examples/generative/stormcast/utils/diffusions/generate.py deleted file mode 100644 index 048840991b..0000000000 --- a/examples/generative/stormcast/utils/diffusions/generate.py +++ /dev/null @@ -1,85 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Generate random images using the techniques described in the paper -"Elucidating the Design Space of Diffusion-Based Generative Models".""" - -import numpy as np -import torch - -# ---------------------------------------------------------------------------- -# Proposed EDM sampler (Algorithm 2). - - -def edm_sampler( - net, - latents, - condition=None, - class_labels=None, - randn_like=torch.randn_like, - num_steps=18, - sigma_min=0.002, - sigma_max=800, - rho=7, - S_churn=0, - S_min=0, - S_max=float("inf"), - S_noise=1, - **kwargs, -): - # Adjust noise levels based on what's supported by the network. - sigma_min = max(sigma_min, net.sigma_min) - sigma_max = min(sigma_max, net.sigma_max) - - # Time step discretization. - step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) - t_steps = ( - sigma_max ** (1 / rho) - + step_indices - / (num_steps - 1) - * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) - ) ** rho - t_steps = torch.cat( - [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] - ) # t_N = 0 - - # Main sampling loop. - x_next = latents.to(torch.float64) * t_steps[0] - for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 - x_cur = x_next - - # Increase noise temporarily. - gamma = ( - min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 - ) - t_hat = net.round_sigma(t_cur + gamma * t_cur) - x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) - - # Euler step. - denoised = net(x_hat, t_hat, class_labels=class_labels, condition=condition).to( - torch.float64 - ) - d_cur = (x_hat - denoised) / t_hat - x_next = x_hat + (t_next - t_hat) * d_cur - - # Apply 2nd order correction. - if i < num_steps - 1: - denoised = net( - x_next, t_next, class_labels=class_labels, condition=condition - ).to(torch.float64) - d_prime = (x_next - denoised) / t_next - x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) - - return x_next diff --git a/examples/generative/stormcast/utils/diffusions/losses.py b/examples/generative/stormcast/utils/diffusions/losses.py deleted file mode 100644 index 777b998bc5..0000000000 --- a/examples/generative/stormcast/utils/diffusions/losses.py +++ /dev/null @@ -1,100 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from utils.diffusions import networks - - -class EDMLoss: - """Improved loss function proposed in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" (EDM).""" - - def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): - self.P_mean = P_mean - self.P_std = P_std - self.sigma_data = sigma_data - print("using P_mean value of {}".format(P_mean)) - - def __call__( - self, - net: networks.EDMPrecond, - x, - condition=None, - class_labels=None, - augment_pipe=None, - ): - """ - Args: - net: - x: The latent data (to be denoised). shape [batch_size, target_channels, w, h] - class_labels: optional, shape [batch_size, label_dim] - condition: optional, the conditional inputs, - shape=[batch_size, condition_channel, w, h] - Returns: - out: loss function with shape [batch_size, target_channels, w, h] - This should be averaged to get the mean loss for gradient descent. - """ - rnd_normal = torch.randn([x.shape[0], 1, 1, 1], device=x.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() - weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - y, augment_labels = augment_pipe(x) if augment_pipe is not None else (x, None) - n = torch.randn_like(y) * sigma - D_yn = net( - y + n, - sigma, - condition=condition, - class_labels=class_labels, - augment_labels=augment_labels, - ) - loss = weight * ((D_yn - y) ** 2) - return loss - - -class RegressionLossV2: - """Loss wrapper for training the StormCast regression model, so that it has a similar call signature as - the EDMLoss and the same training loop can be used to train both regression and diffusion models""" - - def __call__( - self, - net: networks.RegressionWrapperV2, - x, - condition=None, - class_labels=None, - augment_pipe=None, - ): - """ - Args: - net: - x: The latent data (to be denoised). shape [batch_size, target_channels, w, h] - class_labels: optional, shape [batch_size, label_dim] - condition: optional, the conditional inputs, - shape=[batch_size, condition_channel, w, h] - Returns: - out: loss function with shape [batch_size, target_channels, w, h] - This should be averaged to get the mean loss for gradient descent. - """ - - sigma = torch.ones([x.shape[0], 1, 1, 1], device=x.device) - weight = 1.0 # (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - y, augment_labels = augment_pipe(x) if augment_pipe is not None else (x, None) - - D_yn = net( - sigma, - condition=condition, - class_labels=class_labels, - augment_labels=augment_labels, - ) - loss = weight * ((D_yn - y) ** 2) - return loss diff --git a/examples/generative/stormcast/utils/diffusions/networks.py b/examples/generative/stormcast/utils/diffusions/networks.py deleted file mode 100644 index fcd48fcc16..0000000000 --- a/examples/generative/stormcast/utils/diffusions/networks.py +++ /dev/null @@ -1,1043 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Model architectures and preconditioning schemes used in the paper -# "Elucidating the Design Space of Diffusion-Based Generative Models". - -import numpy as np -import torch -from torch.nn.functional import silu -import einops - -# ---------------------------------------------------------------------------- -# Unified routine for initializing weights and biases. - - -def weight_init(shape, mode, fan_in, fan_out): - if mode == "xavier_uniform": - return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) - if mode == "xavier_normal": - return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) - if mode == "kaiming_uniform": - return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) - if mode == "kaiming_normal": - return np.sqrt(1 / fan_in) * torch.randn(*shape) - raise ValueError(f'Invalid init mode "{mode}"') - - -class Linear(torch.nn.Module): - """Fully-connected layer""" - - def __init__( - self, - in_features, - out_features, - bias=True, - init_mode="kaiming_normal", - init_weight=1, - init_bias=0, - ): - super().__init__() - self.in_features = in_features - self.out_features = out_features - init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) - self.weight = torch.nn.Parameter( - weight_init([out_features, in_features], **init_kwargs) * init_weight - ) - self.bias = ( - torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) - if bias - else None - ) - - def forward(self, x): - x = x @ self.weight.to(x.dtype).t() - if self.bias is not None: - x = x.add_(self.bias.to(x.dtype)) - return x - - -class Conv2d(torch.nn.Module): - """Convolutional layer with optional up/downsampling""" - - def __init__( - self, - in_channels, - out_channels, - kernel, - bias=True, - up=False, - down=False, - resample_filter=[1, 1], - fused_resample=False, - init_mode="kaiming_normal", - init_weight=1, - init_bias=0, - ): - assert not (up and down) - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.up = up - self.down = down - self.fused_resample = fused_resample - init_kwargs = dict( - mode=init_mode, - fan_in=in_channels * kernel * kernel, - fan_out=out_channels * kernel * kernel, - ) - self.weight = ( - torch.nn.Parameter( - weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) - * init_weight - ) - if kernel - else None - ) - self.bias = ( - torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) - if kernel and bias - else None - ) - f = torch.as_tensor(resample_filter, dtype=torch.float32) - f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() - self.register_buffer("resample_filter", f if up or down else None) - - def forward(self, x): - w = self.weight.to(x.dtype) if self.weight is not None else None - b = self.bias.to(x.dtype) if self.bias is not None else None - f = ( - self.resample_filter.to(x.dtype) - if self.resample_filter is not None - else None - ) - w_pad = w.shape[-1] // 2 if w is not None else 0 - f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 - - if self.fused_resample and self.up and w is not None: - x = torch.nn.functional.conv_transpose2d( - x, - f.mul(4).tile([self.in_channels, 1, 1, 1]), - groups=self.in_channels, - stride=2, - padding=max(f_pad - w_pad, 0), - ) - x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) - elif self.fused_resample and self.down and w is not None: - x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) - x = torch.nn.functional.conv2d( - x, - f.tile([self.out_channels, 1, 1, 1]), - groups=self.out_channels, - stride=2, - ) - else: - if self.up: - x = torch.nn.functional.conv_transpose2d( - x, - f.mul(4).tile([self.in_channels, 1, 1, 1]), - groups=self.in_channels, - stride=2, - padding=f_pad, - ) - if self.down: - x = torch.nn.functional.conv2d( - x, - f.tile([self.in_channels, 1, 1, 1]), - groups=self.in_channels, - stride=2, - padding=f_pad, - ) - if w is not None: - x = torch.nn.functional.conv2d(x, w, padding=w_pad) - if b is not None: - x = x.add_(b.reshape(1, -1, 1, 1)) - return x - - -class GroupNorm(torch.nn.Module): - "Group Normalization layer" - - def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5): - super().__init__() - self.num_groups = min(num_groups, num_channels // min_channels_per_group) - self.eps = eps - self.weight = torch.nn.Parameter(torch.ones(num_channels)) - self.bias = torch.nn.Parameter(torch.zeros(num_channels)) - - def forward(self, x): - x = torch.nn.functional.group_norm( - x, - num_groups=self.num_groups, - weight=self.weight.to(x.dtype), - bias=self.bias.to(x.dtype), - eps=self.eps, - ) - return x - - -class AttentionOp(torch.autograd.Function): - """Attention weight computation, i.e., softmax(Q^T * K). - Performs all computation using FP32, but uses the original datatype for - inputs/outputs/gradients to conserve memory.""" - - @staticmethod - def forward(ctx, q, k): - w = ( - torch.einsum( - "ncq,nck->nqk", - q.to(torch.float32), - (k / np.sqrt(k.shape[1])).to(torch.float32), - ) - .softmax(dim=2) - .to(q.dtype) - ) - ctx.save_for_backward(q, k, w) - return w - - @staticmethod - def backward(ctx, dw): - q, k, w = ctx.saved_tensors - db = torch._softmax_backward_data( - grad_output=dw.to(torch.float32), - output=w.to(torch.float32), - dim=2, - input_dtype=torch.float32, - ) - dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( - q.dtype - ) / np.sqrt(k.shape[1]) - dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to( - k.dtype - ) / np.sqrt(k.shape[1]) - return dq, dk - - -# ---------------------------------------------------------------------------- -# - - -class UNetBlock(torch.nn.Module): - """Unified U-Net block with optional up/downsampling and self-attention. - Represents the union of all features employed by the DDPM++, NCSN++, and - ADM architectures.""" - - def __init__( - self, - in_channels, - out_channels, - emb_channels, - up=False, - down=False, - attention=False, - num_heads=None, - channels_per_head=64, - dropout=0, - skip_scale=1, - eps=1e-5, - resample_filter=[1, 1], - resample_proj=False, - adaptive_scale=True, - init=dict(), - init_zero=dict(init_weight=0), - init_attn=None, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.emb_channels = emb_channels - self.num_heads = ( - 0 - if not attention - else num_heads - if num_heads is not None - else out_channels // channels_per_head - ) - self.dropout = dropout - self.skip_scale = skip_scale - self.adaptive_scale = adaptive_scale - - self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) - self.conv0 = Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel=3, - up=up, - down=down, - resample_filter=resample_filter, - **init, - ) - self.affine = Linear( - in_features=emb_channels, - out_features=out_channels * (2 if adaptive_scale else 1), - **init, - ) - self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) - self.conv1 = Conv2d( - in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero - ) - - self.skip = None - if out_channels != in_channels or up or down: - kernel = 1 if resample_proj or out_channels != in_channels else 0 - self.skip = Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel=kernel, - up=up, - down=down, - resample_filter=resample_filter, - **init, - ) - - if self.num_heads: - self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) - self.qkv = Conv2d( - in_channels=out_channels, - out_channels=out_channels * 3, - kernel=1, - **(init_attn if init_attn is not None else init), - ) - self.proj = Conv2d( - in_channels=out_channels, - out_channels=out_channels, - kernel=1, - **init_zero, - ) - - def forward(self, x, emb): - orig = x - x = self.conv0(silu(self.norm0(x))) - - params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) - if self.adaptive_scale: - scale, shift = params.chunk(chunks=2, dim=1) - x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) - else: - x = silu(self.norm1(x.add_(params))) - - x = self.conv1( - torch.nn.functional.dropout(x, p=self.dropout, training=self.training) - ) - x = x.add_(self.skip(orig) if self.skip is not None else orig) - x = x * self.skip_scale - - if self.num_heads: - q, k, v = ( - self.qkv(self.norm2(x)) - .reshape( - x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 - ) - .unbind(2) - ) - w = AttentionOp.apply(q, k) - a = torch.einsum("nqk,nck->ncq", w, v) - x = self.proj(a.reshape(*x.shape)).add_(x) - x = x * self.skip_scale - return x - - -class PositionalEmbedding(torch.nn.Module): - """Timestep embedding used in the DDPM++ and ADM architectures.""" - - def __init__(self, num_channels, max_positions=10000, endpoint=False): - super().__init__() - self.num_channels = num_channels - self.max_positions = max_positions - self.endpoint = endpoint - - def forward(self, x): - freqs = torch.arange( - start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device - ) - freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) - freqs = (1 / self.max_positions) ** freqs - x = x.ger(freqs.to(x.dtype)) - x = torch.cat([x.cos(), x.sin()], dim=1) - return x - - -class FourierEmbedding(torch.nn.Module): - """Timestep embedding used in the NCSN++ architecture.""" - - def __init__(self, num_channels, scale=16): - super().__init__() - self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) - - def forward(self, x): - x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) - x = torch.cat([x.cos(), x.sin()], dim=1) - return x - - -class SongUNetRegression(torch.nn.Module): - """U-Net implementation of the ADM architecture from the paper "Diffusion Models Beat GANS on Image Synthesis", - modified to be used as the regression model in StormCast (removes the diffusion time embedding ops in the network) - """ - - def __init__( - self, - img_resolution, # Image resolution at input/output. - in_channels, # Number of color channels at input. - out_channels, # Number of color channels at output. - label_dim=0, # Number of class labels, 0 = unconditional. - augment_dim=0, # Augmentation label dimensionality, 0 = no augmentation. - model_channels=128, # Base multiplier for the number of channels. - channel_mult=[ - 1, - 2, - 2, - 2, - ], # Per-resolution multipliers for the number of channels. - channel_mult_emb=4, # Multiplier for the dimensionality of the embedding vector. - num_blocks=4, # Number of residual blocks per resolution. - attn_resolutions=[16], # List of resolutions with self-attention. - dropout=0.10, # Dropout probability of intermediate activations. - label_dropout=0, # Dropout probability of class labels for classifier-free guidance. - embedding_type="zero", # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. - channel_mult_noise=1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++. - encoder_type="standard", # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. - decoder_type="standard", # Decoder architecture: 'standard' for both DDPM++ and NCSN++. - resample_filter=[ - 1, - 1, - ], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - spatial_embedding=None, # None or 'add' or 'concat' - hrrr_resolution=(512, 640), - ): - assert embedding_type in ["zero"] - assert encoder_type in ["standard", "skip", "residual"] - assert decoder_type in ["standard", "skip"] - - super().__init__() - self.label_dropout = label_dropout - self.emb_channels = model_channels * channel_mult_emb - noise_channels = model_channels * channel_mult_noise - init = dict(init_mode="xavier_uniform") - init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) - init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) - block_kwargs = dict( - emb_channels=self.emb_channels, - num_heads=1, - dropout=dropout, - skip_scale=np.sqrt(0.5), - eps=1e-6, - resample_filter=resample_filter, - resample_proj=True, - adaptive_scale=False, - init=init, - init_zero=init_zero, - init_attn=init_attn, - ) - - # Encoder. - self.enc = torch.nn.ModuleDict() - cout = in_channels - caux = in_channels - for level, mult in enumerate(channel_mult): - res = img_resolution >> level - if level == 0: - cin = cout - cout = model_channels - self.enc[f"{res}x{res}_conv"] = Conv2d( - in_channels=cin, out_channels=cout, kernel=3, **init - ) - else: - self.enc[f"{res}x{res}_down"] = UNetBlock( - in_channels=cout, out_channels=cout, down=True, **block_kwargs - ) - if encoder_type == "skip": - self.enc[f"{res}x{res}_aux_down"] = Conv2d( - in_channels=caux, - out_channels=caux, - kernel=0, - down=True, - resample_filter=resample_filter, - ) - self.enc[f"{res}x{res}_aux_skip"] = Conv2d( - in_channels=caux, out_channels=cout, kernel=1, **init - ) - if encoder_type == "residual": - self.enc[f"{res}x{res}_aux_residual"] = Conv2d( - in_channels=caux, - out_channels=cout, - kernel=3, - down=True, - resample_filter=resample_filter, - fused_resample=True, - **init, - ) - caux = cout - for idx in range(num_blocks): - cin = cout - cout = model_channels * mult - attn = res in attn_resolutions - self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( - in_channels=cin, out_channels=cout, attention=attn, **block_kwargs - ) - skips = [ - block.out_channels for name, block in self.enc.items() if "aux" not in name - ] - - # Decoder. - self.dec = torch.nn.ModuleDict() - for level, mult in reversed(list(enumerate(channel_mult))): - res = img_resolution >> level - if level == len(channel_mult) - 1: - self.dec[f"{res}x{res}_in0"] = UNetBlock( - in_channels=cout, out_channels=cout, attention=True, **block_kwargs - ) - self.dec[f"{res}x{res}_in1"] = UNetBlock( - in_channels=cout, out_channels=cout, **block_kwargs - ) - else: - self.dec[f"{res}x{res}_up"] = UNetBlock( - in_channels=cout, out_channels=cout, up=True, **block_kwargs - ) - for idx in range(num_blocks + 1): - cin = cout + skips.pop() - cout = model_channels * mult - attn = idx == num_blocks and res in attn_resolutions - self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( - in_channels=cin, out_channels=cout, attention=attn, **block_kwargs - ) - if decoder_type == "skip" or level == 0: - if decoder_type == "skip" and level < len(channel_mult) - 1: - self.dec[f"{res}x{res}_aux_up"] = Conv2d( - in_channels=out_channels, - out_channels=out_channels, - kernel=0, - up=True, - resample_filter=resample_filter, - ) - self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( - num_channels=cout, eps=1e-6 - ) - self.dec[f"{res}x{res}_aux_conv"] = Conv2d( - in_channels=cout, out_channels=out_channels, kernel=3, **init_zero - ) - - def forward(self, x, noise_labels, class_labels, augment_labels=None): - - emb = torch.zeros((noise_labels.shape[0], self.emb_channels), device=x.device) - - # Encoder. - skips = [] - aux = x - for name, block in self.enc.items(): - if "aux_down" in name: - aux = block(aux) - elif "aux_skip" in name: - x = skips[-1] = x + block(aux) - elif "aux_residual" in name: - x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) - else: - x = block(x, emb) if isinstance(block, UNetBlock) else block(x) - skips.append(x) - - # Decoder. - aux = None - tmp = None - for name, block in self.dec.items(): - if "aux_up" in name: - aux = block(aux) - elif "aux_norm" in name: - tmp = block(x) - elif "aux_conv" in name: - tmp = block(silu(tmp)) - aux = tmp if aux is None else tmp + aux - else: - if x.shape[1] != block.in_channels: - x = torch.cat([x, skips.pop()], dim=1) - x = block(x, emb) - - return aux - - -class SongUNet(torch.nn.Module): - """Reimplementation of the ADM architecture from the paper "Diffusion Models Beat GANS on Image Synthesis". - Equivalent to the original implementation by Dhariwal and Nichol, available at https://github.com/openai/guided-diffusion - Modified to include a learned additive position embedding at the input to the U-Net - """ - - def __init__( - self, - img_resolution, # Image resolution at input/output. - in_channels, # Number of color channels at input. - out_channels, # Number of color channels at output. - label_dim=0, # Number of class labels, 0 = unconditional. - augment_dim=0, # Augmentation label dimensionality, 0 = no augmentation. - model_channels=128, # Base multiplier for the number of channels. - channel_mult=[ - 1, - 2, - 2, - 2, - ], # Per-resolution multipliers for the number of channels. - channel_mult_emb=4, # Multiplier for the dimensionality of the embedding vector. - num_blocks=4, # Number of residual blocks per resolution. - attn_resolutions=[16], # List of resolutions with self-attention. - dropout=0.10, # Dropout probability of intermediate activations. - label_dropout=0, # Dropout probability of class labels for classifier-free guidance. - embedding_type="positional", # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. - channel_mult_noise=1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++. - encoder_type="standard", # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. - decoder_type="standard", # Decoder architecture: 'standard' for both DDPM++ and NCSN++. - resample_filter=[ - 1, - 1, - ], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - spatial_embedding=False, # True or False - hrrr_resolution=(512, 640), - ): - assert embedding_type in ["fourier", "positional"] - assert encoder_type in ["standard", "skip", "residual"] - assert decoder_type in ["standard", "skip"] - - super().__init__() - self.label_dropout = label_dropout - self.spatial_embedding = spatial_embedding - - if spatial_embedding: - self.spatial_emb = torch.nn.Parameter( - torch.randn(1, model_channels, *hrrr_resolution) - ) - torch.nn.init.trunc_normal_(self.spatial_emb, std=0.02) - - emb_channels = model_channels * channel_mult_emb - noise_channels = model_channels * channel_mult_noise - init = dict(init_mode="xavier_uniform") - init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) - init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) - block_kwargs = dict( - emb_channels=emb_channels, - num_heads=1, - dropout=dropout, - skip_scale=np.sqrt(0.5), - eps=1e-6, - resample_filter=resample_filter, - resample_proj=True, - adaptive_scale=False, - init=init, - init_zero=init_zero, - init_attn=init_attn, - ) - - # Mapping. - self.map_noise = ( - PositionalEmbedding(num_channels=noise_channels, endpoint=True) - if embedding_type == "positional" - else FourierEmbedding(num_channels=noise_channels) - ) - self.map_label = ( - Linear(in_features=label_dim, out_features=noise_channels, **init) - if label_dim - else None - ) - self.map_augment = ( - Linear( - in_features=augment_dim, out_features=noise_channels, bias=False, **init - ) - if augment_dim - else None - ) - self.map_layer0 = Linear( - in_features=noise_channels, out_features=emb_channels, **init - ) - self.map_layer1 = Linear( - in_features=emb_channels, out_features=emb_channels, **init - ) - - # Encoder. - self.enc = torch.nn.ModuleDict() - cout = in_channels - caux = in_channels - for level, mult in enumerate(channel_mult): - res = img_resolution >> level - if level == 0: - cin = cout - cout = model_channels - self.enc[f"{res}x{res}_conv"] = Conv2d( - in_channels=cin, out_channels=cout, kernel=3, **init - ) - else: - self.enc[f"{res}x{res}_down"] = UNetBlock( - in_channels=cout, out_channels=cout, down=True, **block_kwargs - ) - if encoder_type == "skip": - self.enc[f"{res}x{res}_aux_down"] = Conv2d( - in_channels=caux, - out_channels=caux, - kernel=0, - down=True, - resample_filter=resample_filter, - ) - self.enc[f"{res}x{res}_aux_skip"] = Conv2d( - in_channels=caux, out_channels=cout, kernel=1, **init - ) - if encoder_type == "residual": - self.enc[f"{res}x{res}_aux_residual"] = Conv2d( - in_channels=caux, - out_channels=cout, - kernel=3, - down=True, - resample_filter=resample_filter, - fused_resample=True, - **init, - ) - caux = cout - for idx in range(num_blocks): - cin = cout - cout = model_channels * mult - attn = res in attn_resolutions - self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( - in_channels=cin, out_channels=cout, attention=attn, **block_kwargs - ) - skips = [ - block.out_channels for name, block in self.enc.items() if "aux" not in name - ] - - # Decoder. - self.dec = torch.nn.ModuleDict() - for level, mult in reversed(list(enumerate(channel_mult))): - res = img_resolution >> level - if level == len(channel_mult) - 1: - self.dec[f"{res}x{res}_in0"] = UNetBlock( - in_channels=cout, out_channels=cout, attention=True, **block_kwargs - ) - self.dec[f"{res}x{res}_in1"] = UNetBlock( - in_channels=cout, out_channels=cout, **block_kwargs - ) - else: - self.dec[f"{res}x{res}_up"] = UNetBlock( - in_channels=cout, out_channels=cout, up=True, **block_kwargs - ) - for idx in range(num_blocks + 1): - cin = cout + skips.pop() - cout = model_channels * mult - attn = idx == num_blocks and res in attn_resolutions - self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( - in_channels=cin, out_channels=cout, attention=attn, **block_kwargs - ) - if decoder_type == "skip" or level == 0: - if decoder_type == "skip" and level < len(channel_mult) - 1: - self.dec[f"{res}x{res}_aux_up"] = Conv2d( - in_channels=out_channels, - out_channels=out_channels, - kernel=0, - up=True, - resample_filter=resample_filter, - ) - self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( - num_channels=cout, eps=1e-6 - ) - self.dec[f"{res}x{res}_aux_conv"] = Conv2d( - in_channels=cout, out_channels=out_channels, kernel=3, **init_zero - ) - - def forward(self, x, noise_labels, class_labels, augment_labels=None): - - emb = self.map_noise(noise_labels) - emb = ( - emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) - ) # swap sin/cos - if self.map_label is not None: - tmp = class_labels - if self.training and self.label_dropout: - tmp = tmp * ( - torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout - ).to(tmp.dtype) - emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) - if self.map_augment is not None and augment_labels is not None: - emb = emb + self.map_augment(augment_labels) - emb = silu(self.map_layer0(emb)) - emb = silu(self.map_layer1(emb)) - - # Encoder. - skips = [] - aux = x - for name, block in self.enc.items(): - if "aux_down" in name: - aux = block(aux) - elif "aux_skip" in name: - x = skips[-1] = x + block(aux) - elif "aux_residual" in name: - x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) - elif "_conv" in name: - x = block(x) - if self.spatial_embedding: - x = x + self.spatial_emb.to(dtype=x.dtype) - skips.append(x) - else: - x = block(x, emb) if isinstance(block, UNetBlock) else block(x) - skips.append(x) - - # Decoder. - aux = None - tmp = None - for name, block in self.dec.items(): - if "aux_up" in name: - aux = block(aux) - elif "aux_norm" in name: - tmp = block(x) - elif "aux_conv" in name: - tmp = block(silu(tmp)) - aux = tmp if aux is None else tmp + aux - else: - if x.shape[1] != block.in_channels: - x = torch.cat([x, skips.pop()], dim=1) - x = block(x, emb) - - return aux - - -class EDMPrecond(torch.nn.Module): - """Preconditioning wrapper to implement diffusion model training as proposed in "Elucidating the Design Space of - Diffusion-Based Generative Models" (EDM) - """ - - def __init__( - self, - model: torch.nn.Module, - img_resolution, # Image resolution. - label_dim=0, # Number of class labels, 0 = unconditional. - use_fp16=False, # Execute the underlying model at FP16 precision? - sigma_min=0, # Minimum supported noise level. - sigma_max=float("inf"), # Maximum supported noise level. - sigma_data=0.5, # Expected standard deviation of the training data. - ): - super().__init__() - self.img_resolution = img_resolution - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data - self.model = model - - def forward( - self, - x, - sigma, - condition=None, - class_labels=None, - force_fp32=False, - **model_kwargs, - ): - """ - Args: - x: The latent data (to be denoised). shape [batch_size, target_channels, w, h] - class_labels: optional, shape [batch_size, label_dim] - condition: optional, the conditional inputs, - shape=[batch_size, condition_channel, w, h] - **model_kwargs: passed on to self.model. - - Returns: - D_x: shape [batch_size, out_channels, w, h] - - """ - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() - c_noise = sigma.log() / 4 - - arg = c_in * x - - if condition is not None: - arg = torch.cat([arg, condition], dim=1) - - F_x = self.model( - arg.to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - assert F_x.dtype == dtype - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - - def round_sigma(self, sigma): - return torch.as_tensor(sigma) - - -class EasyRegressionV2(torch.nn.Module): - """Wrapper enabling the regression model to outwardly look like a straightforward regression model in inference""" - - def __init__( - self, - model, - ): - super().__init__() - - self.model = model - - def set_invariant(self, invariant_tensor): - - self.invariant_tensor = invariant_tensor - - def forward( - self, hrrr, era5, mask=None - ): # mask is just for compatibility. Doesn't do anything - - condition = torch.cat( - [hrrr, era5, self.invariant_tensor.repeat(hrrr.shape[0], 1, 1, 1)], dim=1 - ) - - sigma = torch.randn([condition.shape[0], 1, 1, 1], device=condition.device) - - return self.model(sigma=sigma, condition=condition) - - -class RegressionWrapperV2(torch.nn.Module): - """Wrapper class for training regression models - (allows re-using training code between regression and diffusion training) - """ - - def __init__( - self, - model: torch.nn.Module, - img_resolution, # Image resolution. - label_dim=0, # Number of class labels, 0 = unconditional. - use_fp16=False, # Execute the underlying model at FP16 precision? - ): - super().__init__() - self.img_resolution = img_resolution - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.model = model - - def forward( - self, - sigma, - condition=None, - class_labels=None, - force_fp32=False, - **model_kwargs, - ): - - condition = condition.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = None - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_noise = torch.zeros_like(sigma) - arg = condition - - F_x = self.model( - arg.to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - assert F_x.dtype == dtype - - D_x = F_x.to(torch.float32) - return D_x - - -# ---------------------------------------------------------------------------- - - -def get_preconditioned_architecture( - name: str, - resolution: int, - target_channels: int, - conditional_channels: int = 0, - label_dim: int = 0, - spatial_embedding: str = "add", - hrrr_resolution: tuple = (512, 640), - attn_resolutions: list = [], -): - """ - - Args: - name: 'regression' or 'diffusion' to select between either model type - resolution (int): _description_ - target_channels: The number of channels in the target - conditional_channels: The number of channels in the conditioning - label_dim: size of label data - sigma_min: Defaults to 0. - sigma_max: Defaults to float("inf"). - sigma_data: Defaults to 0.5. - - Returns: - EDMPrecond: a wrapped torch module net(x+n, sigma, condition, class_labels) -> x - """ - if name == "diffusion": - model = SongUNet( - img_resolution=resolution, - in_channels=target_channels + conditional_channels, - out_channels=target_channels, - label_dim=label_dim, - embedding_type="positional", - encoder_type="standard", - decoder_type="standard", - channel_mult_noise=1, - resample_filter=[1, 1], - model_channels=128, - channel_mult=[1, 2, 2, 2, 2], - attn_resolutions=attn_resolutions, - spatial_embedding=spatial_embedding, - hrrr_resolution=hrrr_resolution, - ) - return EDMPrecond( - model=model, - img_resolution=resolution, - label_dim=label_dim, - ) - elif name == "regression": - model = SongUNetRegression( - img_resolution=resolution, - in_channels=conditional_channels, - out_channels=target_channels, - label_dim=label_dim, - embedding_type="zero", - encoder_type="standard", - decoder_type="standard", - channel_mult_noise=1, - resample_filter=[1, 1], - model_channels=128, - channel_mult=[1, 2, 2, 2, 2], - attn_resolutions=[], - spatial_embedding=spatial_embedding, - hrrr_resolution=hrrr_resolution, - ) - return RegressionWrapperV2( - model=model, - img_resolution=resolution, - label_dim=label_dim, - ) - else: - raise ValueError(f'Invalid architecture name "{name}"') diff --git a/examples/generative/stormcast/utils/diffusions/run_edm.py b/examples/generative/stormcast/utils/diffusions/run_edm.py deleted file mode 100644 index 7b78f70305..0000000000 --- a/examples/generative/stormcast/utils/diffusions/run_edm.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from utils.diffusions.generate import edm_sampler -from utils.diffusions.networks import get_preconditioned_architecture -from utils.data_loader_hrrr_era5 import get_dataset - - -class EDMRunner: - """Wrapper class to handle loading a pretrained diffusion model and running inference with it""" - - def __init__(self, params, checkpoint_path, device=None, sampler_args={}): - - if device is None: - device = torch.device("cuda:0") - - self.sampler_args = sampler_args - dataset_obj = get_dataset(params, train=True) - - _, hrrr_channels = dataset_obj._get_hrrr_channel_names() - self.input_channels = ( - hrrr_channels if params.input_channels == "all" else params.input_channels - ) - self.input_channel_indices = list(range(len(hrrr_channels))) - self.diffusion_channels = ( - hrrr_channels - if params.diffusion_channels == "all" - else params.diffusion_channels - ) - self.diffusion_channel_indices = list(range(len(hrrr_channels))) - - invariant_array = dataset_obj._get_invariants() - self.invariant_tensor = torch.from_numpy(invariant_array).to(device) - self.invariant_tensor = self.invariant_tensor.unsqueeze(0) - - resolution = params.hrrr_img_size[0] - n_target_channels = len(self.diffusion_channel_indices) - n_input_channels = 2 * len(self.input_channel_indices) + len(params.invariants) - - self.net = get_preconditioned_architecture( - name="diffusion", - resolution=resolution, - target_channels=n_target_channels, - conditional_channels=n_input_channels, - label_dim=0, - spatial_embedding=params.spatial_pos_embed, - attn_resolutions=params.attn_resolutions, - ).requires_grad_(False) - - assert self.net.sigma_min < self.net.sigma_max - - # Load pretrained weights - chkpt = torch.load(checkpoint_path, weights_only=True) - self.net.load_state_dict(chkpt["net"], strict=True) - self.net = self.net.to(device) - self.params = params - - print("n target channels: ", n_target_channels) - - def run(self, hrrr_0): - - with torch.no_grad(): - - ensemble_size, c, h, w = hrrr_0[ - :, self.diffusion_channel_indices, :, : - ].shape - latents = torch.randn( - ensemble_size, c, h, w, device=hrrr_0.device, dtype=hrrr_0.dtype - ) - - if ensemble_size > 1 and self.invariant_tensor.shape[0] != ensemble_size: - self.invariant_tensor = self.invariant_tensor.expand( - ensemble_size, -1, -1, -1 - ) - condition = torch.cat((hrrr_0, self.invariant_tensor), dim=1) - output_images = edm_sampler( - self.net, latents=latents, condition=condition, **self.sampler_args - ) - - return output_images, self.diffusion_channels diff --git a/examples/generative/stormcast/utils/diffusions/training_loop.py b/examples/generative/stormcast/utils/diffusions/training_loop.py deleted file mode 100644 index 3b2bc6c185..0000000000 --- a/examples/generative/stormcast/utils/diffusions/training_loop.py +++ /dev/null @@ -1,557 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Main training loop.""" - -import os -import time -import copy -import json -from utils.YParams import YParams -import psutil -import numpy as np -import torch -from modulus.distributed import DistributedManager -from utils import training_stats -from utils import misc -from utils.misc import print0 -from utils.diffusions.generate import edm_sampler -from utils.diffusions.networks import get_preconditioned_architecture, EasyRegressionV2 -from utils.diffusions.losses import EDMLoss, RegressionLossV2 -from utils.data_loader_hrrr_era5 import get_dataset, worker_init -import matplotlib.pyplot as plt -import wandb -from utils.spectrum import compute_ps1d -from torch.nn.utils import clip_grad_norm_ - -# ---------------------------------------------------------------------------- - - -def get_pretrained_regression_net( - checkpoint_path, config_file, regression_config, target_channels, device -): - """ - Load a pretrained regression network as specified by a given config - """ - - hyperparams = YParams(config_file, regression_config) - resolution = hyperparams.hrrr_img_size[0] - - conditional_channels = ( - target_channels + len(hyperparams.invariants) + hyperparams.n_era5_channels - ) - - net = get_preconditioned_architecture( - name="regression", - resolution=resolution, - target_channels=target_channels, - conditional_channels=conditional_channels, - label_dim=0, - spatial_embedding=hyperparams.spatial_pos_embed, - attn_resolutions=hyperparams.attn_resolutions, - ) - - chkpt = torch.load(checkpoint_path, weights_only=True) - net.load_state_dict(chkpt["net"], strict=True) - net = EasyRegressionV2(net) - - return net.to(device) - - -def training_loop( - run_dir=".", # Output directory. - optimizer_kwargs={}, # Options for optimizer. - seed=0, # Global random seed. - lr_rampup_kimg=2000, # Learning rate ramp-up duration. - state_dump_ticks=50, # How often to dump training state, None = disable. - resume_state_dump=None, # Start from the given training state, None = reset training state. - resume_kimg=0, # Start from the given training progress. - cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? - device=torch.device("cuda"), - config_file=None, - config_name=None, - log_to_wandb=False, -): - dist = DistributedManager() - params = YParams(config_file, config_name) - batch_size = params.batch_size - local_batch_size = batch_size // dist.world_size - optimizer_kwargs["lr"] = params.lr - img_per_tick = params.img_per_tick - use_regression_net = params.use_regression_net - previous_step_conditioning = params.previous_step_conditioning - loss_type = params.loss - if loss_type == "regression_v2": - train_regression_unet = True - net_name = "regression" - print0("Using regression_v2") - elif loss_type == "edm": - train_regression_unet = False - net_name = "diffusion" - - # Initialize. - start_time = time.time() - np.random.seed((seed * dist.world_size + dist.rank) % (1 << 31)) - torch.manual_seed(np.random.randint(1 << 31)) - torch.backends.cudnn.benchmark = cudnn_benchmark - torch.backends.cudnn.allow_tf32 = False - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - total_kimg = params.total_kimg - - if resume_state_dump is not None: - print0("Resuming from state dump:", resume_state_dump) - print0("Resuming from kimg:", resume_kimg) - - # Load dataset. - print0("Loading dataset...") - - dataset_train = get_dataset(params, train=True) - dataset_valid = get_dataset(params, train=False) - - _, hrrr_channels = dataset_train._get_hrrr_channel_names() - diffusion_channels = ( - hrrr_channels - if params.diffusion_channels == "all" - else params.diffusion_channels - ) - input_channels = ( - hrrr_channels if params.input_channels == "all" else params.input_channels - ) - input_channel_indices = [hrrr_channels.index(channel) for channel in input_channels] - diffusion_channel_indices = [ - hrrr_channels.index(channel) for channel in diffusion_channels - ] - - sampler = misc.InfiniteSampler( - dataset=dataset_train, - rank=dist.rank, - num_replicas=dist.world_size, - seed=seed, - ) - valid_sampler = misc.InfiniteSampler( - dataset=dataset_valid, - rank=dist.rank, - num_replicas=dist.world_size, - seed=seed, - ) - data_loader = torch.utils.data.DataLoader( - dataset=dataset_train, - batch_size=local_batch_size, - num_workers=params.num_data_workers, - sampler=sampler, - worker_init_fn=worker_init, - drop_last=True, - pin_memory=torch.cuda.is_available(), - ) - - valid_data_loader = torch.utils.data.DataLoader( - dataset=dataset_valid, - batch_size=local_batch_size, - num_workers=params.num_data_workers, - sampler=valid_sampler, - drop_last=True, - pin_memory=torch.cuda.is_available(), - ) - - dataset_iterator = iter(data_loader) - valid_dataset_iterator = iter(valid_data_loader) - - # load pretrained regression net if training diffusion - if use_regression_net: - regression_net = get_pretrained_regression_net( - checkpoint_path=params.regression_weights, - config_file=config_file, - regression_config=params.regression_config, - target_channels=len(diffusion_channels), - device=device, - ) - - # Construct network - print0("Constructing network...") - resolution = 512 - target_channels = len(diffusion_channels) - if train_regression_unet: - conditional_channels = ( - len(input_channels) + 26 - ) # 26 is the number of era5 channels - else: - conditional_channels = ( - len(input_channels) - if not previous_step_conditioning - else 2 * len(input_channels) - ) - - conditional_channels += len(params.invariants) - invariant_array = dataset_train._get_invariants() - invariant_tensor = torch.from_numpy(invariant_array).to(device) - - if not train_regression_unet: - regression_net.set_invariant(invariant_tensor) - - print0("hrrr_channels", hrrr_channels) - print0("target_channels for diffusion", target_channels) - print0("conditional_channels for diffusion", conditional_channels) - - net = get_preconditioned_architecture( - name=net_name, - resolution=resolution, - target_channels=target_channels, - conditional_channels=conditional_channels, - label_dim=0, - spatial_embedding=params.spatial_pos_embed, - attn_resolutions=params.attn_resolutions, - ) - - if not params.loss in ["regression", "regression_v2"]: - assert net.sigma_min < net.sigma_max - net.train().requires_grad_(True).to(device) - - # Setup optimizer. - print0("Setting up optimizer...") - if params.loss == "regression_v2": - loss_fn = RegressionLossV2() - elif params.loss == "edm": - loss_fn = EDMLoss(P_mean=params.P_mean) - optimizer = torch.optim.Adam(net.parameters(), **optimizer_kwargs) - augment_pipe = None - ddp = torch.nn.parallel.DistributedDataParallel( - net, device_ids=[device], broadcast_buffers=False - ) - - total_steps = 0 - - # Resume training from previous snapshot. - if resume_state_dump: - print0(f'Loading training state from "{resume_state_dump}"...') - data = torch.load( - resume_state_dump, map_location=torch.device("cpu"), weights_only=True - ) - net.load_state_dict(data["net"]) - total_steps = data["total_steps"] - optimizer.load_state_dict(data["optimizer_state"]) - del data # conserve memory - - # Train. - print0(f"Training for {total_kimg} kimg...") - print0() - cur_nimg = resume_kimg * 1000 - cur_tick = 0 - tick_start_nimg = cur_nimg - tick_start_time = time.time() - maintenance_time = tick_start_time - start_time - stats_jsonl = None - wandb_logs = {} - - while True: - # Accumulate gradients. - optimizer.zero_grad(set_to_none=True) - batch = next(dataset_iterator) - hrrr_0 = batch["hrrr"][0].to(device).to(torch.float32) - hrrr_1 = batch["hrrr"][1].to(device).to(torch.float32) - - if use_regression_net: - era5 = batch["era5"][0].to(device).to(torch.float32) - - with torch.no_grad(): - reg_out = regression_net(hrrr_0, era5, mask=None) - hrrr_0 = torch.cat( - ( - hrrr_0[:, input_channel_indices, :, :], - reg_out[:, input_channel_indices, :, :], - ), - dim=1, - ) - hrrr_1 = hrrr_1 - reg_out - del reg_out - - elif train_regression_unet: - assert diffusion_channel_indices == input_channel_indices - - era5 = batch["era5"][0].to(device).to(torch.float32) - - hrrr_0 = torch.cat((hrrr_0[:, input_channel_indices, :, :], era5), dim=1) - - hrrr_1 = hrrr_1[ - :, diffusion_channel_indices, :, : - ] # targets of the diffusion model - - invariant_tensor_ = invariant_tensor.unsqueeze(0) - invariant_tensor_ = invariant_tensor.repeat(hrrr_0.shape[0], 1, 1, 1) - hrrr_0 = torch.cat((hrrr_0, invariant_tensor_), dim=1) - - loss = loss_fn(net=ddp, x=hrrr_1, condition=hrrr_0, augment_pipe=augment_pipe) - channelwise_loss = loss.mean(dim=(0, 2, 3)) - channelwise_loss_dict = { - f"ChLoss/{diffusion_channels[i]}": channelwise_loss[i].item() - for i in range(target_channels) - } - training_stats.report("Loss/loss", loss.mean()) - loss_value = loss.sum() / target_channels - if log_to_wandb: - wandb_logs["channelwise_loss"] = channelwise_loss_dict - - loss_value.backward() - - if params.clip_grad_norm is not None: - clip_grad_norm_(net.parameters(), params.clip_grad_norm) - - # Update weights. - for g in optimizer.param_groups: - g["lr"] = optimizer_kwargs["lr"] * min( - cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1 - ) - if log_to_wandb: - wandb_logs["lr"] = g["lr"] - for param in net.parameters(): - if param.grad is not None: - torch.nan_to_num( - param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad - ) - - optimizer.step() - - # Perform maintenance tasks once per tick. - effective_batch_size = (batch_size // local_batch_size) * hrrr_0.shape[0] - total_steps += 1 - cur_nimg += effective_batch_size - # done = (cur_nimg >= total_kimg * 1000) - done = cur_nimg >= 10 # TODO remove this line (testing only) - - if ( - (not done) - and (cur_tick != 0) - and (cur_nimg < tick_start_nimg + img_per_tick) - ): - continue - - # make inference - if cur_tick % params.validate_every == 0: - batch = next(valid_dataset_iterator) - - with torch.no_grad(): - # n = 1 - hrrr_0, hrrr_1 = batch["hrrr"] - hrrr_0 = hrrr_0.to(torch.float32).to(device) - hrrr_1 = hrrr_1.to(torch.float32).to(device) - - invariant_tensor_ = invariant_tensor.unsqueeze(0) - invariant_tensor_ = invariant_tensor.repeat(hrrr_0.shape[0], 1, 1, 1) - - if use_regression_net: - with torch.no_grad(): - era5 = batch["era5"][0].to(device).to(torch.float32) - reg_out = regression_net(hrrr_0, era5, mask=None) - hrrr_0 = torch.cat( - ( - hrrr_0[:, input_channel_indices, :, :], - reg_out[:, input_channel_indices, :, :], - ), - dim=1, - ) - latents = torch.randn_like( - hrrr_1[:, diffusion_channel_indices, :, :] - ) - loss_target = hrrr_1 - reg_out - output_images = edm_sampler( - net, - latents=latents, - condition=torch.cat((hrrr_0, invariant_tensor_), dim=1), - ) - valid_loss = loss_fn( - net=ddp, - x=loss_target[:, diffusion_channel_indices], - condition=torch.cat((hrrr_0, invariant_tensor_), dim=1), - augment_pipe=augment_pipe, - ) - output_images += reg_out[:, diffusion_channel_indices, :, :] - del reg_out - - elif train_regression_unet: - assert ( - use_regression_net == False - ), "use_regression_net must be False when training regression unet" - assert ( - input_channel_indices == diffusion_channel_indices - ), "input_channel_indices must be equal to diffusion_channel_indices when training regression unet" - condition = torch.cat( - ( - hrrr_0[:, input_channel_indices, :, :], - era5[:], - invariant_tensor_, - ), - dim=1, - ) - latents = torch.zeros_like( - hrrr_1[:, diffusion_channel_indices, :, :], - device=hrrr_1.device, - ) - rnd_normal = torch.randn( - [latents.shape[0], 1, 1, 1], device=latents.device - ) - sigma = ( - rnd_normal * 1.2 - 1.2 - ).exp() # this isn't used by the code - output_images = net(sigma=sigma, condition=condition) - valid_loss = loss_fn( - net=ddp, - x=hrrr_1[:, diffusion_channel_indices, :, :], - condition=condition, - augment_pipe=augment_pipe, - ) - channelwise_valid_loss = valid_loss.mean(dim=[0, 2, 3]) - channelwise_valid_loss_dict = { - f"ChLoss_valid/{diffusion_channels[i]}": channelwise_valid_loss[ - i - ].item() - for i in range(target_channels) - } - if log_to_wandb: - wandb_logs[ - "channelwise_valid_loss" - ] = channelwise_valid_loss_dict - - hrrr_1 = hrrr_1[:, diffusion_channel_indices, :, :] - - training_stats.report("Loss/valid_loss", valid_loss.mean()) - - # Print status line, accumulating the same information in training_stats. - tick_end_time = time.time() - fields = [] - fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] - fields += [ - f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}" - ] - fields += [ - f"time {misc.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}" - ] - fields += [ - f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}" - ] - fields += [ - f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}" - ] - fields += [ - f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}" - ] - fields += [ - f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" - ] - fields += [ - f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}" - ] - fields += [ - f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}" - ] - torch.cuda.reset_peak_memory_stats() - print0(" ".join(fields)) - - if cur_tick % params.validate_every == 0: - if log_to_wandb: - if dist.rank == 0: - print0("logging to wandb") - wandb.log(wandb_logs, step=cur_nimg) - - if dist.rank == 0: - # TODO: improve the image saving and run_dir setup for thread safe image saving from all ranks - - for i in range(output_images.shape[0]): - image = output_images[i].cpu().numpy() - # hrrr_channels = dataset_train.hrrr_channels - fields = ["u10m", "v10m", "t2m", "refc", "q1", "q5", "q10"] - - # Compute spectral metrics - figs, spec_ratios = compute_ps1d( - output_images[i], hrrr_1[i], fields, diffusion_channels - ) - if log_to_wandb: - wandb.log(spec_ratios, step=cur_nimg) - for figname, fig in figs.items(): - wandb.log({figname: wandb.Image(fig)}, step=cur_nimg) - - for f_ in fields: - f_index = diffusion_channels.index(f_) - image_dir = os.path.join(run_dir, "images", f_) - generated = image[f_index] - truth = hrrr_1[i, f_index].cpu().numpy() - - fig, (a, b) = plt.subplots(1, 2) - im = a.imshow(generated) - a.set_title("generated, {}.png".format(f_)) - plt.colorbar(im, fraction=0.046, pad=0.04) - im = b.imshow(truth) - b.set_title("truth") - plt.colorbar(im, fraction=0.046, pad=0.04) - os.makedirs(image_dir, exist_ok=True) - plt.savefig(os.path.join(image_dir, f"{cur_tick}_{i}_{f_}.png")) - plt.close("all") - - specfig = "PS1D_" + f_ - figs[specfig].savefig( - os.path.join(image_dir, f"{cur_tick}{i}{f_}_spec.png") - ) - plt.close(figs[specfig]) - - # log the images to wandb - if log_to_wandb: - # log fig to wandb - wandb.log({f"generated_{f_}": fig}, step=cur_nimg) - - # Save full dump of the training state. - if ( - (state_dump_ticks is not None) - and (done or cur_tick % state_dump_ticks == 0) - and cur_tick != 0 - and dist.rank == 0 - ): - torch.save( - dict( - net=net.state_dict(), - optimizer_state=optimizer.state_dict(), - total_steps=total_steps, - ), - os.path.join(run_dir, f"training-state-{cur_nimg//1000:06d}.pt"), - ) - - # Update logs. - training_stats.default_collector.update() - if dist.rank == 0: - if stats_jsonl is None: - stats_jsonl = open(os.path.join(run_dir, "stats.jsonl"), "at") - - stats_dict = dict( - training_stats.default_collector.as_dict(), timestamp=time.time() - ) - if True: - wandb_logs["loss"] = stats_dict["Loss/loss"]["mean"] - wandb_logs["valid_loss"] = stats_dict["Loss/valid_loss"]["mean"] - print0("loss: ", wandb_logs["loss"]) - print0("valid_loss: ", wandb_logs["valid_loss"]) - stats_jsonl.write(json.dumps(stats_dict) + "\n") - stats_jsonl.flush() - - # Update state. - cur_tick += 1 - tick_start_nimg = cur_nimg - tick_start_time = time.time() - maintenance_time = tick_start_time - tick_end_time - if done: - break - - # Done. - torch.distributed.barrier() - print0() - print0("Exiting...") diff --git a/examples/generative/stormcast/utils/misc.py b/examples/generative/stormcast/utils/misc.py deleted file mode 100644 index 96c0d5bace..0000000000 --- a/examples/generative/stormcast/utils/misc.py +++ /dev/null @@ -1,241 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys -import re -import numpy as np -import torch -from typing import Any, Union, Optional - -from modulus.distributed import DistributedManager - - -try: - nan_to_num = torch.nan_to_num # 1.8.0a0 -except AttributeError: - - def nan_to_num( - input, nan=0.0, posinf=None, neginf=None, *, out=None - ): # pylint: disable=redefined-builtin - assert isinstance(input, torch.Tensor) - if posinf is None: - posinf = torch.finfo(input.dtype).max - if neginf is None: - neginf = torch.finfo(input.dtype).min - assert nan == 0 - return torch.clamp( - input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out - ) - - -class EasyDict(dict): - """Convenience class that behaves like a dict but allows access with the attribute syntax.""" - - def __getattr__(self, name: str) -> Any: - try: - return self[name] - except KeyError: - raise AttributeError(name) - - def __setattr__(self, name: str, value: Any) -> None: - self[name] = value - - def __delattr__(self, name: str) -> None: - del self[name] - - -def format_time(seconds: Union[int, float]) -> str: - """Convert the seconds to human readable string with days, hours, minutes and seconds.""" - s = int(np.rint(seconds)) - - if s < 60: - return "{0}s".format(s) - elif s < 60 * 60: - return "{0}m {1:02}s".format(s // 60, s % 60) - elif s < 24 * 60 * 60: - return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) - else: - return "{0}d {1:02}h {2:02}m".format( - s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60 - ) - - -def print0(*args, **kwargs): - if DistributedManager().rank == 0: - print(*args, **kwargs) - - -class Logger(object): - """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" - - def __init__( - self, - file_name: Optional[str] = None, - file_mode: str = "w", - should_flush: bool = True, - ): - self.file = None - - if file_name is not None: - self.file = open(file_name, file_mode) - - self.should_flush = should_flush - self.stdout = sys.stdout - self.stderr = sys.stderr - - sys.stdout = self - sys.stderr = self - - def __enter__(self) -> "Logger": - return self - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - self.close() - - def write(self, text: Union[str, bytes]) -> None: - """Write text to stdout (and a file) and optionally flush.""" - if isinstance(text, bytes): - text = text.decode() - if ( - len(text) == 0 - ): # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash - return - - if self.file is not None: - self.file.write(text) - - self.stdout.write(text) - - if self.should_flush: - self.flush() - - def flush(self) -> None: - """Flush written text to both stdout and a file, if open.""" - if self.file is not None: - self.file.flush() - - self.stdout.flush() - - def close(self) -> None: - """Flush, close possible files, and remove stdout/stderr mirroring.""" - self.flush() - - # if using multiple loggers, prevent closing in wrong order - if sys.stdout is self: - sys.stdout = self.stdout - if sys.stderr is self: - sys.stderr = self.stderr - - if self.file is not None: - self.file.close() - self.file = None - - -# ---------------------------------------------------------------------------- -# Symbolic assert. - -try: - symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access -except AttributeError: - symbolic_assert = torch.Assert # 1.7.0 - -# ---------------------------------------------------------------------------- -# Function decorator that calls torch.autograd.profiler.record_function(). - - -def profiled_function(fn): - def decorator(*args, **kwargs): - with torch.autograd.profiler.record_function(fn.__name__): - return fn(*args, **kwargs) - - decorator.__name__ = fn.__name__ - return decorator - - -class InfiniteSampler(torch.utils.data.Sampler): - """Sampler for torch.utils.data.DataLoader that loops over the dataset - indefinitely, shuffling items as it goes.""" - - def __init__( - self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 - ): - assert len(dataset) > 0 - assert num_replicas > 0 - assert 0 <= rank < num_replicas - assert 0 <= window_size <= 1 - super().__init__(dataset) - self.dataset = dataset - self.rank = rank - self.num_replicas = num_replicas - self.shuffle = shuffle - self.seed = seed - self.window_size = window_size - - def __iter__(self): - order = np.arange(len(self.dataset)) - rnd = None - window = 0 - if self.shuffle: - rnd = np.random.RandomState(self.seed) - rnd.shuffle(order) - window = int(np.rint(order.size * self.window_size)) - - idx = 0 - while True: - i = idx % order.size - if idx % self.num_replicas == self.rank: - yield order[i] - if window >= 2: - j = (i - rnd.randint(window)) % order.size - order[i], order[j] = order[j], order[i] - idx += 1 - - -def named_params_and_buffers(module): - assert isinstance(module, torch.nn.Module) - return list(module.named_parameters()) + list(module.named_buffers()) - - -# ---------------------------------------------------------------------------- -# Check DistributedDataParallel consistency across processes. - - -def check_ddp_consistency(module, ignore_regex=None): - assert isinstance(module, torch.nn.Module) - for name, tensor in named_params_and_buffers(module): - fullname = type(module).__name__ + "." + name - if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): - continue - tensor = tensor.detach() - if tensor.is_floating_point(): - tensor = nan_to_num(tensor) - other = tensor.clone() - torch.distributed.broadcast(tensor=other, src=0) - assert (tensor == other).all(), fullname - - -@torch.no_grad() -def copy_params_and_buffers(src_module, dst_module, require_all=False): - assert isinstance(src_module, torch.nn.Module) - assert isinstance(dst_module, torch.nn.Module) - src_tensors = dict(named_params_and_buffers(src_module)) - for name, tensor in named_params_and_buffers(dst_module): - assert (name in src_tensors) or (not require_all) - if name in src_tensors: - tensor.copy_(src_tensors[name]) - - -# ---------------------------------------------------------------------------- -# Print summary table of module hierarchy. diff --git a/examples/generative/stormcast/utils/nn.py b/examples/generative/stormcast/utils/nn.py new file mode 100644 index 0000000000..3fc9b7a885 --- /dev/null +++ b/examples/generative/stormcast/utils/nn.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from modulus.models import Module +from modulus.models.diffusion import EDMPrecond, StormCastUNet +from modulus.utils.generative import deterministic_sampler + + +def get_preconditioned_architecture( + name: str, + target_channels: int, + conditional_channels: int = 0, + spatial_embedding: bool = True, + hrrr_resolution: tuple = (512, 640), + attn_resolutions: list = [], +): + """ + + Args: + name: 'regression' or 'diffusion' to select between either model type + target_channels: The number of channels in the target + conditional_channels: The number of channels in the conditioning + spatial_embedding: whether or not to use the additive spatial embedding in the U-Net + hrrr_resolution: resolution of HRRR data (U-Net inputs/outputs) + attn_resolutions: resolution of internal U-Net stages to use self-attention + Returns: + EDMPrecond or StormCastUNet: a wrapped torch module net(x+n, sigma, condition, class_labels) -> x + """ + if name == "diffusion": + return EDMPrecond( + img_resolution=hrrr_resolution, + img_channels=target_channels + conditional_channels, + img_out_channels=target_channels, + model_type="SongUNet", + channel_mult=[1, 2, 2, 2, 2], + attn_resolutions=attn_resolutions, + additive_pos_embed=spatial_embedding, + ) + + elif name == "regression": + return StormCastUNet( + img_resolution=hrrr_resolution, + img_in_channels=conditional_channels, + img_out_channels=target_channels, + model_type="SongUNet", + embedding_type="zero", + channel_mult=[1, 2, 2, 2, 2], + attn_resolutions=attn_resolutions, + additive_pos_embed=spatial_embedding, + ) + + +def diffusion_model_forward( + model, hrrr_0, diffusion_channel_indices, invariant_tensor, sampler_args={} +): + """Helper function to run diffusion model sampling""" + + b, c, h, w = hrrr_0[:, diffusion_channel_indices, :, :].shape + + latents = torch.randn(b, c, h, w, device=hrrr_0.device, dtype=hrrr_0.dtype) + + if b > 1 and invariant_tensor.shape[0] != b: + invariant_tensor = invariant_tensor.expand(b, -1, -1, -1) + condition = torch.cat((hrrr_0, invariant_tensor), dim=1) + + output_images = deterministic_sampler( + model, latents=latents, img_lr=condition, **sampler_args + ) + + return output_images + + +def regression_model_forward(model, hrrr, era5, invariant_tensor): + """Helper function to run regression model forward pass in inference""" + + x = torch.cat([hrrr, era5, invariant_tensor], dim=1) + + return model(x) + + +def regression_loss_fn( + net: Module, + images, + condition, + class_labels=None, + augment_pipe=None, + return_model_outputs=False, +): + """Helper function for training the StormCast regression model, so that it has a similar call signature as + the EDMLoss and the same training loop can be used to train both regression and diffusion models + + Args: + net: modulus.models.diffusion.StormCastUNet + images: Target data, shape [batch_size, target_channels, w, h] + condition: input to the model, shape=[batch_size, condition_channel, w, h] + class_labels: unused (applied to match EDMLoss signature) + augment_pipe: optional data augmentation pipe + return_model_outputs: If True, will return the generated outputs + Returns: + out: loss function with shape [batch_size, target_channels, w, h] + This should be averaged to get the mean loss for gradient descent. + """ + + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + + D_yn = net(x=condition) + loss = (D_yn - y) ** 2 + if return_model_outputs: + return loss, D_yn + else: + return loss diff --git a/examples/generative/stormcast/utils/spectrum.py b/examples/generative/stormcast/utils/spectrum.py index aef14544a9..23b7f13adc 100644 --- a/examples/generative/stormcast/utils/spectrum.py +++ b/examples/generative/stormcast/utils/spectrum.py @@ -17,79 +17,21 @@ import torch import numpy as np import matplotlib.pyplot as plt +from modulus.metrics.general.power_spectrum import power_spectrum -def batch_histogram(data_tensor, num_classes=-1, weights=None): - """ - From. https://github.com/pytorch/pytorch/issues/99719#issuecomment-1760112194 - Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram). - Arguments: - data_tensor: a D1 x ... x D_n torch.LongTensor - num_classes (optional): the number of classes present in data. - If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty). - Returns: - A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor, - containing histograms of the last dimension D_n of tensor, - that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}]. - """ - maxd = data_tensor.max() - nc = (maxd + 1) if num_classes <= 0 else num_classes - hist = torch.zeros( - (*data_tensor.shape[:-1], nc), - dtype=data_tensor.dtype, - device=data_tensor.device, - ) - if weights is not None: - wts = weights - else: - wts = torch.tensor(1, dtype=hist.dtype, device=hist.device).expand( - data_tensor.shape - ) - hist.scatter_add_(-1, ((data_tensor * nc) // (maxd + 1)).long(), wts) - return hist - - -def powerspect(x): - c, h, w = x.shape - - # 2D power - pwr = torch.fft.fftn(x, dim=(-2, -1), norm="ortho").abs() ** 2 - pwr = torch.fft.fftshift(pwr, dim=(-2, -1)).to(torch.float32) - - # Azimuthal average - xx, yy = torch.meshgrid( - torch.arange(h, device=pwr.device), - torch.arange(w, device=pwr.device), - indexing="ij", - ) - k = torch.hypot(xx - h // 2, yy - w / 2).to(torch.float32) - - sort = torch.argsort(k.flatten()) - k_sort = k.flatten()[sort] - pwr_sort = pwr.reshape(c, -1)[:, sort] - - nbins = min(h // 2, w // 2) - k_bins = torch.linspace(0, k_sort.max() + 1, nbins) - k_bin_centers = 0.5 * (k_bins[1:] + k_bins[:-1]) - k_sort_stack = torch.tile(k_sort, dims=(c, 1)) - - pwr_binned = batch_histogram(k_sort_stack, weights=pwr_sort, num_classes=nbins - 1) - count_binned = batch_histogram(k_sort_stack, num_classes=nbins - 1) - - return ( - k_bin_centers.detach().cpu().numpy(), - (pwr_binned / count_binned).detach().cpu().numpy(), - ) - - -def compute_ps1d(generated, target, fields, diffusion_channels): +def ps1d_plots(generated, target, fields, diffusion_channels): assert generated.shape == target.shape # Comppute PS1D, all channels with torch.no_grad(): - k, Pk_gen = powerspect(generated) - _, Pk_tar = powerspect(target) + k, Pk_gen = power_spectrum(generated) + _, Pk_tar = power_spectrum(target) + + k = k.detach().cpu().numpy() + Pk_gen = Pk_gen.detach().cpu().numpy() + Pk_tar = Pk_tar.detach().cpu().numpy() # Make plots and save metrics figs = {} diff --git a/examples/generative/stormcast/utils/trainer.py b/examples/generative/stormcast/utils/trainer.py new file mode 100644 index 0000000000..4390fd823d --- /dev/null +++ b/examples/generative/stormcast/utils/trainer.py @@ -0,0 +1,472 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main training loop.""" + +import os +import time +import numpy as np +import torch +import psutil +from modulus.models import Module +from modulus.distributed import DistributedManager +from modulus.metrics.diffusion import EDMLoss +from modulus.utils.generative import InfiniteSampler + +from modulus.launch.utils import save_checkpoint, load_checkpoint +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper +from utils.nn import ( + regression_model_forward, + diffusion_model_forward, + regression_loss_fn, + get_preconditioned_architecture, +) +from utils.data_loader_hrrr_era5 import HrrrEra5Dataset, worker_init +import matplotlib.pyplot as plt +import wandb +from utils.spectrum import ps1d_plots +from torch.nn.utils import clip_grad_norm_ + + +logger = PythonLogger("train") + + +def training_loop(cfg): + + # Initialize. + start_time = time.time() + dist = DistributedManager() + device = dist.device + logger0 = RankZeroLoggingWrapper(logger, dist) + + # Shorthand for config items + batch_size = cfg.training.batch_size + local_batch_size = batch_size // dist.world_size + use_regression_net = cfg.model.use_regression_net + previous_step_conditioning = cfg.model.previous_step_conditioning + resume_checkpoint = cfg.training.resume_checkpoint + log_to_wandb = cfg.training.log_to_wandb + + loss_type = cfg.training.loss + if loss_type == "regression": + train_regression_unet = True + net_name = "regression" + elif loss_type == "edm": + train_regression_unet = False + net_name = "diffusion" + + # Seed and Performance settings + np.random.seed((cfg.training.seed * dist.world_size + dist.rank) % (1 << 31)) + torch.manual_seed(cfg.training.seed) + torch.backends.cudnn.benchmark = cfg.training.cudnn_benchmark + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + total_train_steps = cfg.training.total_train_steps + + # Load dataset. + logger0.info("Loading dataset...") + + dataset_train = HrrrEra5Dataset(cfg.dataset, train=True) + dataset_valid = HrrrEra5Dataset(cfg.dataset, train=False) + + _, hrrr_channels = dataset_train._get_hrrr_channel_names() + diffusion_channels = ( + hrrr_channels + if cfg.dataset.diffusion_channels == "all" + else cfg.dataset.diffusion_channels + ) + input_channels = ( + hrrr_channels + if cfg.dataset.input_channels == "all" + else cfg.dataset.input_channels + ) + input_channel_indices = [hrrr_channels.index(channel) for channel in input_channels] + diffusion_channel_indices = [ + hrrr_channels.index(channel) for channel in diffusion_channels + ] + + sampler = InfiniteSampler( + dataset=dataset_train, + rank=dist.rank, + num_replicas=dist.world_size, + seed=cfg.training.seed, + ) + valid_sampler = InfiniteSampler( + dataset=dataset_valid, + rank=dist.rank, + num_replicas=dist.world_size, + seed=cfg.training.seed, + ) + data_loader = torch.utils.data.DataLoader( + dataset=dataset_train, + batch_size=local_batch_size, + num_workers=cfg.training.num_data_workers, + sampler=sampler, + worker_init_fn=worker_init, + drop_last=True, + pin_memory=torch.cuda.is_available(), + ) + + valid_data_loader = torch.utils.data.DataLoader( + dataset=dataset_valid, + batch_size=local_batch_size, + num_workers=cfg.training.num_data_workers, + sampler=valid_sampler, + drop_last=True, + pin_memory=torch.cuda.is_available(), + ) + + dataset_iterator = iter(data_loader) + valid_dataset_iterator = iter(valid_data_loader) + + # load pretrained regression net if training diffusion + if use_regression_net: + regression_net = Module.from_checkpoint(cfg.model.regression_weights) + regression_net = regression_net.to(device) + + # Construct network + logger0.info("Constructing network...") + target_channels = len(diffusion_channels) + if train_regression_unet: + conditional_channels = ( + len(input_channels) + 26 + ) # 26 is the number of era5 channels + else: + conditional_channels = ( + len(input_channels) + if not previous_step_conditioning + else 2 * len(input_channels) + ) + + conditional_channels += len(cfg.dataset.invariants) + invariant_array = dataset_train._get_invariants() + invariant_tensor = torch.from_numpy(invariant_array).to(device) + invariant_tensor = invariant_tensor.unsqueeze(0) + invariant_tensor = invariant_tensor.repeat(local_batch_size, 1, 1, 1) + + logger0.info(f"hrrr_channels {hrrr_channels}") + logger0.info(f"target_channels for diffusion {target_channels}") + logger0.info(f"conditional_channels for diffusion {conditional_channels}") + + net = get_preconditioned_architecture( + name=net_name, + hrrr_resolution=tuple(cfg.dataset.hrrr_img_size), + target_channels=target_channels, + conditional_channels=conditional_channels, + spatial_embedding=cfg.model.spatial_pos_embed, + attn_resolutions=list(cfg.model.attn_resolutions), + ) + + net.train().requires_grad_(True).to(device) + + # Setup optimizer. + logger0.info("Setting up optimizer...") + if cfg.training.loss == "regression": + loss_fn = regression_loss_fn + elif cfg.training.loss == "edm": + loss_fn = EDMLoss(P_mean=cfg.model.P_mean) + optimizer = torch.optim.Adam(net.parameters(), lr=cfg.training.lr) + augment_pipe = None + ddp = torch.nn.parallel.DistributedDataParallel( + net, device_ids=[device], broadcast_buffers=False + ) + + # Resume training from previous snapshot. + total_steps = 0 + if resume_checkpoint is not None: + logger0.info(f'Resuming training state from "{resume_checkpoint}"...') + + total_steps = load_checkpoint( + path=os.path.join(cfg.training.rundir, "checkpoints"), + models=net, + optimizer=optimizer, + ) + + # Train. + logger0.info( + f"Training up to {total_train_steps} steps starting from step {total_steps}..." + ) + stats_jsonl = None + wandb_logs = {} + done = total_steps >= total_train_steps + + train_start = time.time() + avg_train_loss = 0 + train_steps = 0 + while not done: + # Accumulate gradients. + optimizer.zero_grad(set_to_none=True) + batch = next(dataset_iterator) + hrrr_0 = batch["hrrr"][0].to(device).to(torch.float32) + hrrr_1 = batch["hrrr"][1].to(device).to(torch.float32) + + if use_regression_net: + era5 = batch["era5"][0].to(device).to(torch.float32) + + with torch.no_grad(): + reg_out = regression_model_forward( + regression_net, hrrr_0, era5, invariant_tensor + ) + hrrr_0 = torch.cat( + ( + hrrr_0[:, input_channel_indices, :, :], + reg_out[:, input_channel_indices, :, :], + ), + dim=1, + ) + hrrr_1 = hrrr_1 - reg_out + del reg_out + + elif train_regression_unet: + assert diffusion_channel_indices == input_channel_indices + + era5 = batch["era5"][0].to(device).to(torch.float32) + + hrrr_0 = torch.cat((hrrr_0[:, input_channel_indices, :, :], era5), dim=1) + + hrrr_1 = hrrr_1[ + :, diffusion_channel_indices, :, : + ] # targets of the diffusion model + + hrrr_0 = torch.cat((hrrr_0, invariant_tensor), dim=1) + + loss = loss_fn( + net=ddp, images=hrrr_1, condition=hrrr_0, augment_pipe=augment_pipe + ) + channelwise_loss = loss.mean(dim=(0, 2, 3)) + channelwise_loss_dict = { + f"ChLoss/{diffusion_channels[i]}": channelwise_loss[i].item() + for i in range(target_channels) + } + if log_to_wandb: + wandb_logs["channelwise_loss"] = channelwise_loss_dict + + loss_value = loss.sum() / target_channels + loss_value.backward() + + if cfg.training.clip_grad_norm > 0: + clip_grad_norm_(net.parameters(), cfg.training.clip_grad_norm) + + # Update weights. + for g in optimizer.param_groups: + g["lr"] = cfg.training.lr * min( + total_steps / max(cfg.training.lr_rampup_steps, 1e-8), 1 + ) + if log_to_wandb: + wandb_logs["lr"] = g["lr"] + for param in net.parameters(): + if param.grad is not None: + torch.nan_to_num( + param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad + ) + + optimizer.step() + + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + + avg_train_loss += loss.mean().cpu().item() + train_steps += 1 + if log_to_wandb: + wandb_logs["loss"] = loss.mean().cpu().item() / train_steps + + # Perform maintenance tasks once per tick. + total_steps += 1 + done = total_steps >= total_train_steps + + # Perform validation step + if total_steps % cfg.training.validation_freq == 0: + valid_start = time.time() + batch = next(valid_dataset_iterator) + + with torch.no_grad(): + + hrrr_0, hrrr_1 = batch["hrrr"] + hrrr_0 = hrrr_0.to(torch.float32).to(device) + hrrr_1 = hrrr_1.to(torch.float32).to(device) + + if use_regression_net: + with torch.no_grad(): + era5 = batch["era5"][0].to(device).to(torch.float32) + reg_out = regression_model_forward( + regression_net, hrrr_0, era5, invariant_tensor + ) + hrrr_0 = torch.cat( + ( + hrrr_0[:, input_channel_indices, :, :], + reg_out[:, input_channel_indices, :, :], + ), + dim=1, + ) + + loss_target = hrrr_1 - reg_out + output_images = diffusion_model_forward( + net, + hrrr_0, + diffusion_channel_indices, + invariant_tensor, + sampler_args=dict(cfg.sampler.args), + ) + + valid_loss = loss_fn( + net=ddp, + images=loss_target[:, diffusion_channel_indices], + condition=torch.cat((hrrr_0, invariant_tensor), dim=1), + augment_pipe=augment_pipe, + ) + output_images += reg_out[:, diffusion_channel_indices, :, :] + del reg_out + + elif train_regression_unet: + assert ( + use_regression_net == False + ), "use_regression_net must be False when training regression unet" + assert ( + input_channel_indices == diffusion_channel_indices + ), "input_channel_indices must be equal to diffusion_channel_indices when training regression unet" + condition = torch.cat( + ( + hrrr_0[:, input_channel_indices, :, :], + era5[:], + invariant_tensor, + ), + dim=1, + ) + valid_loss, output_images = loss_fn( + net=ddp, + images=hrrr_1[:, diffusion_channel_indices, :, :], + condition=condition, + augment_pipe=augment_pipe, + return_model_outputs=True, + ) + + if log_to_wandb: + channelwise_valid_loss = valid_loss.mean(dim=[0, 2, 3]) + channelwise_valid_loss_dict = { + f"ChLoss_valid/{diffusion_channels[i]}": channelwise_valid_loss[ + i + ].item() + for i in range(target_channels) + } + wandb_logs[ + "channelwise_valid_loss" + ] = channelwise_valid_loss_dict + + hrrr_1 = hrrr_1[:, diffusion_channel_indices, :, :] + + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss, op=torch.distributed.ReduceOp.AVG + ) + val_loss = valid_loss.mean().cpu().item() + if log_to_wandb: + wandb_logs["valid_loss"] = val_loss + + # Save plots locally (and optionally to wandb) + if dist.rank == 0: + + for i in range(output_images.shape[0]): + image = output_images[i].cpu().numpy() + fields = ["u10m", "v10m", "t2m", "refc", "q1", "q5", "q10"] + + # Compute spectral metrics + figs, spec_ratios = ps1d_plots( + output_images[i], hrrr_1[i], fields, diffusion_channels + ) + + for f_ in fields: + f_index = diffusion_channels.index(f_) + image_dir = os.path.join(cfg.training.rundir, "images", f_) + generated = image[f_index] + truth = hrrr_1[i, f_index].cpu().numpy() + + fig, (a, b) = plt.subplots(1, 2) + im = a.imshow(generated) + a.set_title("generated, {}.png".format(f_)) + plt.colorbar(im, fraction=0.046, pad=0.04) + im = b.imshow(truth) + b.set_title("truth") + plt.colorbar(im, fraction=0.046, pad=0.04) + os.makedirs(image_dir, exist_ok=True) + plt.savefig( + os.path.join(image_dir, f"{total_steps}_{i}_{f_}.png") + ) + plt.close("all") + + specfig = "PS1D_" + f_ + figs[specfig].savefig( + os.path.join(image_dir, f"{total_steps}{i}{f_}_spec.png") + ) + plt.close(figs[specfig]) + if log_to_wandb: + # Save plots as wandb Images + for figname, plot in figs.items(): + wandb_logs[figname] = wandb.Image(plot) + wandb_logs.update({f"generated_{f_}": wandb.Image(fig)}) + + if log_to_wandb: + wandb_logs.update(spec_ratios) + wandb.log(wandb_logs, step=total_steps) + + valid_time = time.time() - valid_start + + # Print training stats + current_time = time.time() + if total_steps % cfg.training.print_progress_freq == 0: + fields = [] + fields += [f"steps {total_steps:<5d}"] + fields += [f"samples {total_steps*batch_size}"] + fields += [f"tot_time {current_time - start_time: .2f}"] + fields += [ + f"step_time {(current_time - train_start - valid_time) / train_steps : .2f}" + ] + fields += [f"valid_time {valid_time: .2f}"] + fields += [ + f"cpumem {psutil.Process(os.getpid()).memory_info().rss / 2**30:<6.2f}" + ] + fields += [ + f"gpumem {torch.cuda.max_memory_allocated(device) / 2**30:<6.2f}" + ] + fields += [f"train_loss {avg_train_loss/train_steps:<6.3f}"] + fields += [f"val_loss {val_loss:<6.3f}"] + logger0.info(" ".join(fields)) + + # Reset counters + train_steps = 0 + train_start = time.time() + avg_train_loss = 0 + torch.cuda.reset_peak_memory_stats() + + # Save full dump of the training state. + if ( + (done or total_steps % cfg.training.checkpoint_freq == 0) + and total_steps != 0 + and dist.rank == 0 + ): + + save_checkpoint( + path=os.path.join(cfg.training.rundir, "checkpoints"), + models=net, + optimizer=optimizer, + epoch=total_steps, + ) + + # Done. + torch.distributed.barrier() + logger0.info("\nExiting...") diff --git a/examples/generative/stormcast/utils/training_stats.py b/examples/generative/stormcast/utils/training_stats.py deleted file mode 100644 index 6fe70f6aff..0000000000 --- a/examples/generative/stormcast/utils/training_stats.py +++ /dev/null @@ -1,304 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Facilities for reporting and collecting training statistics across -multiple processes and devices. The interface is designed to minimize -synchronization overhead as well as the amount of boilerplate in user -code.""" - -import re -import numpy as np -import torch -from utils.misc import EasyDict - -from . import misc - -# ---------------------------------------------------------------------------- - -_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] -_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. -_counter_dtype = torch.float64 # Data type to use for the internal counters. -_rank = 0 # Rank of the current process. -_sync_device = ( - None # Device to use for multiprocess communication. None = single-process. -) -_sync_called = False # Has _sync() been called yet? -_counters = ( - dict() -) # Running counters on each device, updated by report(): name => device => torch.Tensor -_cumulative = ( - dict() -) # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor - -# ---------------------------------------------------------------------------- - - -def init_multiprocessing(rank, sync_device): - r"""Initializes `utils.training_stats` for collecting statistics - across multiple processes. - - This function must be called after - `torch.distributed.init_process_group()` and before `Collector.update()`. - The call is not necessary if multi-process collection is not needed. - - Args: - rank: Rank of the current process. - sync_device: PyTorch device to use for inter-process - communication, or None to disable multi-process - collection. Typically `torch.device('cuda', rank)`. - """ - global _rank, _sync_device - assert not _sync_called - _rank = rank - _sync_device = sync_device - - -# ---------------------------------------------------------------------------- - - -@misc.profiled_function -def report(name, value): - r"""Broadcasts the given set of scalars to all interested instances of - `Collector`, across device and process boundaries. - - This function is expected to be extremely cheap and can be safely - called from anywhere in the training loop, loss function, or inside a - `torch.nn.Module`. - - Warning: The current implementation expects the set of unique names to - be consistent across processes. Please make sure that `report()` is - called at least once for each unique name by each process, and in the - same order. If a given process has no scalars to broadcast, it can do - `report(name, [])` (empty list). - - Args: - name: Arbitrary string specifying the name of the statistic. - Averages are accumulated separately for each unique name. - value: Arbitrary set of scalars. Can be a list, tuple, - NumPy array, PyTorch tensor, or Python scalar. - - Returns: - The same `value` that was passed in. - """ - if name not in _counters: - _counters[name] = dict() - - elems = torch.as_tensor(value) - if elems.numel() == 0: - return value - - elems = elems.detach().flatten().to(_reduce_dtype) - moments = torch.stack( - [ - torch.ones_like(elems).sum(), - elems.sum(), - elems.square().sum(), - ] - ) - assert moments.ndim == 1 and moments.shape[0] == _num_moments - moments = moments.to(_counter_dtype) - - device = moments.device - if device not in _counters[name]: - _counters[name][device] = torch.zeros_like(moments) - _counters[name][device].add_(moments) - return value - - -# ---------------------------------------------------------------------------- - - -def report0(name, value): - r"""Broadcasts the given set of scalars by the first process (`rank = 0`), - but ignores any scalars provided by the other processes. - See `report()` for further details. - """ - report(name, value if _rank == 0 else []) - return value - - -# ---------------------------------------------------------------------------- - - -class Collector: - r"""Collects the scalars broadcasted by `report()` and `report0()` and - computes their long-term averages (mean and standard deviation) over - user-defined periods of time. - - The averages are first collected into internal counters that are not - directly visible to the user. They are then copied to the user-visible - state as a result of calling `update()` and can then be queried using - `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the - internal counters for the next round, so that the user-visible state - effectively reflects averages collected between the last two calls to - `update()`. - - Args: - regex: Regular expression defining which statistics to - collect. The default is to collect everything. - keep_previous: Whether to retain the previous averages if no - scalars were collected on a given round - (default: True). - """ - - def __init__(self, regex=".*", keep_previous=True): - self._regex = re.compile(regex) - self._keep_previous = keep_previous - self._cumulative = dict() - self._moments = dict() - self.update() - self._moments.clear() - - def names(self): - r"""Returns the names of all statistics broadcasted so far that - match the regular expression specified at construction time. - """ - return [name for name in _counters if self._regex.fullmatch(name)] - - def update(self): - r"""Copies current values of the internal counters to the - user-visible state and resets them for the next round. - - If `keep_previous=True` was specified at construction time, the - operation is skipped for statistics that have received no scalars - since the last update, retaining their previous averages. - - This method performs a number of GPU-to-CPU transfers and one - `torch.distributed.all_reduce()`. It is intended to be called - periodically in the main training loop, typically once every - N training steps. - """ - if not self._keep_previous: - self._moments.clear() - for name, cumulative in _sync(self.names()): - if name not in self._cumulative: - self._cumulative[name] = torch.zeros( - [_num_moments], dtype=_counter_dtype - ) - delta = cumulative - self._cumulative[name] - self._cumulative[name].copy_(cumulative) - if float(delta[0]) != 0: - self._moments[name] = delta - - def _get_delta(self, name): - r"""Returns the raw moments that were accumulated for the given - statistic between the last two calls to `update()`, or zero if - no scalars were collected. - """ - assert self._regex.fullmatch(name) - if name not in self._moments: - self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) - return self._moments[name] - - def num(self, name): - r"""Returns the number of scalars that were accumulated for the given - statistic between the last two calls to `update()`, or zero if - no scalars were collected. - """ - delta = self._get_delta(name) - return int(delta[0]) - - def mean(self, name): - r"""Returns the mean of the scalars that were accumulated for the - given statistic between the last two calls to `update()`, or NaN if - no scalars were collected. - """ - delta = self._get_delta(name) - if int(delta[0]) == 0: - return float("nan") - return float(delta[1] / delta[0]) - - def std(self, name): - r"""Returns the standard deviation of the scalars that were - accumulated for the given statistic between the last two calls to - `update()`, or NaN if no scalars were collected. - """ - delta = self._get_delta(name) - if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): - return float("nan") - if int(delta[0]) == 1: - return float(0) - mean = float(delta[1] / delta[0]) - raw_var = float(delta[2] / delta[0]) - return np.sqrt(max(raw_var - np.square(mean), 0)) - - def as_dict(self): - r"""Returns the averages accumulated between the last two calls to - `update()` as an `EasyDict`. The contents are as follows: - - EasyDict( - NAME = EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), - ... - ) - """ - stats = EasyDict() - for name in self.names(): - stats[name] = EasyDict( - num=self.num(name), mean=self.mean(name), std=self.std(name) - ) - return stats - - def __getitem__(self, name): - r"""Convenience getter. - `collector[name]` is a synonym for `collector.mean(name)`. - """ - return self.mean(name) - - -# ---------------------------------------------------------------------------- - - -def _sync(names): - r"""Synchronize the global cumulative counters across devices and - processes. Called internally by `Collector.update()`. - """ - if len(names) == 0: - return [] - global _sync_called - _sync_called = True - - # Collect deltas within current rank. - deltas = [] - device = _sync_device if _sync_device is not None else torch.device("cpu") - for name in names: - delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) - for counter in _counters[name].values(): - delta.add_(counter.to(device)) - counter.copy_(torch.zeros_like(counter)) - deltas.append(delta) - deltas = torch.stack(deltas) - - # Sum deltas across ranks. - if _sync_device is not None: - torch.distributed.all_reduce(deltas) - - # Update cumulative values. - deltas = deltas.cpu() - for idx, name in enumerate(names): - if name not in _cumulative: - _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) - _cumulative[name].add_(deltas[idx]) - - # Return name-value pairs. - return [(name, _cumulative[name]) for name in names] - - -# ---------------------------------------------------------------------------- -# Convenience. - -default_collector = Collector() - -# ---------------------------------------------------------------------------- diff --git a/modulus/metrics/diffusion/loss.py b/modulus/metrics/diffusion/loss.py index 12166eb500..9a60f04265 100644 --- a/modulus/metrics/diffusion/loss.py +++ b/modulus/metrics/diffusion/loss.py @@ -218,7 +218,7 @@ def __init__( self.P_std = P_std self.sigma_data = sigma_data - def __call__(self, net, images, labels=None, augment_pipe=None): + def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): """ Calculate and return the loss corresponding to the EDM formulation. @@ -256,7 +256,16 @@ def __call__(self, net, images, labels=None, augment_pipe=None): augment_pipe(images) if augment_pipe is not None else (images, None) ) n = torch.randn_like(y) * sigma - D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + if condition is not None: + D_yn = net( + y + n, + sigma, + condition=condition, + class_labels=labels, + augment_labels=augment_labels, + ) + else: + D_yn = net(y + n, sigma, labels=labels, augment_labels=augment_labels) loss = weight * ((D_yn - y) ** 2) return loss diff --git a/modulus/metrics/general/power_spectrum.py b/modulus/metrics/general/power_spectrum.py new file mode 100644 index 0000000000..bf56ef2aa7 --- /dev/null +++ b/modulus/metrics/general/power_spectrum.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch + + +def _batch_weighted_histogram( + data_tensor: torch.Tensor, num_classes: int = -1, weights: torch.Tensor = None +) -> torch.Tensor: + """ + Computes (optionally weighted) histogram of values in a Tensor, preserving the leading dimensions + + Args: + data_tensor: torch.Tensor + a D1 x ... x D_n torch.LongTensor + num_classes: int, optional + The number of classes/bins present in data. + If not provided (set to -1), tensor.max() + 1 is used + weights: torch.Tensor, optional + If provided, use values in weights to produce a weighted histogram + + Returns: + hist: torch.Tensor + A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor, + containing (weighted) histograms of the last dimension D_n of tensor, + """ + maxd = data_tensor.max() + nc = (maxd + 1) if num_classes <= 0 else num_classes + hist = torch.zeros( + (*data_tensor.shape[:-1], nc), + dtype=data_tensor.dtype, + device=data_tensor.device, + ) + if weights is not None: + wts = weights + else: + wts = torch.tensor(1, dtype=hist.dtype, device=hist.device).expand( + data_tensor.shape + ) + hist.scatter_add_(-1, ((data_tensor * nc) // (maxd + 1)).long(), wts) + return hist + + +def power_spectrum(x: torch.Tensor) -> Tuple[torch.Tensor]: + """Compute the wavenumber-averaged power spectrum of an input tensor x, + preserving the leading D - 2 dimensions for an input with D dimensions. + This routine will compute the 2D power from FFT coefficients, then perform + azimuthal averaging to get the 1D power spectrum as a function of total + wavenumber. + + Args: + x: torch.Tensor + Input tensor with at least three dimensions; the final two dims are + assumed to be the height and width of a regular 2D spatial domain + Shape: D1 x D2 x ... x h x w + + Returns: + k: torch.Tensor + Centers of the total wavenumber bins after azimuthal averaging + Number of bins is min(h//2, w//2) - 1, linearly spaced + power: torch.Tensor + Azimuthally averaged 1D power spectrum + Shape: D1 x ... x D_n-2 x min(h//2, w//2) - 1 + """ + + leading, (h, w) = x.shape[:-2], x.shape[-2:] + x = x.reshape(-1, h, w) + batch = x.shape[0] + + # 2D power + pwr = torch.fft.fftn(x, dim=(-2, -1), norm="ortho").abs() ** 2 + pwr = torch.fft.fftshift(pwr, dim=(-2, -1)).to(torch.float32) + + # Azimuthal average + xx, yy = torch.meshgrid( + torch.arange(h, device=pwr.device), + torch.arange(w, device=pwr.device), + indexing="ij", + ) + k = torch.hypot(xx - h // 2, yy - w / 2).to(torch.float32) + + sort = torch.argsort(k.flatten()) + k_sort = k.flatten()[sort] + pwr_sort = pwr.reshape(batch, -1)[:, sort] + + nbins = min(h // 2, w // 2) + k_bins = torch.linspace(0, k_sort.max() + 1, nbins) + k_bin_centers = 0.5 * (k_bins[1:] + k_bins[:-1]) + k_sort_stack = torch.tile(k_sort, dims=(batch, 1)) + + pwr_binned = _batch_weighted_histogram( + k_sort_stack, weights=pwr_sort, num_classes=nbins - 1 + ) + count_binned = _batch_weighted_histogram(k_sort_stack, num_classes=nbins - 1) + + power = pwr_binned / count_binned + k = k_bin_centers + + power = power.reshape(*leading, nbins - 1) + + return k, power diff --git a/modulus/models/diffusion/__init__.py b/modulus/models/diffusion/__init__.py index d3bab58dcc..3984bffd42 100644 --- a/modulus/models/diffusion/__init__.py +++ b/modulus/models/diffusion/__init__.py @@ -26,7 +26,7 @@ ) from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd from .dhariwal_unet import DhariwalUNet -from .unet import UNet +from .unet import UNet, StormCastUNet from .preconditioning import ( EDMPrecond, EDMPrecondSR, diff --git a/modulus/models/diffusion/preconditioning.py b/modulus/models/diffusion/preconditioning.py index 2106fc1efd..13822cd8a9 100644 --- a/modulus/models/diffusion/preconditioning.py +++ b/modulus/models/diffusion/preconditioning.py @@ -551,7 +551,10 @@ class EDMPrecond(Module): img_resolution : int Image resolution. img_channels : int - Number of color channels. + Number of color channels (for both input and output). If your model + requires a different number of input or output chanels, + override this by passing either of the optional + img_in_channels or img_out_channels args label_dim : int Number of class labels, 0 = unconditional, by default 0. use_fp16 : bool @@ -564,6 +567,13 @@ class EDMPrecond(Module): Expected standard deviation of the training data, by default 0.5. model_type :str Class name of the underlying model, by default "DhariwalUNet". + img_in_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the input + This is useful in the case of additional (conditional) channels + img_out_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the output **model_kwargs : dict Keyword arguments for the underlying model. @@ -584,11 +594,20 @@ def __init__( sigma_max=float("inf"), sigma_data=0.5, model_type="DhariwalUNet", + img_in_channels=None, + img_out_channels=None, **model_kwargs, ): super().__init__(meta=EDMPrecondMetaData) self.img_resolution = img_resolution - self.img_channels = img_channels + if img_in_channels is not None: + img_in_channels = img_in_channels + else: + img_in_channels = img_channels + if img_out_channels is not None: + img_out_channels = img_out_channels + else: + img_out_channels = img_channels self.label_dim = label_dim self.use_fp16 = use_fp16 @@ -599,13 +618,21 @@ def __init__( model_class = getattr(network_module, model_type) self.model = model_class( img_resolution=img_resolution, - in_channels=img_channels, - out_channels=img_channels, + in_channels=img_in_channels, + out_channels=img_out_channels, label_dim=label_dim, **model_kwargs, ) # TODO needs better handling - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + def forward( + self, + x, + sigma, + condition=None, + class_labels=None, + force_fp32=False, + **model_kwargs, + ): x = x.to(torch.float32) sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) class_labels = ( @@ -626,8 +653,13 @@ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs) c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() c_noise = sigma.log() / 4 + arg = c_in * x + + if condition is not None: + arg = torch.cat([arg, condition], dim=1) + F_x = self.model( - (c_in * x).to(dtype), + arg.to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs, diff --git a/modulus/models/diffusion/song_unet.py b/modulus/models/diffusion/song_unet.py index e99e4da3a7..41f048b93c 100644 --- a/modulus/models/diffusion/song_unet.py +++ b/modulus/models/diffusion/song_unet.py @@ -61,7 +61,7 @@ class MetaData(ModelMetaData): class SongUNet(Module): """ Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with - optional self-attention,embeddings, and encoder-decoder components. + optional self-attention, embeddings, and encoder-decoder components. This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder @@ -95,7 +95,7 @@ class SongUNet(Module): label_dropout : float, optional Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional - Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. @@ -109,6 +109,8 @@ class SongUNet(Module): Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. checkpoint_level : int, optional (default=0) How many layers should use gradient checkpointing, 0 is None + additive_pos_embed: bool = False, + Set to True to add a learned position embedding after the first conv (used in StormCast) Reference @@ -153,6 +155,7 @@ def __init__( decoder_type: str = "standard", resample_filter: List[int] = [1, 1], checkpoint_level: int = 0, + additive_pos_embed: bool = False, ): valid_embedding_types = ["fourier", "positional", "zero"] if embedding_type not in valid_embedding_types: @@ -206,6 +209,14 @@ def __init__( # set the threshold for checkpointing based on image resolution self.checkpoint_threshold = (self.img_shape_y >> checkpoint_level) + 1 + # Optional additive learned positition embed after the first conv + self.additive_pos_embed = additive_pos_embed + if self.additive_pos_embed: + self.spatial_emb = torch.nn.Parameter( + torch.randn(1, model_channels, self.img_shape_y, self.img_shape_x) + ) + torch.nn.init.trunc_normal_(self.spatial_emb, std=0.02) + # Mapping. if self.embedding_type != "zero": self.map_noise = ( @@ -358,6 +369,11 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None): x = skips[-1] = x + block(aux) elif "aux_residual" in name: x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + elif "_conv" in name: + x = block(x) + if self.additive_pos_embed: + x = x + self.spatial_emb.to(dtype=x.dtype) + skips.append(x) else: # For UNetBlocks check if we should use gradient checkpointing if isinstance(block, UNetBlock): diff --git a/modulus/models/diffusion/unet.py b/modulus/models/diffusion/unet.py index ef15afb793..d1a2aeca6b 100644 --- a/modulus/models/diffusion/unet.py +++ b/modulus/models/diffusion/unet.py @@ -164,3 +164,104 @@ def round_sigma(self, sigma): The tensor representation of the provided sigma value(s). """ return torch.as_tensor(sigma) + + +class StormCastUNet(Module): + """ + U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. + + Parameters + ----------- + img_resolution : int or List[int] + The resolution of the input/output image. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16: bool, optional + Execute the underlying model at FP16 precision?, by default False. + sigma_min: float, optional + Minimum supported noise level, by default 0. + sigma_max: float, optional + Maximum supported noise level, by default float('inf'). + sigma_data: float, optional + Expected standard deviation of the training data, by default 0.5. + model_type: str, optional + Class name of the underlying model, by default 'DhariwalUNet'. + **model_kwargs : dict + Keyword arguments for the underlying model. + + """ + + def __init__( + self, + img_resolution, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNet", + **model_kwargs, + ): + super().__init__(meta=MetaData("StormCastUNet")) + + if isinstance(img_resolution, int): + self.img_shape_x = self.img_shape_y = img_resolution + else: + self.img_shape_x = img_resolution[0] + self.img_shape_y = img_resolution[1] + + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels, + out_channels=img_out_channels, + **model_kwargs, + ) + + def forward(self, x, force_fp32=False, **model_kwargs): + """Run a forward pass of the StormCast regression U-Net. + + Args: + x (torch.Tensor): input to the U-Net + force_fp32 (bool, optional): force casting to fp_32 if True. Defaults to False. + + Raises: + ValueError: If input data type is a mismatch with provided options + + Returns: + D_x (torch.Tensor): Output (prediction) of the U-Net + """ + + x = x.to(torch.float32) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + F_x = self.model( + x.to(dtype), + torch.zeros(x.shape[0], dtype=x.dtype, device=x.device), + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = F_x.to(torch.float32) + return D_x diff --git a/modulus/utils/generative/deterministic_sampler.py b/modulus/utils/generative/deterministic_sampler.py index de7c3f0f17..a118900ea9 100644 --- a/modulus/utils/generative/deterministic_sampler.py +++ b/modulus/utils/generative/deterministic_sampler.py @@ -19,6 +19,8 @@ import nvtx import torch +from modulus.models.diffusion import EDMPrecond + # ruff: noqa: E731 @@ -185,9 +187,18 @@ def deterministic_sampler( # Euler step. h = t_next - t_hat - denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to( - torch.float64 - ) + if isinstance(net, EDMPrecond): + # Conditioning info is passed as keyword arg + denoised = net( + x_hat / s(t_hat), + sigma(t_hat), + condition=x_lr, + class_labels=class_labels, + ).to(torch.float64) + else: + denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to( + torch.float64 + ) d_cur = ( sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised @@ -198,9 +209,18 @@ def deterministic_sampler( if solver == "euler" or i == num_steps - 1: x_next = x_hat + h * d_cur else: - denoised = net(x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels).to( - torch.float64 - ) + if isinstance(net, EDMPrecond): + # Conditioning info is passed as keyword arg + denoised = net( + x_prime / s(t_prime), + sigma(t_prime), + condition=x_lr, + class_labels=class_labels, + ).to(torch.float64) + else: + denoised = net( + x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels + ).to(torch.float64) d_prime = ( sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised diff --git a/test/metrics/diffusion/test_losses.py b/test/metrics/diffusion/test_losses.py index dfdf01dc30..3372c6f4ca 100644 --- a/test/metrics/diffusion/test_losses.py +++ b/test/metrics/diffusion/test_losses.py @@ -57,6 +57,10 @@ def fake_net(y, sigma, labels, augment_labels=None): return torch.tensor([1.0]) +def fake_condition_net(y, sigma, condition, class_labels=None, augment_labels=None): + return torch.tensor([1.0]) + + def test_call_method_vp(): loss_func = VPLoss() @@ -131,10 +135,15 @@ def test_call_method_edm(): img = torch.tensor([[[[1.0]]]]) labels = None - # Without augmentation + # Without augmentation or conditioning loss_value = loss_func(fake_net, img, labels) assert isinstance(loss_value, torch.Tensor) + # With conditioning + condition = torch.tensor([[[[0.0]]]]) + loss_value = loss_func(fake_condition_net, img, condition=condition, labels=labels) + assert isinstance(loss_value, torch.Tensor) + # With augmentation def mock_augment_pipe(imgs): return imgs, None diff --git a/test/metrics/test_metrics_general.py b/test/metrics/test_metrics_general.py index d36fa527dc..b3c573a54a 100644 --- a/test/metrics/test_metrics_general.py +++ b/test/metrics/test_metrics_general.py @@ -25,6 +25,7 @@ import modulus.metrics.general.ensemble_metrics as em import modulus.metrics.general.entropy as ent import modulus.metrics.general.histogram as hist +import modulus.metrics.general.power_spectrum as ps import modulus.metrics.general.wasserstein as w from modulus.distributed.manager import DistributedManager @@ -842,3 +843,27 @@ def test_entropy(device, rtol: float = 1e-2, atol: float = 1e-2): torch.zeros((1,) + x_bin_counts.shape[1:], device=device), bin_edges, ) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_power_spectrum(device): + """Test the 2D power spectrum routine for correctness using a sine wave""" + h, w = 64, 64 + kx, ky = 4, 4 + amplitude = 1.0 + + # Create input sine wave + x = torch.arange(w).view(1, -1).repeat(h, 1).float() + y = torch.arange(h).view(-1, 1).repeat(1, w).float() + signal = amplitude * torch.sin(2 * np.pi * kx * x / w + 2 * np.pi * ky * y / h) + + # Compute the power spectrum (added batch/channel dims) + k, power = ps.power_spectrum(signal.unsqueeze(0).unsqueeze(0)) + + # Assert that the power at expected wavenumber is dominant + k_total = np.sqrt(kx**2 + ky**2) + k_index = (torch.abs(k - k_total)).argmin() + assert power[0, 0, k_index] > 0.9 * power[0, 0].max() # Dominant peak + assert (power[0, 0] < 1e-6).sum() > ( + power[0, 0].numel() * 0.9 + ) # Most bins are zero diff --git a/test/models/diffusion/test_preconditioning.py b/test/models/diffusion/test_preconditioning.py index 282e7bce03..30956d26dd 100644 --- a/test/models/diffusion/test_preconditioning.py +++ b/test/models/diffusion/test_preconditioning.py @@ -19,6 +19,7 @@ from modulus.launch.utils import load_checkpoint, save_checkpoint from modulus.models.diffusion.preconditioning import ( + EDMPrecond, EDMPrecondSR, VEPrecond_dfsr, VEPrecond_dfsr_cond, @@ -68,6 +69,43 @@ def test_EDMPrecondSR_serialization(tmp_path): assert epoch == 1 +@pytest.mark.parametrize("channels", [[0, 4], [3, 8], [3, 5]]) +def test_EDMPrecond_forward(channels): + res = [32, 64] + cond_ch, out_ch = channels + b = 1 + + # Create an instance of the preconditioner + model = EDMPrecond( + img_resolution=res, + img_channels=99, # dummy value, should be overwritten by following args + img_in_channels=out_ch + cond_ch, + img_out_channels=out_ch, + model_type="SongUNet", + ) + + latents = torch.randn(b, out_ch, *res) + sigma = torch.tensor([10.0]) + + if cond_ch > 0: + # Forward pass with conditioning + condition = torch.randn(b, cond_ch, *res) + output = model( + x=latents, + condition=condition, + sigma=sigma, + ) + else: + # Forward pass without conditioning + output = model( + x=latents, + sigma=sigma, + ) + + # Assert the output shape is correct + assert output.shape == (b, out_ch, *res) + + def test_VEPrecond_dfsr(): b, c, x, y = 1, 3, 256, 256 diff --git a/test/models/diffusion/test_song_unet.py b/test/models/diffusion/test_song_unet.py index 2cb5842214..dd3c6f96c7 100644 --- a/test/models/diffusion/test_song_unet.py +++ b/test/models/diffusion/test_song_unet.py @@ -83,6 +83,26 @@ def test_song_unet_constructor(device): output_image = model(input_image, noise_labels, class_labels) assert output_image.shape == (1, out_channels, img_resolution, img_resolution) + # DDM++ with additive pos embed + model_channels = 64 + model = UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + model_channels=model_channels, + additive_pos_embed=True, + ).to(device) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + output_image = model(input_image, noise_labels, class_labels) + assert model.spatial_emb.shape == ( + 1, + model_channels, + img_resolution, + img_resolution, + ) + # NCSN++ model = UNet( img_resolution=img_resolution, diff --git a/test/models/diffusion/test_unet_wrappers.py b/test/models/diffusion/test_unet_wrappers.py new file mode 100644 index 0000000000..5ae9c6af05 --- /dev/null +++ b/test/models/diffusion/test_unet_wrappers.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: E402 +import os +import sys + +import pytest +import torch + +script_path = os.path.abspath(__file__) +sys.path.append(os.path.join(os.path.dirname(script_path), "..")) + +import common + +from modulus.models.diffusion import StormCastUNet, UNet + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_unet_forwards(device): + """Test forward passes of UNet wrappers""" + + # Construct the UNet model + res, inc, outc = 64, 2, 3 + model = UNet( + img_resolution=res, + img_channels=inc, + img_in_channels=inc, + img_out_channels=outc, + model_type="SongUNet", + ).to(device) + input_image = torch.ones([1, inc, res, res]).to(device) + lr_image = torch.randn([1, outc, res, res]).to(device) + sigma = torch.randn([1]).to(device) + output = model(x=input_image, img_lr=lr_image, sigma=sigma) + assert output.shape == (1, outc, res, res) + + # Construct the StormCastUNet model + model = StormCastUNet( + img_resolution=res, img_in_channels=inc, img_out_channels=outc + ).to(device) + input_image = torch.ones([1, inc, res, res]).to(device) + output = model(x=input_image) + assert output.shape == (1, outc, res, res) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_unet_optims(device): + """Test optimizations of U-Net wrappers""" + + res, inc, outc = 64, 2, 3 + + def setup_model(): + + model = UNet( + img_resolution=res, + img_channels=inc, + img_in_channels=inc, + img_out_channels=outc, + model_type="SongUNet", + ).to(device) + input_image = torch.ones([1, inc, res, res]).to(device) + lr_image = torch.randn([1, outc, res, res]).to(device) + sigma = torch.randn([1]).to(device) + + return model, [input_image, lr_image, sigma] + + # Check AMP + model, invar = setup_model() + assert common.validate_amp(model, (*invar,)) + + def setup_model(): + model = StormCastUNet( + img_resolution=res, img_in_channels=inc, img_out_channels=outc + ).to(device) + input_image = torch.ones([1, inc, res, res]).to(device) + + return model, [input_image] + + # Check AMP + model, invar = setup_model() + assert common.validate_amp(model, (*invar,)) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_unet_checkpoint(device): + """Test UNet wrapper checkpoint save/load""" + # Construct UNet models + res, inc, outc = 64, 2, 3 + model_1 = UNet( + img_resolution=res, + img_channels=inc, + img_in_channels=inc, + img_out_channels=outc, + model_type="SongUNet", + ).to(device) + model_2 = UNet( + img_resolution=res, + img_channels=inc, + img_in_channels=inc, + img_out_channels=outc, + model_type="SongUNet", + ).to(device) + + input_image = torch.ones([1, inc, res, res]).to(device) + lr_image = torch.randn([1, outc, res, res]).to(device) + sigma = torch.randn([1]).to(device) + assert common.validate_checkpoint( + model_1, model_2, (*[input_image, lr_image, sigma],) + ) + + # Construct StormCastUNet models + res, inc, outc = 64, 2, 3 + model_1 = StormCastUNet( + img_resolution=res, img_in_channels=inc, img_out_channels=outc + ).to(device) + model_2 = StormCastUNet( + img_resolution=res, img_in_channels=inc, img_out_channels=outc + ).to(device) + + input_image = torch.ones([1, inc, res, res]).to(device) + assert common.validate_checkpoint(model_1, model_2, (input_image,))