diff --git a/docs/img/sfm_model.png b/docs/img/sfm_model.png
new file mode 100755
index 0000000000..81b22634e5
Binary files /dev/null and b/docs/img/sfm_model.png differ
diff --git a/examples/generative/corrdiff_plus_plus/README.md b/examples/generative/corrdiff_plus_plus/README.md
new file mode 100644
index 0000000000..b6b18db18a
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/README.md
@@ -0,0 +1,149 @@
+
+## CorrDiff++: Adaptive Flow Matching for Resolving Small-Scale Physics
+
+## Problem overview
+
+Conditional diffusion and flow-based models have shown promise for super-resolving fine-scale structure in natural images. However, applying them to weather and physical domains introduces unique challenges:
+(i) spatial misalignment between input and target fields due to differing PDE resolutions,
+(ii) mismatched and heterogeneous input-output channels (i.e., channel synthesis), and
+(iii) channel-dependent stochasticity.
+
+**CorrDiff** was proposed to address these issues but suffers from poor generalization—its regression-based residuals can overfit training data, leading to degraded performance on out-of-distribution inputs.
+
+To overcome these limitations, **CorDiff++** was introduced at NVIDIA. It relies on adaptive floe matching (AFM) that improves upon CorrDiff in generalization, calibration, and efficiency through several key innovations:
+
+- A joint encoder–generator architecture trained via flow matching
+- A latent base distribution that reconstructs the large-scale, deterministic signal
+- An adaptive noise scaling mechanism informed by the encoder’s RMSE, used to inject calibrated uncertainty
+- A final flow matching step to refine latent samples and synthesize fine-scale physical details
+
+AFM outperforms previous methods across both real-world (e.g., 25 → 2 km super-resolution in Taiwan) and synthetic (Kolmogorov flow) benchmarks—especially for highly stochastic output channels.
+
+📄 For details, see the [SFM paper (arXiv:2410.19814)](https://arxiv.org/abs/2410.19814).
+
+
+
+
+
+## Getting started
+
+To build custom CorrDiff++ versions, you can get started by training the "Mini" version of CorrDiff++, which uses smaller training samples and a smaller network to reduce training costs from thousands of GPU hours to around 10 hours on A100 GPUs while still producing reasonable results. It also includes a simple data loader that can be used as a baseline for training CorrDiff++ on custom datasets. [@mohammad: the mini version nees to be created similar to corrdiff and needs to be tested]
+
+### Preliminaries
+Start by installing Modulus (if not already installed) and copying this folder (`examples/generative/corrdiff++`) to a system with a GPU available. Also download the CorrDiff++ dataset from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets-hrrr_mini).
+
+### Configuration basics
+
+Similar to its earlier version, CorrDiff++ training is handled by `train.py` and controlled by YAML configuration files handled by [Hydra](https://hydra.cc/docs/intro/). Prebuilt configuration files are found in the `conf` directory. You can choose the configuration file using the `--config-name` option. The main configuration file specifies the training dataset, the model configuration and the training options. The details of these are given in the corresponding configuration files. To change a configuration option, you can either edit the configuration files or use the Hydra command line overrides. For example, the training batch size is controlled by the option `training.hp.total_batch_size`. We can override this from the command line with the `++` syntax: `python train.py ++training.hp.total_batch_size=64` would set run the training with the batch size set to 64.
+
+### Joint training of encoder and generator via flow matching
+
+To train CorrDiff++, we use the main configuration file [config_train.yaml](/conf/config_training_sfm.yaml). This includes the following components:
+To start the training, run:
+```bash
+python train.py --config-name=config_training_sfm.yaml ++dataset.data_path=/hrrr_mini_train.nc ++dataset.stats_path=/stats.json
+```
+
+The training will require a few hours on a single A100 GPU. If training is interrupted, it will automatically continue from the latest checkpoint when restarted. Multi-GPU and multi-node training are supported and will launch automatically when the training is run in a `torchrun` or MPI environment.
+
+The results, including logs and checkpoints, are saved by default to `outputs/mini_generation/`. You can direct the checkpoints to be saved elsewhere by setting: `++training.io.checkpoint_dir=`.
+
+> **_Out of memory?_** CorrDiff-Mini trains by default with a batch size of 256 (set by `training.hp.total_batch_size`). If you're using a single GPU, especially one with a smaller amout of memory, you might see out-of-memory error. If that happens, set a smaller batch size per GPU, e.g.: `++training.hp.batch_size_per_gpu=16`. CorrDiff training will then automatically use gradient accumulation to train with an effective batch size of `training.hp.total_batch_size`.
+
+
+### Generation
+
+Use the `generate.py` script to generate samples with the trained networks:
+```bash
+python generate.py --config-name="config_generate_sfm.yaml" ++generation.io.res_ckpt_filename= ++generation.io.reg_ckpt_filename= ++generation.io.output_filename=
+```
+where `` and `` should point to the encoder and diffusion model checkpoints, respectively, and `` indicates the output NetCDF4 file.
+
+You can open the output file with e.g. the Python NetCDF4 library. The inputs are saved in the `input` group of the file, the ground truth data in the `truth` group, and the CorrDiff prediction in the `prediction` group.
+
+## Configs
+
+The `conf` directory contains the configuration files for the model, data,
+training, etc. The configs are given in YAML format and use the `omegaconf`
+library to manage them. Several example configs are given for training
+different models that are encoder only, flow matching with pretrained encoder, and encoder with diffusion
+models.
+The default configs are set to train the encoder with diffusion model.
+To train the other models, please adjust `conf/config_training.yaml`
+according to the comments. Alternatively, you can create a new config file
+and specify it using the `--config-name` option.
+
+
+## Dataset & Datapipe
+
+In this example, CorrDiff training is demonstrated on the Taiwan dataset,
+conditioned on the [ERA5 dataset](https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5).
+We have made this dataset available for non-commercial use under the
+[CC BY-NC-ND 4.0 license](https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode.en)
+and can be downloaded from [https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets_cwa](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets_cwa)
+by `ngc registry resource download-version "nvidia/modulus/modulus_datasets_cwa:v1"`.
+The datapipe in this example is tailored specifically for the Taiwan dataset.
+A light-weight datapipe for the HRRR dataset is also available and can be used
+with the CorrDiff++ model.
+For other datasets, you will need to create a custom datapipe.
+You can use the lightweight HRRR datapipe as a starting point for developing your new one.
+
+
+
+### Sampling and Model Evaluation
+
+Model evaluation is split into two components. `generate.py` creates a netCDF file
+for the generated outputs, and `score_samples.py` computes deterministic and probablistic
+scores.
+
+To generate samples and save output in a netCDF file, run:
+
+```bash
+python generate.py
+```
+This will use the base configs specified in the `conf/config_generate.yaml` file.
+
+Next, to score the generated samples, run:
+
+```bash
+python score_samples.py path= output=
+```
+
+Some legacy plotting scripts are also available in the `inference` directory.
+You can also bring your checkpoints to [earth2studio]
+for further anaylysis and visualizations.
+
+## Logging
+
+We use TensorBoard for logging training and validation losses, as well as
+the learning rate during training. To visualize TensorBoard running in a
+Docker container on a remote server from your local desktop, follow these steps:
+
+1. **Expose the Port in Docker:**
+ Expose port 6006 in the Docker container by including
+ `-p 6006:6006` in your docker run command.
+
+2. **Launch TensorBoard:**
+ Start TensorBoard within the Docker container:
+ ```bash
+ tensorboard --logdir=/path/to/logdir --port=6006
+ ```
+
+3. **Set Up SSH Tunneling:**
+ Create an SSH tunnel to forward port 6006 from the remote server to your local machine:
+ ```bash
+ ssh -L 6006:localhost:6006 @
+ ```
+ Replace `` with your SSH username and `` with the IP address
+ of your remote server. You can use a different port if necessary.
+
+4. **Access TensorBoard:**
+ Open your web browser and navigate to `http://localhost:6006` to view TensorBoard.
+
+**Note:** Ensure the remote server’s firewall allows connections on port `6006`
+and that your local machine’s firewall allows outgoing connections.
+
+
+## References
+
+- [Adaptive Flow Matching for Resolving Small-Scale Physics](https://openreview.net/forum?id=YJ1My9ttEN)
diff --git a/examples/generative/corrdiff_plus_plus/conf/config_generate.yaml b/examples/generative/corrdiff_plus_plus/conf/config_generate.yaml
new file mode 100644
index 0000000000..30a0f38fe1
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/config_generate.yaml
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+hydra:
+ job:
+ chdir: true
+ name: generation
+ run:
+ dir: ./outputs/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/cwb_generate
+
+ # Generation
+ - generation/sfm
diff --git a/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_encoder.yaml
new file mode 100644
index 0000000000..05e69f12c4
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_encoder.yaml
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+hydra:
+ job:
+ chdir: true
+ name: generation
+ run:
+ dir: ./outputs/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/cwb_generate
+
+ # Generation
+ - generation/sfm_encoder
diff --git a/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_two_stage.yaml
new file mode 100644
index 0000000000..5b1661ff56
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_two_stage.yaml
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+hydra:
+ job:
+ chdir: true
+ name: generation
+ run:
+ dir: ./outputs/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/cwb_generate
+
+ # Generation
+ - generation/sfm_two_stage
diff --git a/examples/generative/corrdiff_plus_plus/conf/config_training.yaml b/examples/generative/corrdiff_plus_plus/conf/config_training.yaml
new file mode 100644
index 0000000000..3fa07a5e08
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/config_training.yaml
@@ -0,0 +1,42 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+hydra:
+ job:
+ chdir: true
+ name: sfm # choose from ["sfm_encoder", "sfm", "sfm_two_stage"]
+ run:
+ dir: ./outputs/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/cwb_train
+
+ # Model
+ #- model/sfm_encoder
+ - model/sfm
+ #- model/sfm_two_stage
+
+ # Training
+ #- training/sfm_encoder
+ - training/sfm
+ #- training/sfm_two_stage
+
+ # Validation (comment out to disable validation)
+ - validation/cwb
diff --git a/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_encoder.yaml
new file mode 100644
index 0000000000..73dd814b39
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_encoder.yaml
@@ -0,0 +1,42 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+hydra:
+ job:
+ chdir: true
+ name: sfm # choose from ["sfm_encoder", "sfm", "sfm_two_stage"]
+ run:
+ dir: ./outputs/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/cwb_train
+
+ # Model
+ - model/sfm_encoder
+ #- model/sfm
+ #- model/sfm_two_stage
+
+ # Training
+ - training/sfm_encoder
+ #- training/sfm
+ #- training/sfm_two_stage
+
+ # Validation (comment out to disable validation)
+ - validation/cwb
diff --git a/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_two_stage.yaml
new file mode 100644
index 0000000000..bf8d07273e
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_two_stage.yaml
@@ -0,0 +1,42 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+hydra:
+ job:
+ chdir: true
+ name: sfm # choose from ["sfm_encoder", "sfm", "sfm_two_stage"]
+ run:
+ dir: ./outputs/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/cwb_train
+
+ # Model
+ #- model/sfm_encoder
+ #- model/sfm
+ - model/sfm_two_stage
+
+ # Training
+ #- training/sfm_encoder
+ #- training/sfm
+ - training/sfm_two_stage
+
+ # Validation (comment out to disable validation)
+ - validation/cwb
diff --git a/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_generate.yaml b/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_generate.yaml
new file mode 100644
index 0000000000..787e1c9847
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_generate.yaml
@@ -0,0 +1,31 @@
+
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+type: cwb
+data_path: /code/2023-01-24-cwb-4years.zarr
+in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19]
+out_channels: [0, 17, 18, 19]
+img_shape_x: 448
+img_shape_y: 448
+add_grid: true
+ds_factor: 4
+min_path: null
+max_path: null
+global_means_path: null
+global_stds_path: null
+train: False
+all_times: True
\ No newline at end of file
diff --git a/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_train.yaml b/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_train.yaml
new file mode 100644
index 0000000000..5bcd741031
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_train.yaml
@@ -0,0 +1,29 @@
+
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+type: cwb
+data_path: /code/2023-01-24-cwb-4years.zarr
+in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19]
+out_channels: [0, 17, 18, 19]
+img_shape_x: 448
+img_shape_y: 448
+add_grid: true
+ds_factor: 4
+min_path: null
+max_path: null
+global_means_path: null
+global_stds_path: null
diff --git a/examples/generative/corrdiff_plus_plus/conf/generation/sfm.yaml b/examples/generative/corrdiff_plus_plus/conf/generation/sfm.yaml
new file mode 100644
index 0000000000..f5d6f7e28d
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/generation/sfm.yaml
@@ -0,0 +1,63 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+num_ensembles: 64
+ # Number of ensembles to generate per input
+seed_batch_size: 4
+ # Size of the batched inference
+inference_mode: sfm
+ # Choose between "sfm", "sfm_encoder" or "sfm_two_stage"
+learnable_sigma: True
+ # whether the model uses a learnable sigma
+gridtype: "sinusoidal"
+
+sampler:
+ rho: 7
+ # shape of noise schedule curve
+ num_steps: 50
+ # Number of sampling steps
+ sigma_min: 0.01
+ # Lowest noise level
+ t_min: 0.002
+ # clamp time step to this value
+
+N_grid_channels: 4
+times_range: null
+times:
+ - 2021-02-02T00:00:00
+ - 2021-03-02T00:00:00
+ - 2021-04-02T00:00:00
+ # hurricane
+ - 2021-09-12T00:00:00
+ - 2021-09-12T12:00:00
+
+perf:
+ force_fp16: false
+ # Whether to force fp16 precision for the model. If false, it'll use the precision
+ # specified upon training.
+ use_torch_compile: false
+ # whether to use torch.compile on the diffusion model
+ # this will make the first time stamp generation very slow due to compilation overheads
+ # but will significantly speed up subsequent inference runs
+ num_writer_workers: 1
+ # number of workers to use for writing file
+ # To support multiple workers a threadsafe version of the netCDF library must be used
+
+io:
+ encoder_ckpt_filename: Conv2dSerializable.0.895232.mdlus
+ # Checkpoint filename for the diffusion model
+ denoiser_ckpt_filename: EDMPrecondSR.0.895232.mdlus
+ # Checkpoint filename for the mean predictor model
diff --git a/examples/generative/corrdiff_plus_plus/conf/generation/sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/generation/sfm_encoder.yaml
new file mode 100644
index 0000000000..07a0663ef8
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/generation/sfm_encoder.yaml
@@ -0,0 +1,52 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+num_ensembles: 64
+ # Number of ensembles to generate per input
+seed_batch_size: 4
+ # Size of the batched inference
+inference_mode: sfm_encoder
+ # Choose between "sfm", "sfm_encoder" or "sfm_two_stage"
+gridtype: "sinusoidal"
+
+sampler: null
+ # no config options for the encoder sampler
+
+N_grid_channels: 4
+times_range: null
+times:
+ - 2021-02-02T00:00:00
+ - 2021-03-02T00:00:00
+ - 2021-04-02T00:00:00
+ # hurricane
+ - 2021-09-12T00:00:00
+ - 2021-09-12T12:00:00
+
+perf:
+ force_fp16: false
+ # Whether to force fp16 precision for the model. If false, it'll use the precision
+ # specified upon training.
+ use_torch_compile: false
+ # whether to use torch.compile on the diffusion model
+ # this will make the first time stamp generation very slow due to compilation overheads
+ # but will significantly speed up subsequent inference runs
+ num_writer_workers: 1
+ # number of workers to use for writing file
+ # To support multiple workers a threadsafe version of the netCDF library must be used
+
+io:
+ encoder_ckpt_filename: outputs/sfm/checkpoints_sfm_encoder/Conv2dSerializable.0.5120.mdlus
+ # Checkpoint filename for the encoder model
diff --git a/examples/generative/corrdiff_plus_plus/conf/generation/sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/generation/sfm_two_stage.yaml
new file mode 100644
index 0000000000..14d586a9b4
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/generation/sfm_two_stage.yaml
@@ -0,0 +1,24 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+defaults:
+ - sfm
+
+io:
+ encoder_ckpt_filename: outputs/sfm/checkpoints_sfm_two_stage/Conv2dSerializable.0.5120.mdlus
+ # Checkpoint filename for the encoder model
+ denoiser_ckpt_filename: outputs/sfm/checkpoints_sfm_two_stage/SFMPrecondSR.0.5120.mdlus
+ # Checkpoint filename for the diffusion model
diff --git a/examples/generative/corrdiff_plus_plus/conf/model/sfm.yaml b/examples/generative/corrdiff_plus_plus/conf/model/sfm.yaml
new file mode 100644
index 0000000000..60a0661b44
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/model/sfm.yaml
@@ -0,0 +1,43 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: sfm
+# Name of the model
+
+encoder_type: 1x1conv
+# Different encoder types for SFM: 1x1conv, songunet_s, songunet_xs, songunet_2xs, songunet_3xs
+encoder_loss_type: l2
+# Regularizer for the SFM model [null, l2, l1]
+encoder_loss_weight: 0.25
+# Regularizer weight for the SFM model [null, float]
+
+sigma_min: [0.002, 0.002, 0.002, 0.002] # for sampling
+# Minimum value of the noise sigma
+
+model_args:
+ dropout: 0.13
+ # Dropout probability
+ # Maximum value of the noise sigma
+ sigma_max:
+ initial_values: [1.0, 1.0, 1.0, 1.0]
+ learnable: True
+ # For learnable sigma_max
+ min_values: [0.05, 0.05, 0.05, 0.05]
+ # Minimum value of sigma_max distribution (only if learnable)
+ ema_weight: 0.99
+ # EMA weight for sigma_max (only if learnable)
+ use_x_low_conditioning: False
+ # Conditioning for the diffusion model on the upsampled ERA data [True, False]
diff --git a/examples/generative/corrdiff_plus_plus/conf/model/sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/model/sfm_encoder.yaml
new file mode 100644
index 0000000000..205538a592
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/model/sfm_encoder.yaml
@@ -0,0 +1,23 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: sfm_encoder
+# Name of the model
+
+encoder_type: 1x1conv
+# Different encoder types for SFM: 1x1conv, songunet_s, songunet_xs, songunet_2xs, songunet_3xs
+encoder_loss_type: l2
+# Regularizer for the SFM model [null, l2, l1]
diff --git a/examples/generative/corrdiff_plus_plus/conf/model/sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/model/sfm_two_stage.yaml
new file mode 100644
index 0000000000..09801532e3
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/model/sfm_two_stage.yaml
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: sfm_two_stage
+# Name of the model
+
+encoder_loss_type: l2
+# Regularizer for the SFM model [null, l2, l1]
+encoder_loss_weight: 0.25
+# Regularizer weight for the SFM model [null, float]
+
+sigma_min: [0.002, 0.002, 0.002, 0.002] # for sampling
+# Minimum value of the noise sigma
+
+model_args:
+ dropout: 0.13
+ # Dropout probability
+ # Maximum value of the noise sigma
+ sigma_max:
+ initial_values: [1.0, 1.0, 1.0, 1.0]
+ learnable: False
+ # For learnable sigma_max
+ min_values: [0.05, 0.05, 0.05, 0.05]
+ # Minimum value of sigma_max distribution (only if learnable)
+ ema_weight: 0.99
+ # EMA weight for sigma_max (only if learnable)
+ use_x_low_conditioning: False
+ # Conditioning for the diffusion model on the upsampled ERA data [True, False]
diff --git a/examples/generative/corrdiff_plus_plus/conf/training/sfm.yaml b/examples/generative/corrdiff_plus_plus/conf/training/sfm.yaml
new file mode 100644
index 0000000000..00a1a767a6
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/training/sfm.yaml
@@ -0,0 +1,59 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Hyperparameters
+hp:
+ training_duration: 8000000
+ # Training duration based on the number of processed samples
+ total_batch_size: 256
+ # Total batch size
+ batch_size_per_gpu: 4
+ # Batch size per GPU
+ lr: 0.0002
+ # Learning rate
+ grad_clip_threshold: null
+ # no gradient clipping for defualt non-patch-based training
+ lr_decay: 1
+ # LR decay rate
+ lr_rampup: 0
+ # Rampup for learning rate, in number of samples
+ ema: 0.5
+ # EMA half-life
+ ema_rampup_ratio: 0.05
+ # EMA ramp-up coefficient, None = no rampup.
+
+# Performance
+perf:
+ fp_optimizations: amp-bf16
+ # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
+ # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
+ dataloader_workers: 1
+ # DataLoader worker processes
+ songunet_checkpoint_level: 0 # 0 means no checkpointing
+ # Gradient checkpointing level, value is number of layers to checkpoint
+
+# I/O
+io:
+ # encoder_checkpoint_path: checkpoints/sfm_encoder.mdlus
+ # Where to load the regression checkpoint
+ print_progress_freq: 10000
+ # How often to print progress
+ save_checkpoint_freq: 5000
+ # How often to save the checkpoints, measured in number of processed samples
+ validation_freq: 5000
+ # how often to record the validation loss, measured in number of processed samples
+ validation_steps: 10
+ # how many loss evaluations are used to compute the validation loss per checkpoint
diff --git a/examples/generative/corrdiff_plus_plus/conf/training/sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/training/sfm_encoder.yaml
new file mode 100644
index 0000000000..7a88566287
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/training/sfm_encoder.yaml
@@ -0,0 +1,31 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+defaults:
+ - sfm
+
+# Hyperparameters
+hp:
+ training_duration: 2000000
+ # Training duration based on the number of processed samples
+ total_batch_size: 256
+ # Total batch size
+ batch_size_per_gpu: 32
+ # Batch size per GPU
+ lr: 0.0001
+ # Learning rate
+ grad_clip_threshold: null
+ # no gradient clipping for defualt non-patch-based training
diff --git a/examples/generative/corrdiff_plus_plus/conf/training/sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/training/sfm_two_stage.yaml
new file mode 100644
index 0000000000..dec4ae4983
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/training/sfm_two_stage.yaml
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+defaults:
+ - sfm
+
+# Hyperparameters
+hp:
+ training_duration: 4000000
+ # Training duration based on the number of processed samples
+ batch_size_per_gpu: 4
+ # Batch size per GPU
+ lr: 0.0001
+ # Learning rate
+
+# I/O
+io:
+ encoder_checkpoint_path: outputs/sfm/checkpoints_sfm_encoder/Conv2dSerializable.0.5120.mdlus
+ # Where to load the encoder checkpoint
diff --git a/examples/generative/corrdiff_plus_plus/conf/validation/cwb.yaml b/examples/generative/corrdiff_plus_plus/conf/validation/cwb.yaml
new file mode 100644
index 0000000000..f29f412a7c
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/conf/validation/cwb.yaml
@@ -0,0 +1,20 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Validation dataset options
+# (need to set dataset.train_test_split == true to have an effect)
+train: false
+all_times: false
diff --git a/examples/generative/corrdiff_plus_plus/datasets/base.py b/examples/generative/corrdiff_plus_plus/datasets/base.py
new file mode 100644
index 0000000000..22b00d252c
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/datasets/base.py
@@ -0,0 +1,85 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import List, Tuple
+
+import numpy as np
+import torch
+
+
+@dataclass
+class ChannelMetadata:
+ """Metadata describing a data channel."""
+
+ name: str
+ level: str = ""
+ auxiliary: bool = False
+
+
+class DownscalingDataset(torch.utils.data.Dataset, ABC):
+ """An abstract class that defines the interface for downscaling datasets."""
+
+ @abstractmethod
+ def longitude(self) -> np.ndarray:
+ """Get longitude values from the dataset."""
+ pass
+
+ @abstractmethod
+ def latitude(self) -> np.ndarray:
+ """Get latitude values from the dataset."""
+ pass
+
+ @abstractmethod
+ def input_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the input channels. A list of ChannelMetadata, one for each channel"""
+ pass
+
+ @abstractmethod
+ def output_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the output channels. A list of ChannelMetadata, one for each channel"""
+ pass
+
+ @abstractmethod
+ def time(self) -> List:
+ """Get time values from the dataset."""
+ pass
+
+ @abstractmethod
+ def image_shape(self) -> Tuple[int, int]:
+ """Get the (height, width) of the data (same for input and output)."""
+ pass
+
+ def normalize_input(self, x: np.ndarray) -> np.ndarray:
+ """Convert input from physical units to normalized data."""
+ return x
+
+ def denormalize_input(self, x: np.ndarray) -> np.ndarray:
+ """Convert input from normalized data to physical units."""
+ return x
+
+ def normalize_output(self, x: np.ndarray) -> np.ndarray:
+ """Convert output from physical units to normalized data."""
+ return x
+
+ def denormalize_output(self, x: np.ndarray) -> np.ndarray:
+ """Convert output from normalized data to physical units."""
+ return x
+
+ def info(self) -> dict:
+ """Get information about the dataset."""
+ return {}
diff --git a/examples/generative/corrdiff_plus_plus/datasets/cwb.py b/examples/generative/corrdiff_plus_plus/datasets/cwb.py
new file mode 100644
index 0000000000..91f469633c
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/datasets/cwb.py
@@ -0,0 +1,531 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Streaming images and labels from datasets created with dataset_tool.py."""
+
+import logging
+import random
+
+import cftime
+import cv2
+from hydra.utils import to_absolute_path
+import numpy as np
+import zarr
+
+from .base import ChannelMetadata, DownscalingDataset
+from .img_utils import reshape_fields
+from .norm import denormalize, normalize
+
+logger = logging.getLogger(__file__)
+
+
+def get_target_normalizations_v1(group):
+ """Get target normalizations using center and scale values from the 'group'."""
+ return group["cwb_center"][:], group["cwb_scale"][:]
+
+
+def get_target_normalizations_v2(group):
+ """Change the normalizations of the non-gaussian output variables"""
+ center = group["cwb_center"]
+ scale = group["cwb_scale"]
+ variable = group["cwb_variable"]
+
+ center = np.where(variable == "maximum_radar_reflectivity", 25.0, center)
+ center = np.where(variable == "eastward_wind_10m", 0.0, center)
+ center = np.where(variable == "northward_wind_10m", 0, center)
+
+ scale = np.where(variable == "maximum_radar_reflectivity", 25.0, scale)
+ scale = np.where(variable == "eastward_wind_10m", 20.0, scale)
+ scale = np.where(variable == "northward_wind_10m", 20.0, scale)
+ return center, scale
+
+
+class _ZarrDataset(DownscalingDataset):
+ """A Dataset for loading paired training data from a Zarr-file
+
+ This dataset should not be modified to add image processing contributions.
+ """
+
+ path: str
+
+ def __init__(
+ self, path: str, get_target_normalization=get_target_normalizations_v1
+ ):
+ self.path = path
+ self.group = zarr.open_consolidated(path)
+ self.get_target_normalization = get_target_normalization
+
+ # valid indices
+ cwb_valid = self.group["cwb_valid"]
+ era5_valid = self.group["era5_valid"]
+ if not (
+ era5_valid.ndim == 2
+ and cwb_valid.ndim == 1
+ and cwb_valid.shape[0] == era5_valid.shape[0]
+ ):
+ raise ValueError("Invalid dataset shape")
+ era5_all_channels_valid = np.all(era5_valid, axis=-1)
+ valid_times = cwb_valid & era5_all_channels_valid
+ # need to cast to bool since cwb_valis is stored as an int8 type in zarr.
+ self.valid_times = valid_times != 0
+
+ logger.info("Number of valid times: %d", len(self))
+ logger.info("input_channels:%s", self.input_channels())
+ logger.info("output_channels:%s", self.output_channels())
+
+ def _get_valid_time_index(self, idx):
+ time_indexes = np.arange(self.group["time"].size)
+ if not self.valid_times.dtype == np.bool_:
+ raise ValueError("valid_times must be a boolean array")
+ valid_time_indexes = time_indexes[self.valid_times]
+ return valid_time_indexes[idx]
+
+ def __getitem__(self, idx):
+ idx_to_load = self._get_valid_time_index(idx)
+ target = self.group["cwb"][idx_to_load]
+ input = self.group["era5"][idx_to_load]
+ label = 0
+
+ target = self.normalize_output(target[None, ...])[0]
+ input = self.normalize_input(input[None, ...])[0]
+
+ return target, input, label
+
+ def longitude(self):
+ """The longitude. useful for plotting"""
+ return self.group["XLONG"]
+
+ def latitude(self):
+ """The latitude. useful for plotting"""
+ return self.group["XLAT"]
+
+ def _get_channel_meta(self, variable, level):
+ if np.isnan(level):
+ level = ""
+ return ChannelMetadata(name=variable, level=str(level))
+
+ def input_channels(self):
+ """Metadata for the input channels. A list of dictionaries, one for each channel"""
+ variable = self.group["era5_variable"]
+ level = self.group["era5_pressure"]
+ return [self._get_channel_meta(*v) for v in zip(variable, level)]
+
+ def output_channels(self):
+ """Metadata for the output channels. A list of dictionaries, one for each channel"""
+ variable = self.group["cwb_variable"]
+ level = self.group["cwb_pressure"]
+ return [self._get_channel_meta(*v) for v in zip(variable, level)]
+
+ def _read_time(self):
+ """The vector of time coordinate has length (self)"""
+
+ return cftime.num2date(
+ self.group["time"], units=self.group["time"].attrs["units"]
+ )
+
+ def time(self):
+ """The vector of time coordinate has length (self)"""
+ time = self._read_time()
+ return time[self.valid_times].tolist()
+
+ def image_shape(self):
+ """Get the shape of the image (same for input and output)."""
+ return self.group["cwb"].shape[-2:]
+
+ def _select_norm_channels(self, means, stds, channels):
+ if channels is not None:
+ means = means[channels]
+ stds = stds[channels]
+ return (means, stds)
+
+ def normalize_input(self, x, channels=None):
+ """Convert input from physical units to normalized data."""
+ norm = self._select_norm_channels(
+ self.group["era5_center"], self.group["era5_scale"], channels
+ )
+ return normalize(x, *norm)
+
+ def denormalize_input(self, x, channels=None):
+ """Convert input from normalized data to physical units."""
+ norm = self._select_norm_channels(
+ self.group["era5_center"], self.group["era5_scale"], channels
+ )
+ return denormalize(x, *norm)
+
+ def normalize_output(self, x, channels=None):
+ """Convert output from physical units to normalized data."""
+ norm = self.get_target_normalization(self.group)
+ norm = self._select_norm_channels(*norm, channels)
+ return normalize(x, *norm)
+
+ def denormalize_output(self, x, channels=None):
+ """Convert output from normalized data to physical units."""
+ norm = self.get_target_normalization(self.group)
+ norm = self._select_norm_channels(*norm, channels)
+ return denormalize(x, *norm)
+
+ def info(self):
+ return {
+ "target_normalization": self.get_target_normalization(self.group),
+ "input_normalization": (
+ self.group["era5_center"][:],
+ self.group["era5_scale"][:],
+ ),
+ }
+
+ def __len__(self):
+ return self.valid_times.sum()
+
+
+class FilterTime(DownscalingDataset):
+ """Filter a time dependent dataset"""
+
+ def __init__(self, dataset, filter_fn):
+ """
+ Args:
+ filter_fn: if filter_fn(time) is True then return point
+ """
+ self._dataset = dataset
+ self._filter_fn = filter_fn
+ self._indices = [i for i, t in enumerate(self._dataset.time()) if filter_fn(t)]
+
+ def longitude(self):
+ """Get longitude values from the dataset."""
+ return self._dataset.longitude()
+
+ def latitude(self):
+ """Get latitude values from the dataset."""
+ return self._dataset.latitude()
+
+ def input_channels(self):
+ """Metadata for the input channels. A list of dictionaries, one for each channel"""
+ return self._dataset.input_channels()
+
+ def output_channels(self):
+ """Metadata for the output channels. A list of dictionaries, one for each channel"""
+ return self._dataset.output_channels()
+
+ def time(self):
+ """Get time values from the dataset."""
+ time = self._dataset.time()
+ return [time[i] for i in self._indices]
+
+ def info(self):
+ """Get information about the dataset."""
+ return self._dataset.info()
+
+ def image_shape(self):
+ """Get the shape of the image (same for input and output)."""
+ return self._dataset.image_shape()
+
+ def normalize_input(self, x, channels=None):
+ """Convert input from physical units to normalized data."""
+ return self._dataset.normalize_input(x, channels=channels)
+
+ def denormalize_input(self, x, channels=None):
+ """Convert input from normalized data to physical units."""
+ return self._dataset.denormalize_input(x, channels=channels)
+
+ def normalize_output(self, x, channels=None):
+ """Convert output from physical units to normalized data."""
+ return self._dataset.normalize_output(x, channels=channels)
+
+ def denormalize_output(self, x, channels=None):
+ """Convert output from normalized data to physical units."""
+ return self._dataset.denormalize_output(x, channels=channels)
+
+ def __getitem__(self, idx):
+ return self._dataset[self._indices[idx]]
+
+ def __len__(self):
+ return len(self._indices)
+
+
+def is_2021(time):
+ """Check if the given time is in the year 2021."""
+ return time.year == 2021
+
+
+def is_not_2021(time):
+ """Check if the given time is not in the year 2021."""
+ return not is_2021(time)
+
+
+class ZarrDataset(DownscalingDataset):
+ """A Dataset for loading paired training data from a Zarr-file with the
+ following schema::
+
+ xarray.Dataset {
+ dimensions:
+ south_north = 450 ;
+ west_east = 450 ;
+ west_east_stag = 451 ;
+ south_north_stag = 451 ;
+ time = 8760 ;
+ cwb_channel = 20 ;
+ era5_channel = 20 ;
+
+ variables:
+ float32 XLAT(south_north, west_east) ;
+ XLAT:FieldType = 104 ;
+ XLAT:MemoryOrder = XY ;
+ XLAT:description = LATITUDE, SOUTH IS NEGATIVE ;
+ XLAT:stagger = ;
+ XLAT:units = degree_north ;
+ float32 XLAT_U(south_north, west_east_stag) ;
+ XLAT_U:FieldType = 104 ;
+ XLAT_U:MemoryOrder = XY ;
+ XLAT_U:description = LATITUDE, SOUTH IS NEGATIVE ;
+ XLAT_U:stagger = X ;
+ XLAT_U:units = degree_north ;
+ float32 XLAT_V(south_north_stag, west_east) ;
+ XLAT_V:FieldType = 104 ;
+ XLAT_V:MemoryOrder = XY ;
+ XLAT_V:description = LATITUDE, SOUTH IS NEGATIVE ;
+ XLAT_V:stagger = Y ;
+ XLAT_V:units = degree_north ;
+ float32 XLONG(south_north, west_east) ;
+ XLONG:FieldType = 104 ;
+ XLONG:MemoryOrder = XY ;
+ XLONG:description = LONGITUDE, WEST IS NEGATIVE ;
+ XLONG:stagger = ;
+ XLONG:units = degree_east ;
+ float32 XLONG_U(south_north, west_east_stag) ;
+ XLONG_U:FieldType = 104 ;
+ XLONG_U:MemoryOrder = XY ;
+ XLONG_U:description = LONGITUDE, WEST IS NEGATIVE ;
+ XLONG_U:stagger = X ;
+ XLONG_U:units = degree_east ;
+ float32 XLONG_V(south_north_stag, west_east) ;
+ XLONG_V:FieldType = 104 ;
+ XLONG_V:MemoryOrder = XY ;
+ XLONG_V:description = LONGITUDE, WEST IS NEGATIVE ;
+ XLONG_V:stagger = Y ;
+ XLONG_V:units = degree_east ;
+ datetime64[ns] XTIME() ;
+ XTIME:FieldType = 104 ;
+ XTIME:MemoryOrder = 0 ;
+ XTIME:description = minutes since 2022-12-18 13:00:00 ;
+ XTIME:stagger = ;
+ float32 cwb(time, cwb_channel, south_north, west_east) ;
+ float32 cwb_center(cwb_channel) ;
+ float64 cwb_pressure(cwb_channel) ;
+ float32 cwb_scale(cwb_channel) ;
+ bool cwb_valid(time) ;
+ 1:
+ target = self._create_lowres_(target, factor=self.ds_factor)
+
+ reshape_args = (
+ y_roll,
+ self.train,
+ self.n_history,
+ self.in_channels,
+ self.out_channels,
+ self.img_shape_x,
+ self.img_shape_y,
+ self.min_path,
+ self.max_path,
+ self.global_means_path,
+ self.global_stds_path,
+ self.normalization,
+ self.roll,
+ )
+ # SR
+ input = reshape_fields(
+ input,
+ "inp",
+ *reshape_args,
+ normalize=False,
+ ) # 3x720x1440
+ target = reshape_fields(
+ target, "tar", *reshape_args, normalize=False
+ ) # 3x720x1440
+
+ return target, input, idx
+
+ def input_channels(self):
+ """Metadata for the input channels. A list of dictionaries, one for each channel"""
+ in_channels = self._dataset.input_channels()
+ in_channels = [in_channels[i] for i in self.in_channels]
+ return in_channels
+
+ def output_channels(self):
+ """Metadata for the output channels. A list of dictionaries, one for each channel"""
+ out_channels = self._dataset.output_channels()
+ return [out_channels[i] for i in self.out_channels]
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def longitude(self):
+ """Get longitude values from the dataset."""
+ lon = self._dataset.longitude()
+ return lon if self.train else lon[..., : self.img_shape_y, : self.img_shape_x]
+
+ def latitude(self):
+ """Get latitude values from the dataset."""
+ lat = self._dataset.latitude()
+ return lat if self.train else lat[..., : self.img_shape_y, : self.img_shape_x]
+
+ def time(self):
+ """Get time values from the dataset."""
+ return self._dataset.time()
+
+ def image_shape(self):
+ """Get the shape of the image (same for input and output)."""
+ return (self.img_shape_x, self.img_shape_y)
+
+ def normalize_input(self, x):
+ """Convert input from physical units to normalized data."""
+ x_norm = self._dataset.normalize_input(
+ x[:, : len(self.in_channels)], channels=self.in_channels
+ )
+ return np.concatenate((x_norm, x[:, self.in_channels :]), axis=1)
+
+ def denormalize_input(self, x):
+ """Convert input from normalized data to physical units."""
+ x_denorm = self._dataset.denormalize_input(
+ x[:, : len(self.in_channels)], channels=self.in_channels
+ )
+ return np.concatenate((x_denorm, x[:, len(self.in_channels) :]), axis=1)
+
+ def normalize_output(self, x):
+ """Convert output from physical units to normalized data."""
+ return self._dataset.normalize_output(x, channels=self.out_channels)
+
+ def denormalize_output(self, x):
+ """Convert output from normalized data to physical units."""
+ return self._dataset.denormalize_output(x, channels=self.out_channels)
+
+ def _create_highres_(self, x, shape):
+ # downsample the high res imag
+ x = x.transpose(1, 2, 0)
+ # upsample with bicubic interpolation to bring the image to the nominal size
+ x = cv2.resize(
+ x, (shape[0], shape[1]), interpolation=cv2.INTER_CUBIC
+ ) # 32x32x3
+ x = x.transpose(2, 0, 1) # 3x32x32
+ return x
+
+ def _create_lowres_(self, x, factor=4):
+ # downsample the high res imag
+ x = x.transpose(1, 2, 0)
+ x = x[::factor, ::factor, :] # 8x8x3 #subsample
+ # upsample with bicubic interpolation to bring the image to the nominal size
+ x = cv2.resize(
+ x, (x.shape[1] * factor, x.shape[0] * factor), interpolation=cv2.INTER_CUBIC
+ ) # 32x32x3
+ x = x.transpose(2, 0, 1) # 3x32x32
+ return x
+
+
+def get_zarr_dataset(*, data_path, normalization="v1", all_times=False, **kwargs):
+ """Get a Zarr dataset for training or evaluation."""
+ data_path = to_absolute_path(data_path)
+ get_target_normalization = {
+ "v1": get_target_normalizations_v1,
+ "v2": get_target_normalizations_v2,
+ }[normalization]
+ logger.info(f"Normalization: {normalization}")
+ zdataset = _ZarrDataset(
+ data_path, get_target_normalization=get_target_normalization
+ )
+ return ZarrDataset(
+ dataset=zdataset, normalization=normalization, all_times=all_times, **kwargs
+ )
diff --git a/examples/generative/corrdiff_plus_plus/datasets/dataset.py b/examples/generative/corrdiff_plus_plus/datasets/dataset.py
new file mode 100644
index 0000000000..d55f6038af
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/datasets/dataset.py
@@ -0,0 +1,107 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Tuple, Union
+import copy
+import torch
+
+from physicsnemo.utils.generative import InfiniteSampler
+from physicsnemo.distributed import DistributedManager
+
+from . import base, cwb, hrrrmini
+
+
+# this maps all known dataset types to the corresponding init function
+known_datasets = {"cwb": cwb.get_zarr_dataset, "hrrr_mini": hrrrmini.HRRRMiniDataset}
+
+
+def init_train_valid_datasets_from_config(
+ dataset_cfg: dict,
+ dataloader_cfg: Union[dict, None] = None,
+ batch_size: int = 1,
+ seed: int = 0,
+ validation_dataset_cfg: Union[dict, None] = None,
+ train_test_split: bool = True,
+) -> Tuple[
+ base.DownscalingDataset,
+ Iterable,
+ Union[base.DownscalingDataset, None],
+ Union[Iterable, None],
+]:
+ """
+ A wrapper function for managing the train-test split for the CWB dataset.
+
+ Parameters:
+ - dataset_cfg (dict): Configuration for the dataset.
+ - dataloader_cfg (dict, optional): Configuration for the dataloader. Defaults to None.
+ - batch_size (int): The number of samples in each batch of data. Defaults to 1.
+ - seed (int): The random seed for dataset shuffling. Defaults to 0.
+ - train_test_split (bool): A flag to determine whether to create a validation dataset. Defaults to True.
+
+ Returns:
+ - Tuple[base.DownscalingDataset, Iterable, Optional[base.DownscalingDataset], Optional[Iterable]]: A tuple containing the training dataset and iterator, and optionally the validation dataset and iterator if train_test_split is True.
+ """
+
+ config = copy.deepcopy(dataset_cfg)
+ (dataset, dataset_iter) = init_dataset_from_config(
+ config, dataloader_cfg, batch_size=batch_size, seed=seed
+ )
+ if train_test_split:
+ valid_dataset_cfg = copy.deepcopy(config)
+ if validation_dataset_cfg:
+ valid_dataset_cfg.update(validation_dataset_cfg)
+ (valid_dataset, valid_dataset_iter) = init_dataset_from_config(
+ valid_dataset_cfg, dataloader_cfg, batch_size=batch_size, seed=seed
+ )
+ else:
+ valid_dataset = valid_dataset_iter = None
+
+ return dataset, dataset_iter, valid_dataset, valid_dataset_iter
+
+
+def init_dataset_from_config(
+ dataset_cfg: dict,
+ dataloader_cfg: Union[dict, None] = None,
+ batch_size: int = 1,
+ seed: int = 0,
+) -> Tuple[base.DownscalingDataset, Iterable]:
+ dataset_cfg = copy.deepcopy(dataset_cfg)
+ dataset_type = dataset_cfg.pop("type", "cwb")
+ if "train_test_split" in dataset_cfg:
+ # handled by init_train_valid_datasets_from_config
+ del dataset_cfg["train_test_split"]
+ dataset_init_func = known_datasets[dataset_type]
+
+ dataset_obj = dataset_init_func(**dataset_cfg)
+ if dataloader_cfg is None:
+ dataloader_cfg = {}
+
+ dist = DistributedManager()
+ dataset_sampler = InfiniteSampler(
+ dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed
+ )
+
+ dataset_iterator = iter(
+ torch.utils.data.DataLoader(
+ dataset=dataset_obj,
+ sampler=dataset_sampler,
+ batch_size=batch_size,
+ worker_init_fn=None,
+ **dataloader_cfg,
+ )
+ )
+
+ return (dataset_obj, dataset_iterator)
diff --git a/examples/generative/corrdiff_plus_plus/datasets/hrrrmini.py b/examples/generative/corrdiff_plus_plus/datasets/hrrrmini.py
new file mode 100644
index 0000000000..4537e0780a
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/datasets/hrrrmini.py
@@ -0,0 +1,192 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+import math
+from typing import List, Tuple, Union
+
+import numpy as np
+from numba import jit, prange
+import xarray as xr
+
+from physicsnemo.utils.generative import convert_datetime_to_cftime
+
+from .base import ChannelMetadata, DownscalingDataset
+
+
+class HRRRMiniDataset(DownscalingDataset):
+ """Reader for reduced-size HRRR dataset used for CorrDiff-mini."""
+
+ def __init__(
+ self,
+ data_path: str,
+ stats_path: str,
+ input_variables: Union[List[str], None] = None,
+ output_variables: Union[List[str], None] = None,
+ invariant_variables: Union[List[str], None] = ("elev_mean", "lsm_mean"),
+ ):
+ # load data
+ (self.input, self.input_variables) = _load_dataset(
+ data_path, "input", input_variables
+ )
+ (self.output, self.output_variables) = _load_dataset(
+ data_path, "output", output_variables
+ )
+ (self.invariants, self.invariant_variables) = _load_dataset(
+ data_path, "invariant", invariant_variables, stack_axis=0
+ )
+
+ # load temporal and spatial coordinates
+ with xr.open_dataset(data_path) as ds:
+ self.times = np.array(ds["time"])
+ self.coords = np.array(ds["coord"])
+
+ self.img_shape = self.output.shape[-2:]
+ self.upsample_factor = self.output.shape[-1] // self.input.shape[-1]
+
+ # load normalization stats
+ with open(stats_path, "r") as f:
+ stats = json.load(f)
+ (input_mean, input_std) = _load_stats(stats, self.input_variables, "input")
+ (inv_mean, inv_std) = _load_stats(stats, self.invariant_variables, "invariant")
+ self.input_mean = np.concatenate([input_mean, inv_mean], axis=0)
+ self.input_std = np.concatenate([input_std, inv_std], axis=0)
+ (self.output_mean, self.output_std) = _load_stats(
+ stats, self.output_variables, "output"
+ )
+
+ def __getitem__(self, idx):
+ """Return the data sample (output, input, 0) at index idx."""
+ x = self.upsample(self.input[idx].copy())
+
+ # add invariants to input
+ (i, j) = self.coords[idx]
+ inv = self.invariants[:, i : i + self.img_shape[0], j : j + self.img_shape[1]]
+ x = np.concatenate([x, inv], axis=0)
+
+ y = self.output[idx]
+
+ x = self.normalize_input(x)
+ y = self.normalize_output(y)
+ return (y, x, 0)
+
+ def __len__(self):
+ return self.input.shape[0]
+
+ def longitude(self) -> np.ndarray:
+ """Get longitude values from the dataset."""
+ return np.full(self.img_shape, np.nan)
+
+ def latitude(self) -> np.ndarray:
+ """Get latitude values from the dataset."""
+ return np.full(self.img_shape, np.nan)
+
+ def input_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the input channels. A list of ChannelMetadata, one for each channel"""
+ inputs = [ChannelMetadata(name=v) for v in self.input_variables]
+ invariants = [
+ ChannelMetadata(name=v, auxiliary=True) for v in self.invariant_variables
+ ]
+ return inputs + invariants
+
+ def output_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the output channels. A list of ChannelMetadata, one for each channel"""
+ return [ChannelMetadata(name=v) for v in self.output_variables]
+
+ def time(self) -> List:
+ """Get time values from the dataset."""
+ datetimes = (
+ datetime.datetime.utcfromtimestamp(t.tolist() / 1e9) for t in self.times
+ )
+ return [convert_datetime_to_cftime(t) for t in datetimes]
+
+ def image_shape(self) -> Tuple[int, int]:
+ """Get the (height, width) of the data (same for input and output)."""
+ return self.img_shape
+
+ def normalize_input(self, x: np.ndarray) -> np.ndarray:
+ """Convert input from physical units to normalized data."""
+ return (x - self.input_mean) / self.input_std
+
+ def denormalize_input(self, x: np.ndarray) -> np.ndarray:
+ """Convert input from normalized data to physical units."""
+ return x * self.input_std + self.input_mean
+
+ def normalize_output(self, x: np.ndarray) -> np.ndarray:
+ """Convert output from physical units to normalized data."""
+ return (x - self.output_mean) / self.output_std
+
+ def denormalize_output(self, x: np.ndarray) -> np.ndarray:
+ """Convert output from normalized data to physical units."""
+ return x * self.output_std + self.output_mean
+
+ def upsample(self, x):
+ """Extend x around edges with linear extrapolation."""
+ y_shape = (
+ x.shape[0],
+ x.shape[1] * self.upsample_factor,
+ x.shape[2] * self.upsample_factor,
+ )
+ y = np.empty(y_shape, dtype=np.float32)
+ _zoom_extrapolate(x, y, self.upsample_factor)
+ return y
+
+
+def _load_dataset(data_path, group, variables=None, stack_axis=1):
+ with xr.open_dataset(data_path, group=group) as ds:
+ if variables is None:
+ variables = list(ds.keys())
+ data = np.stack([ds[v] for v in variables], axis=stack_axis)
+ return (data, variables)
+
+
+def _load_stats(stats, variables, group):
+ mean = np.array([stats[group][v]["mean"] for v in variables])[:, None, None].astype(
+ np.float32
+ )
+ std = np.array([stats[group][v]["std"] for v in variables])[:, None, None].astype(
+ np.float32
+ )
+ return (mean, std)
+
+
+@jit(nopython=True)
+def _zoom_extrapolate(x, y, factor):
+ """Bilinear zoom with extrapolation.
+ Use a numba function here because numpy/scipy options are rather slow.
+ """
+ s = 1 / factor
+ for k in prange(y.shape[0]):
+ for iy in range(y.shape[1]):
+ ix = (iy + 0.5) * s - 0.5
+ ix0 = int(math.floor(ix))
+ ix0 = max(0, min(ix0, x.shape[1] - 2))
+ ix1 = ix0 + 1
+ for jy in range(y.shape[2]):
+ jx = (jy + 0.5) * s - 0.5
+ jx0 = int(math.floor(jx))
+ jx0 = max(0, min(jx0, x.shape[2] - 2))
+ jx1 = jx0 + 1
+
+ x00 = x[k, ix0, jx0]
+ x01 = x[k, ix0, jx1]
+ x10 = x[k, ix1, jx0]
+ x11 = x[k, ix1, jx1]
+ djx = jx - jx0
+ x0 = x00 + djx * (x01 - x00)
+ x1 = x10 + djx * (x11 - x10)
+ y[k, iy, jy] = x0 + (ix - ix0) * (x1 - x0)
diff --git a/examples/generative/corrdiff_plus_plus/datasets/img_utils.py b/examples/generative/corrdiff_plus_plus/datasets/img_utils.py
new file mode 100644
index 0000000000..c354985c81
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/datasets/img_utils.py
@@ -0,0 +1,80 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+import torch
+
+
+def reshape_fields(
+ img,
+ inp_or_tar,
+ y_roll,
+ train,
+ n_history,
+ in_channels,
+ out_channels,
+ img_shape_x,
+ img_shape_y,
+ min_path,
+ max_path,
+ global_means_path,
+ global_stds_path,
+ normalization,
+ roll,
+ normalize=True,
+):
+ """
+ Takes in np array of size (n_history+1, c, h, w) and returns torch tensor of
+ size ((n_channels*(n_history+1), img_shape_x, img_shape_y)
+ """
+
+ if len(np.shape(img)) == 3:
+ img = np.expand_dims(img, 0)
+
+ if img.shape[3] > 720:
+ img = img[:, :, 0:720] # remove last pixel for era5 data
+
+ n_history = n_history
+
+ n_channels = np.shape(img)[1] # this will either be N_in_channels or N_out_channels
+ channels = in_channels if inp_or_tar == "inp" else out_channels
+
+ if normalize and train:
+ mins = np.load(min_path)[:, channels]
+ maxs = np.load(max_path)[:, channels]
+ means = np.load(global_means_path)[:, channels]
+ stds = np.load(global_stds_path)[:, channels]
+
+ img = img[:, :, :img_shape_x, :img_shape_y]
+
+ if normalize and train:
+ if normalization == "minmax":
+ img -= mins
+ img /= maxs - mins
+ elif normalization == "zscore":
+ img -= means
+ img /= stds
+
+ if roll:
+ img = np.roll(img, y_roll, axis=-1)
+
+ if inp_or_tar == "inp":
+ img = np.reshape(img, (n_channels * (n_history + 1), img_shape_x, img_shape_y))
+ elif inp_or_tar == "tar":
+ img = np.reshape(img, (n_channels, img_shape_x, img_shape_y))
+
+ return torch.as_tensor(img)
diff --git a/examples/generative/corrdiff_plus_plus/datasets/norm.py b/examples/generative/corrdiff_plus_plus/datasets/norm.py
new file mode 100644
index 0000000000..f1c50d13a3
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/datasets/norm.py
@@ -0,0 +1,40 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+
+def normalize(x, center, scale):
+ """Normalize input data 'x' using center and scale values."""
+ center = np.asarray(center)
+ scale = np.asarray(scale)
+ if not (center.ndim == 1 and scale.ndim == 1):
+ raise ValueError("center and scale must be 1D arrays")
+ return (x - center[np.newaxis, :, np.newaxis, np.newaxis]) / scale[
+ np.newaxis, :, np.newaxis, np.newaxis
+ ]
+
+
+def denormalize(x, center, scale):
+ """Denormalize input data 'x' using center and scale values."""
+ center = np.asarray(center)
+ scale = np.asarray(scale)
+ if not (center.ndim == 1 and scale.ndim == 1):
+ raise ValueError("center and scale must be 1D arrays")
+ return (
+ x * scale[np.newaxis, :, np.newaxis, np.newaxis]
+ + center[np.newaxis, :, np.newaxis, np.newaxis]
+ )
diff --git a/examples/generative/corrdiff_plus_plus/generate.py b/examples/generative/corrdiff_plus_plus/generate.py
new file mode 100644
index 0000000000..a758255ff2
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/generate.py
@@ -0,0 +1,297 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hydra
+from omegaconf import OmegaConf, DictConfig
+import torch
+import torch._dynamo
+import nvtx
+import numpy as np
+import netCDF4 as nc
+from physicsnemo.distributed import DistributedManager
+from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
+from physicsnemo import Module
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+from einops import rearrange
+from torch.distributed import gather
+import tqdm
+
+
+from hydra.utils import to_absolute_path
+from physicsnemo.utils.generative import (
+ SFM_Euler_sampler,
+ SFM_Euler_sampler_Adaptive_Sigma,
+ StackedRandomGenerator,
+ SFM_encoder_sampler,
+)
+from physicsnemo.utils.corrdiff import (
+ NetCDFWriter,
+ get_time_from_range,
+)
+
+
+from helpers.generate_helpers import (
+ get_dataset_and_sampler,
+ save_images,
+)
+from helpers.train_helpers import set_patch_shape
+
+
+@hydra.main(version_base="1.2", config_path="conf", config_name="config_generate")
+def main(cfg: DictConfig) -> None:
+ """Generate random images using the techniques described in the paper
+ "Elucidating the Design Space of Diffusion-Based Generative Models".
+ """
+
+ # Initialize distributed manager
+ DistributedManager.initialize()
+ dist = DistributedManager()
+ device = dist.device
+
+ # Initialize logger
+ logger = PythonLogger("generate") # General python logger
+ logger0 = RankZeroLoggingWrapper(logger, dist)
+ logger.file_logging("generate.log")
+
+ # Handle the batch size
+ seeds = list(np.arange(cfg.generation.num_ensembles))
+ num_batches = (
+ (len(seeds) - 1) // (cfg.generation.seed_batch_size * dist.world_size) + 1
+ ) * dist.world_size
+ all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
+ rank_batches = all_batches[dist.rank :: dist.world_size]
+
+ # Synchronize
+ if dist.world_size > 1:
+ torch.distributed.barrier()
+
+ # Parse the inference input times
+ if cfg.generation.times_range and cfg.generation.times:
+ raise ValueError("Either times_range or times must be provided, but not both")
+ if cfg.generation.times_range:
+ times = get_time_from_range(cfg.generation.times_range)
+ else:
+ times = cfg.generation.times
+
+ # Create dataset object
+ dataset_cfg = OmegaConf.to_container(cfg.dataset)
+ dataset, sampler = get_dataset_and_sampler(dataset_cfg=dataset_cfg, times=times)
+ img_shape = dataset.image_shape()
+ img_out_channels = len(dataset.output_channels())
+
+ # patching not supported for
+ patch_shape = (None, None)
+ img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
+
+ # Parse the inference mode
+ if cfg.generation.inference_mode not in ["sfm", "sfm_encoder", "sfm_two_stage"]:
+ raise ValueError(f"Invalid inference mode {cfg.generation.inference_mode}")
+
+ # Load networks, move to device, change precision
+ encoder_ckpt_filename = cfg.generation.io.encoder_ckpt_filename
+ logger0.info(f'Loading encoder network from "{encoder_ckpt_filename}"...')
+ encoder_net = Module.from_checkpoint(to_absolute_path(encoder_ckpt_filename))
+ encoder_net = encoder_net.eval().to(device).to(memory_format=torch.channels_last)
+
+ if cfg.generation.inference_mode in ["sfm", "sfm_two_stage"]:
+ denoiser_ckpt_filename = cfg.generation.io.denoiser_ckpt_filename
+ logger0.info(f'Loading residual network from "{denoiser_ckpt_filename}"...')
+ denoiser_net = Module.from_checkpoint(to_absolute_path(denoiser_ckpt_filename))
+ denoiser_net = (
+ denoiser_net.eval().to(device).to(memory_format=torch.channels_last)
+ )
+ else:
+ denoiser_net = None
+
+ if cfg.generation.perf.force_fp16:
+ encoder_net.use_fp16 = True
+ denoiser_net.use_fp16 = True
+
+ # Reset since we are using a different mode.
+ if cfg.generation.perf.use_torch_compile:
+ torch._dynamo.reset()
+ encoder_net = torch.compile(encoder_net, mode="reduce-overhead")
+ if denoiser_net:
+ denoiser_net = torch.compile(denoiser_net, mode="reduce-overhead")
+ networks = {"denoiser_net": denoiser_net, "encoder_net": encoder_net}
+
+ # Partially instantiate the sampler based on the configs
+ if cfg.generation.inference_mode in ["sfm", "sfm_two_stage"]:
+ if cfg.generation.learnable_sigma:
+ sampler_fn = SFM_Euler_sampler_Adaptive_Sigma
+ else:
+ sampler_fn = SFM_Euler_sampler
+ elif cfg.generation.inference_mode == "sfm_encoder":
+ sampler_fn = SFM_encoder_sampler
+ else:
+ raise ValueError(f"Unknown sampling method {cfg.generation.inference_mode}")
+
+ # Main generation definition
+ def generate_fn():
+ img_shape_y, img_shape_x = img_shape
+ with nvtx.annotate("generate_fn", color="green"):
+ all_images = []
+ for batch_seeds in tqdm.tqdm(
+ rank_batches, unit="batch", disable=(dist.rank != 0)
+ ):
+ batch_size = len(batch_seeds)
+ if batch_size == 0:
+ continue
+ rnd = StackedRandomGenerator(device, batch_seeds)
+ with nvtx.annotate(
+ f"{cfg.generation.inference_mode} model", color="rapids"
+ ):
+ with torch.inference_mode():
+ images = sampler_fn(
+ networks=networks,
+ img_lr=image_lr,
+ randn_like=rnd.randn_like,
+ cfg=cfg.generation.sampler,
+ )
+ all_images.append(images)
+ image_out = torch.cat(all_images)
+
+ # Gather tensors on rank 0
+ if dist.world_size > 1:
+ if dist.rank == 0:
+ gathered_tensors = [
+ torch.zeros_like(
+ image_out, dtype=image_out.dtype, device=image_out.device
+ )
+ for _ in range(dist.world_size)
+ ]
+ else:
+ gathered_tensors = None
+
+ torch.distributed.barrier()
+ gather(
+ image_out,
+ gather_list=gathered_tensors if dist.rank == 0 else None,
+ dst=0,
+ )
+
+ if dist.rank == 0:
+ return torch.cat(gathered_tensors)
+ else:
+ return None
+ else:
+ return image_out
+
+ # generate images
+ output_path = getattr(cfg.generation.io, "output_filename", "corrdiff_output.nc")
+ logger0.info(f"Generating images, saving results to {output_path}...")
+ batch_size = 1
+ warmup_steps = min(len(times) - 1, 2)
+ # Generates model predictions from the input data using the specified
+ # `generate_fn`, and save the predictions to the provided NetCDF file. It iterates
+ # through the dataset using a data loader, computes predictions, and saves them along
+ # with associated metadata.
+ if dist.rank == 0:
+ f = nc.Dataset(output_path, "w")
+ # add attributes
+ f.cfg = str(cfg)
+
+ with torch.cuda.profiler.profile():
+ with torch.autograd.profiler.emit_nvtx():
+
+ data_loader = torch.utils.data.DataLoader(
+ dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True
+ )
+ time_index = -1
+ if dist.rank == 0:
+ writer = NetCDFWriter(
+ f,
+ lat=dataset.latitude(),
+ lon=dataset.longitude(),
+ input_channels=dataset.input_channels(),
+ output_channels=dataset.output_channels(),
+ )
+
+ # Initialize threadpool for writers
+ writer_executor = ThreadPoolExecutor(
+ max_workers=cfg.generation.perf.num_writer_workers
+ )
+ writer_threads = []
+
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ times = dataset.time()
+ for image_tar, image_lr, index in iter(data_loader):
+ time_index += 1
+ if dist.rank == 0:
+ logger0.info(f"starting index: {time_index}")
+
+ if time_index == warmup_steps:
+ start.record()
+
+ # continue
+ image_lr = (
+ image_lr.to(device=device)
+ .to(torch.float32)
+ .to(memory_format=torch.channels_last)
+ )
+ # expand to batch size
+ image_lr = image_lr.expand(
+ cfg.generation.seed_batch_size, -1, -1, -1
+ ).to(memory_format=torch.channels_last)
+ image_tar = image_tar.to(device=device).to(torch.float32)
+ image_out = generate_fn()
+
+ if dist.rank == 0:
+ batch_size = image_out.shape[0]
+ # write out data in a seperate thread so we don't hold up inferencing
+ writer_threads.append(
+ writer_executor.submit(
+ save_images,
+ writer,
+ dataset,
+ list(times),
+ image_out.cpu(),
+ image_tar.cpu(),
+ image_lr.cpu(),
+ time_index,
+ index[0],
+ )
+ )
+ end.record()
+ end.synchronize()
+ elapsed_time = start.elapsed_time(end) / 1000.0 # Convert ms to s
+ timed_steps = time_index + 1 - warmup_steps
+ if dist.rank == 0:
+ average_time_per_batch_element = elapsed_time / timed_steps / batch_size
+ logger.info(
+ f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s"
+ )
+ logger.info(
+ f"Average time per batch element = {average_time_per_batch_element} s"
+ )
+
+ # make sure all the workers are done writing
+ if dist.rank == 0:
+ for thread in list(writer_threads):
+ thread.result()
+ writer_threads.remove(thread)
+ writer_executor.shutdown()
+
+ if dist.rank == 0:
+ f.close()
+ logger0.info("Generation Completed.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/generative/corrdiff_plus_plus/helpers/generate_helpers.py b/examples/generative/corrdiff_plus_plus/helpers/generate_helpers.py
new file mode 100644
index 0000000000..755cb95759
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/helpers/generate_helpers.py
@@ -0,0 +1,110 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+from datasets.base import DownscalingDataset
+from datasets.dataset import init_dataset_from_config
+from physicsnemo.utils.generative import convert_datetime_to_cftime
+
+
+def get_dataset_and_sampler(dataset_cfg, times):
+ """
+ Get a dataset and sampler for generation.
+ """
+ (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1)
+ plot_times = [
+ convert_datetime_to_cftime(
+ datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S")
+ )
+ for time in times
+ ]
+ all_times = dataset.time()
+ time_indices = [all_times.index(t) for t in plot_times]
+ sampler = time_indices
+
+ return dataset, sampler
+
+
+def save_images(
+ writer,
+ dataset: DownscalingDataset,
+ times,
+ image_out,
+ image_tar,
+ image_lr,
+ time_index,
+ t_index,
+):
+ """
+ Saves inferencing result along with the baseline
+
+ Parameters
+ ----------
+
+ writer (NetCDFWriter): Where the data is being written
+ in_channels (List): List of the input channels being used
+ input_channel_info (Dict): Description of the input channels
+ out_channels (List): List of the output channels being used
+ output_channel_info (Dict): Description of the output channels
+ input_norm (Tuple): Normalization data for input
+ target_norm (Tuple): Normalization data for the target
+ image_out (torch.Tensor): Generated output data
+ image_tar (torch.Tensor): Ground truth data
+ image_lr (torch.Tensor): Low resolution input data
+ time_index (int): Epoch number
+ t_index (int): index where times are located
+ """
+ # weather sub-plot
+ image_lr2 = image_lr[0].unsqueeze(0)
+ image_lr2 = image_lr2.cpu().numpy()
+ image_lr2 = dataset.denormalize_input(image_lr2)
+
+ image_tar2 = image_tar[0].unsqueeze(0)
+ image_tar2 = image_tar2.cpu().numpy()
+ image_tar2 = dataset.denormalize_output(image_tar2)
+
+ # some runtime assertions
+ if image_tar2.ndim != 4:
+ raise ValueError("image_tar2 must be 4-dimensional")
+
+ for idx in range(image_out.shape[0]):
+ image_out2 = image_out[idx].unsqueeze(0)
+ if image_out2.ndim != 4:
+ raise ValueError("image_out2 must be 4-dimensional")
+
+ # Denormalize the input and outputs
+ image_out2 = image_out2.cpu().numpy()
+ image_out2 = dataset.denormalize_output(image_out2)
+
+ time = times[t_index]
+ writer.write_time(time_index, time)
+ for channel_idx in range(image_out2.shape[1]):
+ info = dataset.output_channels()[channel_idx]
+ channel_name = info.name + info.level
+ truth = image_tar2[0, channel_idx]
+
+ writer.write_truth(channel_name, time_index, truth)
+ writer.write_prediction(
+ channel_name, time_index, idx, image_out2[0, channel_idx]
+ )
+
+ input_channel_info = dataset.input_channels()
+ for channel_idx in range(len(input_channel_info)):
+ info = input_channel_info[channel_idx]
+ channel_name = info.name + info.level
+ writer.write_input(channel_name, time_index, image_lr2[0, channel_idx])
+ if channel_idx == image_lr2.shape[1] - 1:
+ break
diff --git a/examples/generative/corrdiff_plus_plus/helpers/sfm_utils.py b/examples/generative/corrdiff_plus_plus/helpers/sfm_utils.py
new file mode 100644
index 0000000000..99243ff4dd
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/helpers/sfm_utils.py
@@ -0,0 +1,76 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from physicsnemo.models.diffusion import SongUNetPosEmbd
+from physicsnemo.models.diffusion import Conv2dSerializable
+import torch
+from omegaconf import DictConfig
+
+
+def get_encoder(cfg: DictConfig):
+ """
+ Helper that sets instantiates a
+
+ Parameters
+ ----------
+ cfg: DictConfig
+ configuration for the encoder
+
+ Returns
+ torch.nn.Module: The encoder
+ """
+ in_channels = len(cfg.dataset["in_channels"])
+ out_channels = len(cfg.dataset["out_channels"])
+ encoder_type = cfg.model["encoder_type"]
+
+ if encoder_type == "1x1conv":
+ encoder = Conv2dSerializable(in_channels, out_channels, kernel_size=1)
+ elif "songunet" in encoder_type:
+ model_channels_dict = {
+ "songunet_s": 32, # 11.60M
+ "songunet_xs": 16, # 2.90M
+ "songunet_2xs": 8, # 0.74M
+ "songunet_3xs": 4, # 0.19M
+ }
+ if hasattr(cfg.model, "songunet_checkpoint_level"):
+ songunet_checkpoint_level = cfg.model.songunet_checkpoint_level
+ else:
+ songunet_checkpoint_level = None
+
+ songunet_kwargs = {
+ "embedding_type": "zero",
+ "label_dim": 0,
+ "encoder_type": "standard",
+ "decoder_type": "standard",
+ "channel_mult_noise": 1,
+ "resample_filter": [1, 1],
+ "channel_mult": [1, 2, 2, 4, 4],
+ "attn_resolutions": [28],
+ "N_grid_channels": 0,
+ "dropout": cfg.model.dropout,
+ "checkpoint_level": songunet_checkpoint_level,
+ "model_channels": model_channels_dict[encoder_type],
+ }
+ encoder = SongUNetPosEmbd(
+ img_resolution=cfg.dataset["img_shape_x"],
+ in_channels=in_channels,
+ out_channels=out_channels,
+ **songunet_kwargs,
+ )
+ else:
+ raise ValueError(f"Unknown encoder type: {encoder_type}")
+
+ return encoder
diff --git a/examples/generative/corrdiff_plus_plus/helpers/train_helpers.py b/examples/generative/corrdiff_plus_plus/helpers/train_helpers.py
new file mode 100644
index 0000000000..d4529ac821
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/helpers/train_helpers.py
@@ -0,0 +1,107 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import numpy as np
+from omegaconf import ListConfig
+
+
+def set_patch_shape(img_shape, patch_shape):
+ img_shape_y, img_shape_x = img_shape
+ patch_shape_y, patch_shape_x = patch_shape
+ if (patch_shape_x is None) or (patch_shape_x > img_shape_x):
+ patch_shape_x = img_shape_x
+ if (patch_shape_y is None) or (patch_shape_y > img_shape_y):
+ patch_shape_y = img_shape_y
+ if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y:
+ if patch_shape_x != patch_shape_y:
+ raise NotImplementedError("Rectangular patch not supported yet")
+ if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0:
+ raise ValueError("Patch shape needs to be a multiple of 32")
+ return (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x)
+
+
+def set_seed(rank):
+ """
+ Set seeds for NumPy and PyTorch to ensure reproducibility in distributed settings
+ """
+ np.random.seed(rank % (1 << 31))
+ torch.manual_seed(np.random.randint(1 << 31))
+
+
+def configure_cuda_for_consistent_precision():
+ """
+ Configures CUDA and cuDNN settings to ensure consistent precision by
+ disabling TensorFloat-32 (TF32) and reduced precision settings.
+ """
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.allow_tf32 = False
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
+
+
+def compute_num_accumulation_rounds(total_batch_size, batch_size_per_gpu, world_size):
+ """
+ Calculate the total batch size per GPU in a distributed setting, log the batch size per GPU, ensure it's within valid limits,
+ determine the number of accumulation rounds, and validate that the global batch size matches the expected value.
+ """
+ batch_gpu_total = total_batch_size // world_size
+ batch_size_per_gpu = batch_size_per_gpu
+ if batch_size_per_gpu is None or batch_size_per_gpu > batch_gpu_total:
+ batch_size_per_gpu = batch_gpu_total
+ num_accumulation_rounds = batch_gpu_total // batch_size_per_gpu
+ if total_batch_size != batch_size_per_gpu * num_accumulation_rounds * world_size:
+ raise ValueError(
+ "total_batch_size must be equal to batch_size_per_gpu * num_accumulation_rounds * world_size"
+ )
+ return batch_gpu_total, num_accumulation_rounds
+
+
+def handle_and_clip_gradients(model, grad_clip_threshold=None):
+ """
+ Handles NaNs and infinities in the gradients and optionally clips the gradients.
+
+ Parameters:
+ - model (torch.nn.Module): The model whose gradients need to be processed.
+ - grad_clip_threshold (float, optional): The threshold for gradient clipping. If None, no clipping is performed.
+ """
+ # Replace NaNs and infinities in gradients
+ for param in model.parameters():
+ if param.grad is not None:
+ torch.nan_to_num(
+ param.grad, nan=0.0, posinf=1e5, neginf=-1e5, out=param.grad
+ )
+
+ # Clip gradients if a threshold is provided
+ if grad_clip_threshold is not None:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold)
+
+
+def parse_model_args(args):
+ """Convert ListConfig values in args to tuples."""
+ return {k: tuple(v) if isinstance(v, ListConfig) else v for k, v in args.items()}
+
+
+def is_time_for_periodic_task(
+ cur_nimg, freq, done, batch_size, rank, rank_0_only=False
+):
+ """Should we perform a task that is done every `freq` samples?"""
+ if rank_0_only and rank != 0:
+ return False
+ elif done: # Run periodic tasks also at the end of training
+ return True
+ else:
+ return cur_nimg % freq < batch_size
diff --git a/examples/generative/corrdiff_plus_plus/inference/concat.py b/examples/generative/corrdiff_plus_plus/inference/concat.py
new file mode 100644
index 0000000000..ee44633d59
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/inference/concat.py
@@ -0,0 +1,42 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+
+import dask.diagnostics
+import xarray
+
+base = sys.argv[1:-1]
+out = sys.argv[-1]
+
+with dask.diagnostics.ProgressBar():
+ t = xarray.open_mfdataset(
+ base,
+ group="prediction",
+ concat_dim="ensemble",
+ combine="nested",
+ chunks={"time": 1, "ensemble": 10},
+ )
+ t.to_zarr(out, group="prediction")
+
+ t = xarray.open_dataset(base[0], group="input", chunks={"time": 1})
+ t.to_zarr(out, group="input", mode="a")
+
+ t = xarray.open_dataset(base[0], group="truth", chunks={"time": 1})
+ t.to_zarr(out, group="truth", mode="a")
+
+ t = xarray.open_dataset(base[0])
+ t.to_zarr(out, mode="a")
diff --git a/examples/generative/corrdiff_plus_plus/inference/matplotlibrc b/examples/generative/corrdiff_plus_plus/inference/matplotlibrc
new file mode 100644
index 0000000000..54e097e836
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/inference/matplotlibrc
@@ -0,0 +1,770 @@
+#### MATPLOTLIBRC FORMAT
+
+## NOTE FOR END USERS: DO NOT EDIT THIS FILE!
+##
+## This is a sample Matplotlib configuration file - you can find a copy
+## of it on your system in site-packages/matplotlib/mpl-data/matplotlibrc
+## (relative to your Python installation location).
+##
+## You should find a copy of it on your system at
+## site-packages/matplotlib/mpl-data/matplotlibrc (relative to your Python
+## installation location). DO NOT EDIT IT!
+##
+## If you wish to change your default style, copy this file to one of the
+## following locations:
+## Unix/Linux:
+## $HOME/.config/matplotlib/matplotlibrc OR
+## $XDG_CONFIG_HOME/matplotlib/matplotlibrc (if $XDG_CONFIG_HOME is set)
+## Other platforms:
+## $HOME/.matplotlib/matplotlibrc
+## and edit that copy.
+##
+## See https://matplotlib.org/users/customizing.html#the-matplotlibrc-file
+## for more details on the paths which are checked for the configuration file.
+##
+## Blank lines, or lines starting with a comment symbol, are ignored, as are
+## trailing comments. Other lines must have the format:
+## key: val # optional comment
+##
+## Formatting: Use PEP8-like style (as enforced in the rest of the codebase).
+## All lines start with an additional '#', so that removing all leading '#'s
+## yields a valid style file.
+##
+## Colors: for the color values below, you can either use
+## - a Matplotlib color string, such as r, k, or b
+## - an RGB tuple, such as (1.0, 0.5, 0.0)
+## - a hex string, such as ff00ff
+## - a scalar grayscale intensity such as 0.75
+## - a legal html color name, e.g., red, blue, darkslategray
+##
+## Matplotlib configuration are currently divided into following parts:
+## - BACKENDS
+## - LINES
+## - PATCHES
+## - HATCHES
+## - BOXPLOT
+## - FONT
+## - TEXT
+## - LaTeX
+## - AXES
+## - DATES
+## - TICKS
+## - GRIDS
+## - LEGEND
+## - FIGURE
+## - IMAGES
+## - CONTOUR PLOTS
+## - ERRORBAR PLOTS
+## - HISTOGRAM PLOTS
+## - SCATTER PLOTS
+## - AGG RENDERING
+## - PATHS
+## - SAVING FIGURES
+## - INTERACTIVE KEYMAPS
+## - ANIMATION
+
+##### CONFIGURATION BEGINS HERE
+
+
+## ***************************************************************************
+## * BACKENDS *
+## ***************************************************************************
+## The default backend. If you omit this parameter, the first working
+## backend from the following list is used:
+## MacOSX QtAgg Gtk4Agg Gtk3Agg TkAgg WxAgg Agg
+## Other choices include:
+## QtCairo GTK4Cairo GTK3Cairo TkCairo WxCairo Cairo
+## Qt5Agg Qt5Cairo Wx # deprecated.
+## PS PDF SVG Template
+## You can also deploy your own backend outside of Matplotlib by referring to
+## the module name (which must be in the PYTHONPATH) as 'module://my_backend'.
+##backend: Agg
+
+## The port to use for the web server in the WebAgg backend.
+#webagg.port: 8988
+
+## The address on which the WebAgg web server should be reachable
+#webagg.address: 127.0.0.1
+
+## If webagg.port is unavailable, a number of other random ports will
+## be tried until one that is available is found.
+#webagg.port_retries: 50
+
+## When True, open the web browser to the plot that is shown
+#webagg.open_in_browser: True
+
+## If you are running pyplot inside a GUI and your backend choice
+## conflicts, we will automatically try to find a compatible one for
+## you if backend_fallback is True
+#backend_fallback: True
+
+#interactive: False
+#toolbar: toolbar2 # {None, toolbar2, toolmanager}
+#timezone: UTC # a pytz timezone string, e.g., US/Central or Europe/Paris
+
+
+## ***************************************************************************
+## * LINES *
+## ***************************************************************************
+## See https://matplotlib.org/api/artist_api.html#module-matplotlib.lines
+## for more information on line properties.
+#lines.linewidth: 1.5 # line width in points
+#lines.linestyle: - # solid line
+#lines.color: C0 # has no affect on plot(); see axes.prop_cycle
+#lines.marker: None # the default marker
+#lines.markerfacecolor: auto # the default marker face color
+#lines.markeredgecolor: auto # the default marker edge color
+#lines.markeredgewidth: 1.0 # the line width around the marker symbol
+#lines.markersize: 6 # marker size, in points
+#lines.dash_joinstyle: round # {miter, round, bevel}
+#lines.dash_capstyle: butt # {butt, round, projecting}
+#lines.solid_joinstyle: round # {miter, round, bevel}
+#lines.solid_capstyle: projecting # {butt, round, projecting}
+#lines.antialiased: True # render lines in antialiased (no jaggies)
+
+## The three standard dash patterns. These are scaled by the linewidth.
+#lines.dashed_pattern: 3.7, 1.6
+#lines.dashdot_pattern: 6.4, 1.6, 1, 1.6
+#lines.dotted_pattern: 1, 1.65
+#lines.scale_dashes: True
+
+#markers.fillstyle: full # {full, left, right, bottom, top, none}
+
+#pcolor.shading: auto
+#pcolormesh.snap: True # Whether to snap the mesh to pixel boundaries. This is
+ # provided solely to allow old test images to remain
+ # unchanged. Set to False to obtain the previous behavior.
+
+## ***************************************************************************
+## * PATCHES *
+## ***************************************************************************
+## Patches are graphical objects that fill 2D space, like polygons or circles.
+## See https://matplotlib.org/api/artist_api.html#module-matplotlib.patches
+## for more information on patch properties.
+#patch.linewidth: 1 # edge width in points.
+#patch.facecolor: C0
+#patch.edgecolor: black # if forced, or patch is not filled
+#patch.force_edgecolor: False # True to always use edgecolor
+#patch.antialiased: True # render patches in antialiased (no jaggies)
+
+
+## ***************************************************************************
+## * HATCHES *
+## ***************************************************************************
+#hatch.color: black
+#hatch.linewidth: 1.0
+
+
+## ***************************************************************************
+## * BOXPLOT *
+## ***************************************************************************
+#boxplot.notch: False
+#boxplot.vertical: True
+#boxplot.whiskers: 1.5
+#boxplot.bootstrap: None
+#boxplot.patchartist: False
+#boxplot.showmeans: False
+#boxplot.showcaps: True
+#boxplot.showbox: True
+#boxplot.showfliers: True
+#boxplot.meanline: False
+
+#boxplot.flierprops.color: black
+#boxplot.flierprops.marker: o
+#boxplot.flierprops.markerfacecolor: none
+#boxplot.flierprops.markeredgecolor: black
+#boxplot.flierprops.markeredgewidth: 1.0
+#boxplot.flierprops.markersize: 6
+#boxplot.flierprops.linestyle: none
+#boxplot.flierprops.linewidth: 1.0
+
+#boxplot.boxprops.color: black
+#boxplot.boxprops.linewidth: 1.0
+#boxplot.boxprops.linestyle: -
+
+#boxplot.whiskerprops.color: black
+#boxplot.whiskerprops.linewidth: 1.0
+#boxplot.whiskerprops.linestyle: -
+
+#boxplot.capprops.color: black
+#boxplot.capprops.linewidth: 1.0
+#boxplot.capprops.linestyle: -
+
+#boxplot.medianprops.color: C1
+#boxplot.medianprops.linewidth: 1.0
+#boxplot.medianprops.linestyle: -
+
+#boxplot.meanprops.color: C2
+#boxplot.meanprops.marker: ^
+#boxplot.meanprops.markerfacecolor: C2
+#boxplot.meanprops.markeredgecolor: C2
+#boxplot.meanprops.markersize: 6
+#boxplot.meanprops.linestyle: --
+#boxplot.meanprops.linewidth: 1.0
+
+
+## ***************************************************************************
+## * FONT *
+## ***************************************************************************
+## The font properties used by `text.Text`.
+## See https://matplotlib.org/api/font_manager_api.html for more information
+## on font properties. The 6 font properties used for font matching are
+## given below with their default values.
+##
+## The font.family property can take either a concrete font name (not supported
+## when rendering text with usetex), or one of the following five generic
+## values:
+## - 'serif' (e.g., Times),
+## - 'sans-serif' (e.g., Helvetica),
+## - 'cursive' (e.g., Zapf-Chancery),
+## - 'fantasy' (e.g., Western), and
+## - 'monospace' (e.g., Courier).
+## Each of these values has a corresponding default list of font names
+## (font.serif, etc.); the first available font in the list is used. Note that
+## for font.serif, font.sans-serif, and font.monospace, the first element of
+## the list (a DejaVu font) will always be used because DejaVu is shipped with
+## Matplotlib and is thus guaranteed to be available; the other entries are
+## left as examples of other possible values.
+##
+## The font.style property has three values: normal (or roman), italic
+## or oblique. The oblique style will be used for italic, if it is not
+## present.
+##
+## The font.variant property has two values: normal or small-caps. For
+## TrueType fonts, which are scalable fonts, small-caps is equivalent
+## to using a font size of 'smaller', or about 83%% of the current font
+## size.
+##
+## The font.weight property has effectively 13 values: normal, bold,
+## bolder, lighter, 100, 200, 300, ..., 900. Normal is the same as
+## 400, and bold is 700. bolder and lighter are relative values with
+## respect to the current weight.
+##
+## The font.stretch property has 11 values: ultra-condensed,
+## extra-condensed, condensed, semi-condensed, normal, semi-expanded,
+## expanded, extra-expanded, ultra-expanded, wider, and narrower. This
+## property is not currently implemented.
+##
+## The font.size property is the default font size for text, given in points.
+## 10 pt is the standard value.
+##
+## Note that font.size controls default text sizes. To configure
+## special text sizes tick labels, axes, labels, title, etc., see the rc
+## settings for axes and ticks. Special text sizes can be defined
+## relative to font.size, using the following values: xx-small, x-small,
+## small, medium, large, x-large, xx-large, larger, or smaller
+
+#font.family: sans-serif
+#font.style: normal
+#font.variant: normal
+#font.weight: normal
+#font.stretch: normal
+#font.size: 10.0
+
+#font.serif: DejaVu Serif, Bitstream Vera Serif, Computer Modern Roman, New Century Schoolbook, Century Schoolbook L, Utopia, ITC Bookman, Bookman, Nimbus Roman No9 L, Times New Roman, Times, Palatino, Charter, serif
+#font.sans-serif: DejaVu Sans, Bitstream Vera Sans, Computer Modern Sans Serif, Lucida Grande, Verdana, Geneva, Lucid, Arial, Helvetica, Avant Garde, sans-serif
+#font.cursive: Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, Comic Neue, Comic Sans MS, cursive
+#font.fantasy: Chicago, Charcoal, Impact, Western, Humor Sans, xkcd, fantasy
+#font.monospace: DejaVu Sans Mono, Bitstream Vera Sans Mono, Computer Modern Typewriter, Andale Mono, Nimbus Mono L, Courier New, Courier, Fixed, Terminal, monospace
+
+
+## ***************************************************************************
+## * TEXT *
+## ***************************************************************************
+## The text properties used by `text.Text`.
+## See https://matplotlib.org/api/artist_api.html#module-matplotlib.text
+## for more information on text properties
+#text.color: black
+
+## FreeType hinting flag ("foo" corresponds to FT_LOAD_FOO); may be one of the
+## following (Proprietary Matplotlib-specific synonyms are given in parentheses,
+## but their use is discouraged):
+## - default: Use the font's native hinter if possible, else FreeType's auto-hinter.
+## ("either" is a synonym).
+## - no_autohint: Use the font's native hinter if possible, else don't hint.
+## ("native" is a synonym.)
+## - force_autohint: Use FreeType's auto-hinter. ("auto" is a synonym.)
+## - no_hinting: Disable hinting. ("none" is a synonym.)
+#text.hinting: force_autohint
+
+#text.hinting_factor: 8 # Specifies the amount of softness for hinting in the
+ # horizontal direction. A value of 1 will hint to full
+ # pixels. A value of 2 will hint to half pixels etc.
+#text.kerning_factor: 0 # Specifies the scaling factor for kerning values. This
+ # is provided solely to allow old test images to remain
+ # unchanged. Set to 6 to obtain previous behavior.
+ # Values other than 0 or 6 have no defined meaning.
+#text.antialiased: True # If True (default), the text will be antialiased.
+ # This only affects raster outputs.
+
+
+## ***************************************************************************
+## * LaTeX *
+## ***************************************************************************
+## For more information on LaTeX properties, see
+## https://matplotlib.org/tutorials/text/usetex.html
+#text.usetex: False # use latex for all text handling. The following fonts
+ # are supported through the usual rc parameter settings:
+ # new century schoolbook, bookman, times, palatino,
+ # zapf chancery, charter, serif, sans-serif, helvetica,
+ # avant garde, courier, monospace, computer modern roman,
+ # computer modern sans serif, computer modern typewriter
+#text.latex.preamble: # IMPROPER USE OF THIS FEATURE WILL LEAD TO LATEX FAILURES
+ # AND IS THEREFORE UNSUPPORTED. PLEASE DO NOT ASK FOR HELP
+ # IF THIS FEATURE DOES NOT DO WHAT YOU EXPECT IT TO.
+ # text.latex.preamble is a single line of LaTeX code that
+ # will be passed on to the LaTeX system. It may contain
+ # any code that is valid for the LaTeX "preamble", i.e.
+ # between the "\documentclass" and "\begin{document}"
+ # statements.
+ # Note that it has to be put on a single line, which may
+ # become quite long.
+ # The following packages are always loaded with usetex,
+ # so beware of package collisions:
+ # geometry, inputenc, type1cm.
+ # PostScript (PSNFSS) font packages may also be
+ # loaded, depending on your font settings.
+
+## The following settings allow you to select the fonts in math mode.
+#mathtext.fontset: dejavusans # Should be 'dejavusans' (default),
+ # 'dejavuserif', 'cm' (Computer Modern), 'stix',
+ # 'stixsans' or 'custom' (unsupported, may go
+ # away in the future)
+## "mathtext.fontset: custom" is defined by the mathtext.bf, .cal, .it, ...
+## settings which map a TeX font name to a fontconfig font pattern. (These
+## settings are not used for other font sets.)
+#mathtext.bf: sans:bold
+#mathtext.cal: cursive
+#mathtext.it: sans:italic
+#mathtext.rm: sans
+#mathtext.sf: sans
+#mathtext.tt: monospace
+#mathtext.fallback: cm # Select fallback font from ['cm' (Computer Modern), 'stix'
+ # 'stixsans'] when a symbol can not be found in one of the
+ # custom math fonts. Select 'None' to not perform fallback
+ # and replace the missing character by a dummy symbol.
+#mathtext.default: it # The default font to use for math.
+ # Can be any of the LaTeX font names, including
+ # the special name "regular" for the same font
+ # used in regular text.
+
+
+## ***************************************************************************
+## * AXES *
+## ***************************************************************************
+## Following are default face and edge colors, default tick sizes,
+## default font sizes for tick labels, and so on. See
+## https://matplotlib.org/api/axes_api.html#module-matplotlib.axes
+#axes.facecolor: white # axes background color
+#axes.edgecolor: black # axes edge color
+#axes.linewidth: 0.8 # edge line width
+#axes.grid: False # display grid or not
+#axes.grid.axis: both # which axis the grid should apply to
+#axes.grid.which: major # grid lines at {major, minor, both} ticks
+#axes.titlelocation: center # alignment of the title: {left, right, center}
+#axes.titlesize: large # font size of the axes title
+#axes.titleweight: normal # font weight of title
+#axes.titlecolor: auto # color of the axes title, auto falls back to
+ # text.color as default value
+#axes.titley: None # position title (axes relative units). None implies auto
+#axes.titlepad: 6.0 # pad between axes and title in points
+#axes.labelsize: medium # font size of the x and y labels
+#axes.labelpad: 4.0 # space between label and axis
+#axes.labelweight: normal # weight of the x and y labels
+#axes.labelcolor: black
+#axes.axisbelow: line # draw axis gridlines and ticks:
+ # - below patches (True)
+ # - above patches but below lines ('line')
+ # - above all (False)
+
+#axes.formatter.limits: -5, 6 # use scientific notation if log10
+ # of the axis range is smaller than the
+ # first or larger than the second
+#axes.formatter.use_locale: False # When True, format tick labels
+ # according to the user's locale.
+ # For example, use ',' as a decimal
+ # separator in the fr_FR locale.
+#axes.formatter.use_mathtext: False # When True, use mathtext for scientific
+ # notation.
+#axes.formatter.min_exponent: 0 # minimum exponent to format in scientific notation
+#axes.formatter.useoffset: True # If True, the tick label formatter
+ # will default to labeling ticks relative
+ # to an offset when the data range is
+ # small compared to the minimum absolute
+ # value of the data.
+#axes.formatter.offset_threshold: 4 # When useoffset is True, the offset
+ # will be used when it can remove
+ # at least this number of significant
+ # digits from tick labels.
+
+#axes.spines.left: True # display axis spines
+#axes.spines.bottom: True
+#axes.spines.top: True
+#axes.spines.right: True
+
+#axes.unicode_minus: True # use Unicode for the minus symbol rather than hyphen. See
+ # https://en.wikipedia.org/wiki/Plus_and_minus_signs#Character_codes
+# adapted from https://davidmathlogic.com/colorblind
+# put yellow last, remove black
+axes.prop_cycle: cycler('color', [ '56b4e9', 'e69f00', '009e73', '0072b2', 'd55e00', 'cc79a7', 'f0e442'])
+#axes.prop_cycle: cycler('color', ['1f77b4', 'ff7f0e', '2ca02c', 'd62728', '9467bd', '8c564b', 'e377c2', '7f7f7f', 'bcbd22', '17becf'])
+ # color cycle for plot lines as list of string color specs:
+ # single letter, long name, or web-style hex
+ # As opposed to all other parameters in this file, the color
+ # values must be enclosed in quotes for this parameter,
+ # e.g. '1f77b4', instead of 1f77b4.
+ # See also https://matplotlib.org/tutorials/intermediate/color_cycle.html
+ # for more details on prop_cycle usage.
+#axes.xmargin: .05 # x margin. See `axes.Axes.margins`
+#axes.ymargin: .05 # y margin. See `axes.Axes.margins`
+#axes.zmargin: .05 # z margin. See `axes.Axes.margins`
+#axes.autolimit_mode: data # If "data", use axes.xmargin and axes.ymargin as is.
+ # If "round_numbers", after application of margins, axis
+ # limits are further expanded to the nearest "round" number.
+#polaraxes.grid: True # display grid on polar axes
+#axes3d.grid: True # display grid on 3D axes
+
+
+## ***************************************************************************
+## * AXIS *
+## ***************************************************************************
+#xaxis.labellocation: center # alignment of the xaxis label: {left, right, center}
+#yaxis.labellocation: center # alignment of the yaxis label: {bottom, top, center}
+
+
+## ***************************************************************************
+## * DATES *
+## ***************************************************************************
+## These control the default format strings used in AutoDateFormatter.
+## Any valid format datetime format string can be used (see the python
+## `datetime` for details). For example, by using:
+## - '%%x' will use the locale date representation
+## - '%%X' will use the locale time representation
+## - '%%c' will use the full locale datetime representation
+## These values map to the scales:
+## {'year': 365, 'month': 30, 'day': 1, 'hour': 1/24, 'minute': 1 / (24 * 60)}
+
+#date.autoformatter.year: %Y
+#date.autoformatter.month: %Y-%m
+#date.autoformatter.day: %Y-%m-%d
+#date.autoformatter.hour: %m-%d %H
+#date.autoformatter.minute: %d %H:%M
+#date.autoformatter.second: %H:%M:%S
+#date.autoformatter.microsecond: %M:%S.%f
+## The reference date for Matplotlib's internal date representation
+## See https://matplotlib.org/examples/ticks_and_spines/date_precision_and_epochs.py
+#date.epoch: 1970-01-01T00:00:00
+## 'auto', 'concise':
+#date.converter: auto
+## For auto converter whether to use interval_multiples:
+#date.interval_multiples: True
+
+## ***************************************************************************
+## * TICKS *
+## ***************************************************************************
+## See https://matplotlib.org/api/axis_api.html#matplotlib.axis.Tick
+#xtick.top: False # draw ticks on the top side
+#xtick.bottom: True # draw ticks on the bottom side
+#xtick.labeltop: False # draw label on the top
+#xtick.labelbottom: True # draw label on the bottom
+#xtick.major.size: 3.5 # major tick size in points
+#xtick.minor.size: 2 # minor tick size in points
+#xtick.major.width: 0.8 # major tick width in points
+#xtick.minor.width: 0.6 # minor tick width in points
+#xtick.major.pad: 3.5 # distance to major tick label in points
+#xtick.minor.pad: 3.4 # distance to the minor tick label in points
+#xtick.color: black # color of the ticks
+#xtick.labelcolor: inherit # color of the tick labels or inherit from xtick.color
+#xtick.labelsize: medium # font size of the tick labels
+#xtick.direction: out # direction: {in, out, inout}
+#xtick.minor.visible: False # visibility of minor ticks on x-axis
+#xtick.major.top: True # draw x axis top major ticks
+#xtick.major.bottom: True # draw x axis bottom major ticks
+#xtick.minor.top: True # draw x axis top minor ticks
+#xtick.minor.bottom: True # draw x axis bottom minor ticks
+#xtick.alignment: center # alignment of xticks
+
+#ytick.left: True # draw ticks on the left side
+#ytick.right: False # draw ticks on the right side
+#ytick.labelleft: True # draw tick labels on the left side
+#ytick.labelright: False # draw tick labels on the right side
+#ytick.major.size: 3.5 # major tick size in points
+#ytick.minor.size: 2 # minor tick size in points
+#ytick.major.width: 0.8 # major tick width in points
+#ytick.minor.width: 0.6 # minor tick width in points
+#ytick.major.pad: 3.5 # distance to major tick label in points
+#ytick.minor.pad: 3.4 # distance to the minor tick label in points
+#ytick.color: black # color of the ticks
+#ytick.labelcolor: inherit # color of the tick labels or inherit from ytick.color
+#ytick.labelsize: medium # font size of the tick labels
+#ytick.direction: out # direction: {in, out, inout}
+#ytick.minor.visible: False # visibility of minor ticks on y-axis
+#ytick.major.left: True # draw y axis left major ticks
+#ytick.major.right: True # draw y axis right major ticks
+#ytick.minor.left: True # draw y axis left minor ticks
+#ytick.minor.right: True # draw y axis right minor ticks
+#ytick.alignment: center_baseline # alignment of yticks
+
+
+## ***************************************************************************
+## * GRIDS *
+## ***************************************************************************
+#grid.color: b0b0b0 # grid color
+#grid.linestyle: - # solid
+#grid.linewidth: 0.8 # in points
+#grid.alpha: 1.0 # transparency, between 0.0 and 1.0
+
+
+## ***************************************************************************
+## * LEGEND *
+## ***************************************************************************
+#legend.loc: best
+#legend.frameon: True # if True, draw the legend on a background patch
+#legend.framealpha: 0.8 # legend patch transparency
+#legend.facecolor: inherit # inherit from axes.facecolor; or color spec
+#legend.edgecolor: 0.8 # background patch boundary color
+#legend.fancybox: True # if True, use a rounded box for the
+ # legend background, else a rectangle
+#legend.shadow: False # if True, give background a shadow effect
+#legend.numpoints: 1 # the number of marker points in the legend line
+#legend.scatterpoints: 1 # number of scatter points
+#legend.markerscale: 1.0 # the relative size of legend markers vs. original
+#legend.fontsize: medium
+#legend.labelcolor: None
+#legend.title_fontsize: None # None sets to the same as the default axes.
+
+## Dimensions as fraction of font size:
+#legend.borderpad: 0.4 # border whitespace
+#legend.labelspacing: 0.5 # the vertical space between the legend entries
+#legend.handlelength: 2.0 # the length of the legend lines
+#legend.handleheight: 0.7 # the height of the legend handle
+#legend.handletextpad: 0.8 # the space between the legend line and legend text
+#legend.borderaxespad: 0.5 # the border between the axes and legend edge
+#legend.columnspacing: 2.0 # column separation
+
+
+## ***************************************************************************
+## * FIGURE *
+## ***************************************************************************
+## See https://matplotlib.org/api/figure_api.html#matplotlib.figure.Figure
+#figure.titlesize: large # size of the figure title (``Figure.suptitle()``)
+#figure.titleweight: normal # weight of the figure title
+#figure.figsize: 6.4, 4.8 # figure size in inches
+#figure.dpi: 100 # figure dots per inch
+#figure.facecolor: white # figure face color
+#figure.edgecolor: white # figure edge color
+#figure.frameon: True # enable figure frame
+#figure.max_open_warning: 20 # The maximum number of figures to open through
+ # the pyplot interface before emitting a warning.
+ # If less than one this feature is disabled.
+#figure.raise_window : True # Raise the GUI window to front when show() is called.
+
+## The figure subplot parameters. All dimensions are a fraction of the figure width and height.
+#figure.subplot.left: 0.125 # the left side of the subplots of the figure
+#figure.subplot.right: 0.9 # the right side of the subplots of the figure
+#figure.subplot.bottom: 0.11 # the bottom of the subplots of the figure
+#figure.subplot.top: 0.88 # the top of the subplots of the figure
+#figure.subplot.wspace: 0.2 # the amount of width reserved for space between subplots,
+ # expressed as a fraction of the average axis width
+#figure.subplot.hspace: 0.2 # the amount of height reserved for space between subplots,
+ # expressed as a fraction of the average axis height
+
+## Figure layout
+#figure.autolayout: False # When True, automatically adjust subplot
+ # parameters to make the plot fit the figure
+ # using `tight_layout`
+#figure.constrained_layout.use: False # When True, automatically make plot
+ # elements fit on the figure. (Not
+ # compatible with `autolayout`, above).
+#figure.constrained_layout.h_pad: 0.04167 # Padding around axes objects. Float representing
+#figure.constrained_layout.w_pad: 0.04167 # inches. Default is 3/72 inches (3 points)
+#figure.constrained_layout.hspace: 0.02 # Space between subplot groups. Float representing
+#figure.constrained_layout.wspace: 0.02 # a fraction of the subplot widths being separated.
+
+
+## ***************************************************************************
+## * IMAGES *
+## ***************************************************************************
+#image.aspect: equal # {equal, auto} or a number
+#image.interpolation: antialiased # see help(imshow) for options
+#image.cmap: viridis # A colormap name, gray etc...
+#image.lut: 256 # the size of the colormap lookup table
+#image.origin: upper # {lower, upper}
+#image.resample: True
+#image.composite_image: True # When True, all the images on a set of axes are
+ # combined into a single composite image before
+ # saving a figure as a vector graphics file,
+ # such as a PDF.
+
+
+## ***************************************************************************
+## * CONTOUR PLOTS *
+## ***************************************************************************
+#contour.negative_linestyle: dashed # string or on-off ink sequence
+#contour.corner_mask: True # {True, False, legacy}
+#contour.linewidth: None # {float, None} Size of the contour line
+ # widths. If set to None, it falls back to
+ # `line.linewidth`.
+
+
+## ***************************************************************************
+## * ERRORBAR PLOTS *
+## ***************************************************************************
+#errorbar.capsize: 0 # length of end cap on error bars in pixels
+
+
+## ***************************************************************************
+## * HISTOGRAM PLOTS *
+## ***************************************************************************
+#hist.bins: 10 # The default number of histogram bins or 'auto'.
+
+
+## ***************************************************************************
+## * SCATTER PLOTS *
+## ***************************************************************************
+#scatter.marker: o # The default marker type for scatter plots.
+#scatter.edgecolors: face # The default edge colors for scatter plots.
+
+
+## ***************************************************************************
+## * AGG RENDERING *
+## ***************************************************************************
+## Warning: experimental, 2008/10/10
+#agg.path.chunksize: 0 # 0 to disable; values in the range
+ # 10000 to 100000 can improve speed slightly
+ # and prevent an Agg rendering failure
+ # when plotting very large data sets,
+ # especially if they are very gappy.
+ # It may cause minor artifacts, though.
+ # A value of 20000 is probably a good
+ # starting point.
+
+
+## ***************************************************************************
+## * PATHS *
+## ***************************************************************************
+#path.simplify: True # When True, simplify paths by removing "invisible"
+ # points to reduce file size and increase rendering
+ # speed
+#path.simplify_threshold: 0.111111111111 # The threshold of similarity below
+ # which vertices will be removed in
+ # the simplification process.
+#path.snap: True # When True, rectilinear axis-aligned paths will be snapped
+ # to the nearest pixel when certain criteria are met.
+ # When False, paths will never be snapped.
+#path.sketch: None # May be None, or a 3-tuple of the form:
+ # (scale, length, randomness).
+ # - *scale* is the amplitude of the wiggle
+ # perpendicular to the line (in pixels).
+ # - *length* is the length of the wiggle along the
+ # line (in pixels).
+ # - *randomness* is the factor by which the length is
+ # randomly scaled.
+#path.effects:
+
+
+## ***************************************************************************
+## * SAVING FIGURES *
+## ***************************************************************************
+## The default savefig parameters can be different from the display parameters
+## e.g., you may want a higher resolution, or to make the figure
+## background white
+#savefig.dpi: figure # figure dots per inch or 'figure'
+#savefig.facecolor: auto # figure face color when saving
+#savefig.edgecolor: auto # figure edge color when saving
+#savefig.format: png # {png, ps, pdf, svg}
+#savefig.bbox: standard # {tight, standard}
+ # 'tight' is incompatible with pipe-based animation
+ # backends (e.g. 'ffmpeg') but will work with those
+ # based on temporary files (e.g. 'ffmpeg_file')
+#savefig.pad_inches: 0.1 # Padding to be used when bbox is set to 'tight'
+#savefig.directory: ~ # default directory in savefig dialog box,
+ # leave empty to always use current working directory
+#savefig.transparent: False # setting that controls whether figures are saved with a
+ # transparent background by default
+#savefig.orientation: portrait # Orientation of saved figure
+
+### tk backend params
+#tk.window_focus: False # Maintain shell focus for TkAgg
+
+### ps backend params
+#ps.papersize: letter # {auto, letter, legal, ledger, A0-A10, B0-B10}
+#ps.useafm: False # use of AFM fonts, results in small files
+#ps.usedistiller: False # {ghostscript, xpdf, None}
+ # Experimental: may produce smaller files.
+ # xpdf intended for production of publication quality files,
+ # but requires ghostscript, xpdf and ps2eps
+#ps.distiller.res: 6000 # dpi
+#ps.fonttype: 3 # Output Type 3 (Type3) or Type 42 (TrueType)
+
+### PDF backend params
+#pdf.compression: 6 # integer from 0 to 9
+ # 0 disables compression (good for debugging)
+#pdf.fonttype: 3 # Output Type 3 (Type3) or Type 42 (TrueType)
+#pdf.use14corefonts: False
+#pdf.inheritcolor: False
+
+### SVG backend params
+#svg.image_inline: True # Write raster image data directly into the SVG file
+#svg.fonttype: path # How to handle SVG fonts:
+ # path: Embed characters as paths -- supported
+ # by most SVG renderers
+ # None: Assume fonts are installed on the
+ # machine where the SVG will be viewed.
+#svg.hashsalt: None # If not None, use this string as hash salt instead of uuid4
+
+### pgf parameter
+## See https://matplotlib.org/tutorials/text/pgf.html for more information.
+#pgf.rcfonts: True
+#pgf.preamble: # See text.latex.preamble for documentation
+#pgf.texsystem: xelatex
+
+### docstring params
+#docstring.hardcopy: False # set this when you want to generate hardcopy docstring
+
+
+## ***************************************************************************
+## * INTERACTIVE KEYMAPS *
+## ***************************************************************************
+## Event keys to interact with figures/plots via keyboard.
+## See https://matplotlib.org/users/navigation_toolbar.html for more details on
+## interactive navigation. Customize these settings according to your needs.
+## Leave the field(s) empty if you don't need a key-map. (i.e., fullscreen : '')
+#keymap.fullscreen: f, ctrl+f # toggling
+#keymap.home: h, r, home # home or reset mnemonic
+#keymap.back: left, c, backspace, MouseButton.BACK # forward / backward keys
+#keymap.forward: right, v, MouseButton.FORWARD # for quick navigation
+#keymap.pan: p # pan mnemonic
+#keymap.zoom: o # zoom mnemonic
+#keymap.save: s, ctrl+s # saving current figure
+#keymap.help: f1 # display help about active tools
+#keymap.quit: ctrl+w, cmd+w, q # close the current figure
+#keymap.quit_all: # close all figures
+#keymap.grid: g # switching on/off major grids in current axes
+#keymap.grid_minor: G # switching on/off minor grids in current axes
+#keymap.yscale: l # toggle scaling of y-axes ('log'/'linear')
+#keymap.xscale: k, L # toggle scaling of x-axes ('log'/'linear')
+#keymap.copy: ctrl+c, cmd+c # copy figure to clipboard
+
+
+## ***************************************************************************
+## * ANIMATION *
+## ***************************************************************************
+#animation.html: none # How to display the animation as HTML in
+ # the IPython notebook:
+ # - 'html5' uses HTML5 video tag
+ # - 'jshtml' creates a JavaScript animation
+#animation.writer: ffmpeg # MovieWriter 'backend' to use
+#animation.codec: h264 # Codec to use for writing movie
+#animation.bitrate: -1 # Controls size/quality trade-off for movie.
+ # -1 implies let utility auto-determine
+#animation.frame_format: png # Controls frame format used by temp files
+#animation.ffmpeg_path: ffmpeg # Path to ffmpeg binary. Without full path
+ # $PATH is searched
+#animation.ffmpeg_args: # Additional arguments to pass to ffmpeg
+#animation.convert_path: convert # Path to ImageMagick's convert binary.
+ # On Windows use the full path since convert
+ # is also the name of a system tool.
+#animation.convert_args: # Additional arguments to pass to convert
+#animation.embed_limit: 20.0 # Limit, in MB, of size of base64 encoded
+ # animation in HTML (i.e. IPython notebook)
\ No newline at end of file
diff --git a/examples/generative/corrdiff_plus_plus/inference/plot_multiple_samples.py b/examples/generative/corrdiff_plus_plus/inference/plot_multiple_samples.py
new file mode 100644
index 0000000000..76f36dca29
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/inference/plot_multiple_samples.py
@@ -0,0 +1,71 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import joblib
+import matplotlib.pyplot as plt
+import xarray
+
+
+def plot_samples(netcdf_file, output_dir, n_samples):
+ """Plot multiple samples"""
+ root = xarray.open_dataset(netcdf_file)
+ ds = (
+ xarray.open_dataset(netcdf_file, group="prediction")
+ .merge(root)
+ .set_coords(["lat", "lon"])
+ )
+ truth = (
+ xarray.open_dataset(netcdf_file, group="truth")
+ .merge(root)
+ .set_coords(["lat", "lon"])
+ )
+ os.makedirs(output_dir, exist_ok=True)
+
+ # concatenate truth data and ensemble mean as an "ensemble" member for easy
+ # plotting
+ truth_expanded = truth.assign_coords(ensemble="truth").expand_dims("ensemble")
+ ens_mean = (
+ ds.mean("ensemble")
+ .assign_coords(ensemble="ensemble_mean")
+ .expand_dims("ensemble")
+ )
+ # add [0, 1, 2, ...] to ensemble dim
+ ds["ensemble"] = [str(i) for i in range(ds.sizes["ensemble"])]
+ merged = xarray.concat([truth_expanded, ens_mean, ds], dim="ensemble")
+
+ # plot the variables in parallel
+ def plot(v):
+ print(v)
+ # 2 is for the ensemble and
+ merged[v][: n_samples + 2, :].plot(row="time", col="ensemble")
+ plt.savefig(f"{output_dir}/{v}.png")
+
+ joblib.Parallel(n_jobs=8)(joblib.delayed(plot)(v) for v in merged)
+
+
+if __name__ == "__main__":
+ # Create the parser
+ parser = argparse.ArgumentParser()
+
+ # Add the positional arguments
+ parser.add_argument("--netcdf_file", help="Path to the NetCDF file")
+ parser.add_argument("--output_dir", help="Path to the output directory")
+ # Add the optional argument
+ parser.add_argument("--n-samples", help="Number of samples", default=5, type=int)
+ # Parse the arguments
+ args = parser.parse_args()
+ main(args.netcdf_file, args.output_dir, args.n_samples)
diff --git a/examples/generative/corrdiff_plus_plus/inference/plot_single_sample.py b/examples/generative/corrdiff_plus_plus/inference/plot_single_sample.py
new file mode 100644
index 0000000000..9cf7dbccfd
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/inference/plot_single_sample.py
@@ -0,0 +1,201 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import cftime
+import matplotlib.pyplot as plt
+import netCDF4 as nc
+import numpy as np
+
+
+def pattern_correlation(x, y):
+ """Pattern correlation"""
+ mx = np.mean(x)
+ my = np.mean(y)
+ vx = np.mean((x - mx) ** 2)
+ vy = np.mean((y - my) ** 2)
+
+ a = np.mean((x - mx) * (y - my))
+ b = np.sqrt(vx * vy)
+
+ return a / b
+
+
+def plot_channels(group, time_idx: int):
+ """Plot channels"""
+ # weather sub-plot
+ num_channels = len(group.variables)
+ ncols = 4
+ fig, axs = plt.subplots(
+ nrows=(
+ num_channels // ncols
+ if num_channels % ncols == 0
+ else num_channels // ncols + 1
+ ),
+ ncols=ncols,
+ sharex=True,
+ sharey=True,
+ constrained_layout=True,
+ figsize=(15, 15),
+ )
+
+ for ch, ax in zip(sorted(group.variables), axs.flat):
+ # label row
+ x = group[ch][time_idx]
+ ax.set_title(ch)
+ ax.imshow(x)
+
+
+def channel_eq(a, b):
+ """Check if two channels are equal in variable and pressure."""
+ variable_equal = a["variable"] == b["variable"]
+ pressure_is_nan = np.isnan(a["pressure"]) and np.isnan(b["pressure"])
+ pressure_equal = a["pressure"] == b["pressure"]
+ return variable_equal and (pressure_equal or pressure_is_nan)
+
+
+def channel_repr(channel):
+ """Return a string representation of a channel with variable and pressure."""
+ v = channel["variable"]
+ pressure = channel["pressure"]
+ return f"{v}\n Pressure: {pressure}"
+
+
+def get_clim(output_channels, f):
+ """Get color limits (clim) for output channels based on prediction and truth data."""
+ colorlimits = {}
+ for ch in range(len(output_channels)):
+ channel = output_channels[ch]
+ y = f["prediction"][channel][:]
+ truth = f["truth"][channel][:]
+
+ vmin = min([y.min(), truth.min()])
+ vmax = max([y.max(), truth.max()])
+ colorlimits[channel] = (vmin, vmax)
+ return colorlimits
+
+
+def main(file, output_dir, sample):
+ """Plot single sample"""
+ os.makedirs(output_dir, exist_ok=True)
+ f = nc.Dataset(file, "r")
+
+ # for c in f.time:
+ output_channels = list(f["prediction"].variables)
+ v = f["time"]
+ times = cftime.num2date(v, units=v.units, calendar=v.calendar)
+
+ def plot_panel(ax, data, **kwargs):
+ return ax.pcolormesh(f["lon"], f["lat"], data, cmap="RdBu_r", **kwargs)
+
+ colorlimits = get_clim(output_channels, f)
+ for idx in range(len(times)):
+ print("idx", idx)
+ # weather sub-plot
+ fig, axs = plt.subplots(
+ nrows=len(output_channels),
+ ncols=3,
+ sharex=True,
+ sharey=True,
+ constrained_layout=True,
+ figsize=(12, 12),
+ )
+ row = axs[0]
+ row[0].set_title("Input")
+ row[1].set_title("Generated")
+ row[2].set_title("Truth")
+
+ for ch in range(len(output_channels)):
+ channel = output_channels[ch]
+ row = axs[ch]
+
+ # label row
+
+ y = f["prediction"][channel][sample, idx]
+ truth = f["truth"][channel][idx]
+
+ # search for input_channel
+ input_channels = list(f["input"].variables)
+ if channel in input_channels:
+ x = f["input"][channel][idx]
+ else:
+ x = None
+
+ vmin, vmax = colorlimits[channel]
+
+ def plot_panel(ax, data, **kwargs):
+ if channel == "maximum_radar_reflectivity":
+ return ax.pcolormesh(
+ f["lon"], f["lat"], data, cmap="magma", vmin=0, vmax=vmax
+ )
+ if channel == "temperature_2m":
+ return ax.pcolormesh(
+ f["lon"], f["lat"], data, cmap="magma", vmin=vmin, vmax=vmax
+ )
+ else:
+ if vmin < 0 < vmax:
+ bound = max(abs(vmin), abs(vmax))
+ vmin1 = -bound
+ vmax1 = bound
+ else:
+ vmin1 = vmin
+ vmax1 = vmax
+ return ax.pcolormesh(
+ f["lon"], f["lat"], data, cmap="RdBu_r", vmin=vmin1, vmax=vmax1
+ )
+
+ if x is not None:
+ plot_panel(row[0], x)
+ pc_x = pattern_correlation(x, truth)
+ label_x = pc_x
+ row[0].set_title(f"Input, Pattern correlation: {label_x:.2f}")
+
+ im = plot_panel(row[1], y)
+ plot_panel(row[2], truth)
+
+ cb = plt.colorbar(im, ax=row.tolist())
+ cb.set_label(channel)
+
+ pc_y = pattern_correlation(y, truth)
+ label_y = pc_y
+ row[1].set_title(f"Generated, Pattern correlation: {label_y:.2f}")
+
+ for ax in axs[-1]:
+ ax.set_xlabel("longitude [deg E]")
+
+ for ax in axs[:, 0]:
+ ax.set_ylabel("latitude [deg N]")
+
+ time = times[idx]
+ plt.suptitle(f"Time {time.isoformat()}")
+ plt.savefig(f"{output_dir}/{time.isoformat()}.sample.png")
+
+ plot_channels(f["input"], idx)
+ plt.savefig(f"{output_dir}/{time.isoformat()}.input.png")
+
+
+if __name__ == "__main__":
+ # Create the parser
+ parser = argparse.ArgumentParser()
+ # Add the positional arguments
+ parser.add_argument("--netcdf_file", help="Path to the NetCDF file")
+ parser.add_argument("--output_dir", help="Path to the output directory")
+ # Add the optional argument
+ parser.add_argument("--sample", help="Sample to plot", default=0, type=int)
+ # Parse the arguments
+ args = parser.parse_args()
+ main(args.netcdf_file, args.output_dir, args.sample)
diff --git a/examples/generative/corrdiff_plus_plus/inference/power_spectra.py b/examples/generative/corrdiff_plus_plus/inference/power_spectra.py
new file mode 100644
index 0000000000..6d6bfc84c4
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/inference/power_spectra.py
@@ -0,0 +1,243 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import typer
+import xarray
+from scipy.fft import irfft
+from scipy.signal import periodogram
+
+
+def open_data(file, group=False):
+ """
+ Opens a dataset from a NetCDF file.
+
+ Parameters:
+ file (str): Path to the NetCDF file.
+ group (bool, optional): Whether to open the file as a group. Default is False.
+
+ Returns:
+ xarray.Dataset: An xarray dataset containing the data from the NetCDF file.
+ """
+ root = xarray.open_dataset(file)
+ root = root.set_coords(["lat", "lon"])
+ ds = xarray.open_dataset(file, group=group)
+ ds.coords.update(root.coords)
+ ds.attrs.update(root.attrs)
+
+ return ds
+
+
+def haversine(lat1, lon1, lat2, lon2):
+ """
+ Calculate the Haversine distance between two sets of latitude and longitude coordinates.
+
+ The Haversine formula calculates the shortest distance between two points on the
+ surface of a sphere (in this case, the Earth) given their latitude and longitude
+ coordinates.
+
+ Parameters:
+ lat1 (float): Latitude of the first point in degrees.
+ lon1 (float): Longitude of the first point in degrees.
+ lat2 (float): Latitude of the second point in degrees.
+ lon2 (float): Longitude of the second point in degrees.
+
+ Returns:
+ float: The Haversine distance between the two points in meters.
+ """
+ # Convert latitude and longitude from degrees to radians
+ lat1_rad = np.radians(lat1)
+ lon1_rad = np.radians(lon1)
+ lat2_rad = np.radians(lat2)
+ lon2_rad = np.radians(lon2)
+
+ # Earth radius in meters
+ earth_radius = 6371000 # Approximate value for the average Earth radius
+
+ # Calculate differences in latitude and longitude
+ dlat_rad = lat2_rad - lat1_rad
+ dlon_rad = lon2_rad - lon1_rad
+
+ # Haversine formula
+ a = (
+ np.sin(dlat_rad / 2) ** 2
+ + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon_rad / 2) ** 2
+ )
+ c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
+ distance_meters = earth_radius * c
+
+ return distance_meters
+
+
+def compute_power_spectrum(data, d):
+ """
+ Compute the power spectrum of a 2D data array using the Fast Fourier Transform (FFT).
+
+ The power spectrum represents the distribution of signal power as a function of frequency.
+
+ Parameters:
+ data (numpy.ndarray): 2D input data array.
+ d (float): Sampling interval (time between data points).
+
+ Returns:
+ tuple: A tuple containing the frequency values and the corresponding power spectrum.
+ - freqs (numpy.ndarray): Frequency values corresponding to the power spectrum.
+ - power_spectrum (numpy.ndarray): Power spectrum of the input data.
+ """
+
+ # Compute the 2D FFT along the second dimension
+ fft_data = np.fft.fft(data, axis=-2)
+
+ # Compute the power spectrum by taking the absolute value and squaring
+ power_spectrum = np.abs(fft_data) ** 2
+
+ # Scale the power spectrum based on the sampling interval 'd'
+ power_spectrum /= data.shape[-1] * d
+ freqs = np.fft.fftfreq(data.shape[-1], d)
+
+ return freqs, power_spectrum
+
+
+def power_spectra_to_acf(f, pw):
+ """
+ Convert a one-sided power spectrum to an autocorrelation function.
+
+ Args:
+ f (numpy.ndarray): Frequencies.
+ pw (numpy.ndarray): Power spectral density in units of V^2/Hz.
+
+ Returns:
+ numpy.ndarray: Autocorrelation function (ACF).
+ """
+ pw = pw.copy()
+ pw[0] = 0
+ # magic factor 4 comes from periodogram/irfft stuff
+ # 1) a factor 2 comes from the periodogram being one-sided.
+ # I don't fully understasnd, but this ensures the acf is 1 at r=0
+ sig2 = np.sum(pw * f[1]) * 4
+ acf = irfft(pw) / sig2
+ return acf
+
+
+def average_power_spectrum(data, d):
+ """
+ Compute the average power spectrum of a 2D data array.
+
+ This function calculates the power spectrum for each row of the input data and
+ then averages them to obtain the overall power spectrum.
+ The power spectrum represents the distribution of signal power as a function of frequency.
+
+ Parameters:
+ data (numpy.ndarray): 2D input data array.
+ d (float): Sampling interval (time between data points).
+
+ Returns:
+ tuple: A tuple containing the frequency values and the average power spectrum.
+ - freqs (numpy.ndarray): Frequency values corresponding to the power spectrum.
+ - power_spectra (numpy.ndarray): Average power spectrum of the input data.
+ """
+ # Compute the power spectrum along the second dimension for each row
+ freqs, power_spectra = periodogram(data, fs=1 / d, axis=-1)
+
+ # Average along the first dimension
+ while power_spectra.ndim > 1:
+ power_spectra = power_spectra.mean(axis=0)
+
+ return freqs, power_spectra
+
+
+def main(file, output):
+ """
+ Generate and save multiple power spectrum plots from input data.
+
+ Parameters:
+ file (str): Path to the input data file.
+ output (str): Directory where the generated plots will be saved.
+
+ This function loads and processes various datasets from the input file,
+ calculates their power spectra, and generates and saves multiple power spectrum plots.
+ The plots include kinetic energy, temperature, and reflectivity power spectra.
+ """
+
+ def savefig(name):
+ path = os.path.join(output, name + ".png")
+ plt.savefig(path)
+
+ samples = {}
+ samples["prediction"] = open_data(file, group="prediction")
+ samples["prediction_mean"] = samples["prediction"].mean("ensemble")
+ samples["truth"] = open_data(file, group="truth")
+ samples["ERA5"] = open_data(file, group="input")
+
+ prediction = samples["prediction"]
+ lat = prediction.lat
+ lon = prediction.lon
+
+ dx = haversine(lat[0, 0], lon[0, 0], lat[1, 0], lon[1, 0])
+ dy = haversine(lat[0, 0], lon[0, 0], lat[0, 1], lon[0, 1])
+ print(dx, dy)
+ # the approximate resolution is dx=dy=2000m
+
+ # in km
+ d = 2
+
+ # Plot the power spectrum
+ for name, data in samples.items():
+ freqs, spec_x = average_power_spectrum(data.eastward_wind_10m, d=d)
+ _, spec_y = average_power_spectrum(data.northward_wind_10m, d=d)
+ spec = spec_x + spec_y
+ plt.loglog(freqs, spec, label=name)
+ plt.xlabel("Frequency (1/km)")
+ plt.ylabel("Power Spectrum")
+ plt.ylim(bottom=1e-1)
+ plt.title("Kinetic Energy power spectra")
+ plt.grid()
+ plt.legend()
+ savefig("ke-spectra")
+
+ plt.figure()
+ for name, data in samples.items():
+ freqs, spec = average_power_spectrum(data.temperature_2m, d=d)
+ plt.loglog(freqs, spec, label=name)
+ plt.xlabel("Frequency (1/km)")
+ plt.ylabel("Power Spectrum")
+ plt.ylim(bottom=1e-1)
+ plt.title("T2M Power spectra")
+ plt.grid()
+ plt.legend()
+ savefig("t2m-spectra")
+
+ plt.figure()
+ for name, data in samples.items():
+ try:
+ freqs, spec = average_power_spectrum(data.maximum_radar_reflectivity, d=d)
+ except AttributeError:
+ continue
+ plt.loglog(freqs, spec, label=name)
+ plt.xlabel("Frequency (1/km)")
+ plt.ylabel("Power Spectrum")
+ plt.ylim(bottom=1e-1)
+ plt.title("Reflectivity Power spectra")
+ plt.grid()
+ plt.legend()
+ savefig("reflectivity-spectra")
+
+
+if __name__ == "__main__":
+ typer.run(main)
diff --git a/examples/generative/corrdiff_plus_plus/inference/read_netcdf.py b/examples/generative/corrdiff_plus_plus/inference/read_netcdf.py
new file mode 100644
index 0000000000..64164d44aa
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/inference/read_netcdf.py
@@ -0,0 +1,34 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import netCDF4 as nc
+
+# Open the NetCDF file
+file_path = "image_outdir_0_score.nc" # Replace with the path to your NetCDF file
+dataset = nc.Dataset(file_path, "r") # 'r' stands for read mode
+
+# Access variables and attributes
+print("Variables:")
+for var_name, var in dataset.variables.items():
+ print(f"{var_name}: {var[:]}") # Access the data for each variable
+
+print("\nGlobal attributes:")
+for attr_name in dataset.ncattrs():
+ print(f"{attr_name}: {getattr(dataset, attr_name)}") # Access global attributes
+
+# Close the NetCDF file when done
+dataset.close()
diff --git a/examples/generative/corrdiff_plus_plus/train.py b/examples/generative/corrdiff_plus_plus/train.py
new file mode 100644
index 0000000000..1386526d2a
--- /dev/null
+++ b/examples/generative/corrdiff_plus_plus/train.py
@@ -0,0 +1,496 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os, time, psutil, hydra, torch, sys
+from hydra.utils import to_absolute_path
+from omegaconf import DictConfig, OmegaConf
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.tensorboard import SummaryWriter
+from physicsnemo import Module
+from physicsnemo.models.diffusion import (
+ SongUNet,
+ EDMPrecondSR,
+ SFMPrecondSR,
+ SFMPrecondEmpty,
+)
+from physicsnemo.distributed import DistributedManager
+from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
+from physicsnemo.metrics.diffusion import (
+ RegressionLoss,
+ ResLoss,
+ SFMLoss,
+ # SFMLossSigmaPerChannel,
+ SFMEncoderLoss,
+)
+from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
+from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
+
+# Load utilities from corrdiff examples, make the corrdiff path absolute to avoid issues
+# sys.path.append(sys.path.append(os.path.join(os.path.dirname(__file__), "../corrdiff")) )
+from datasets.dataset import init_train_valid_datasets_from_config
+from helpers.train_helpers import (
+ set_patch_shape,
+ set_seed,
+ configure_cuda_for_consistent_precision,
+ compute_num_accumulation_rounds,
+ handle_and_clip_gradients,
+ is_time_for_periodic_task,
+)
+from helpers.sfm_utils import get_encoder
+
+
+# Train the CorrDiff model using the configurations in "conf/config_training.yaml"
+@hydra.main(version_base="1.2", config_path="conf", config_name="config_training")
+def main(cfg: DictConfig) -> None:
+ # Initialize distributed environment for training
+ DistributedManager.initialize()
+ dist = DistributedManager()
+
+ # Initialize loggers
+ if dist.rank == 0:
+ writer = SummaryWriter(log_dir="tensorboard")
+ logger = PythonLogger("main") # General python logger
+ logger0 = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger
+
+ # Resolve and parse configs
+ OmegaConf.resolve(cfg)
+ dataset_cfg = OmegaConf.to_container(cfg.dataset)
+ if hasattr(cfg, "validation"):
+ train_test_split = True
+ validation_dataset_cfg = OmegaConf.to_container(cfg.validation)
+ else:
+ train_test_split = False
+ validation_dataset_cfg = None
+
+ fp_optimizations = cfg.training.perf.fp_optimizations
+ fp16 = fp_optimizations == "fp16"
+ enable_amp = fp_optimizations.startswith("amp")
+ amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16
+
+ logger.info(f"Saving the outputs in {os.getcwd()}")
+ checkpoint_dir = os.path.join(
+ cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}"
+ )
+ if cfg.training.hp.batch_size_per_gpu == "auto":
+ cfg.training.hp.batch_size_per_gpu = (
+ cfg.training.hp.total_batch_size // dist.world_size
+ )
+
+ # Set seeds and configure CUDA and cuDNN settings to ensure consistent precision
+ set_seed(dist.rank)
+ configure_cuda_for_consistent_precision()
+
+ # Instantiate the dataset
+ data_loader_kwargs = {
+ "pin_memory": True,
+ "num_workers": cfg.training.perf.dataloader_workers,
+ "prefetch_factor": cfg.training.perf.dataloader_workers,
+ }
+ (
+ dataset,
+ dataset_iterator,
+ validation_dataset,
+ validation_dataset_iterator,
+ ) = init_train_valid_datasets_from_config(
+ dataset_cfg,
+ data_loader_kwargs,
+ batch_size=cfg.training.hp.batch_size_per_gpu,
+ seed=0,
+ validation_dataset_cfg=validation_dataset_cfg,
+ train_test_split=train_test_split,
+ )
+
+ # Parse image configuration & update model args
+ dataset_channels = len(dataset.input_channels())
+ img_in_channels = dataset_channels
+ img_shape = dataset.image_shape()
+ img_out_channels = len(dataset.output_channels())
+ patch_shape = (None, None)
+
+ # Instantiate the model and move to device.
+ if cfg.model.name not in (
+ "sfm_encoder",
+ "sfm",
+ "sfm_two_stage",
+ ):
+ raise ValueError("Invalid model")
+ model_args = { # default parameters for all networks
+ "img_out_channels": img_out_channels,
+ "img_resolution": list(img_shape),
+ "use_fp16": fp16,
+ }
+ ## remaining defaults are what we want
+ standard_model_cfgs = { # default parameters for different network types
+ "sfm": {
+ "gridtype": "sinusoidal",
+ "N_grid_channels": 4,
+ },
+ "sfm_two_stage": {
+ "gridtype": "sinusoidal",
+ "N_grid_channels": 4,
+ },
+ "sfm_encoder": {}, # empty preconditioner
+ }
+
+ model_args.update(standard_model_cfgs[cfg.model.name])
+ if hasattr(cfg.model, "model_args"): # override defaults from config file
+ model_args.update(OmegaConf.to_container(cfg.model.model_args))
+
+ if cfg.model.name == "sfm_encoder":
+ # should this be set to no_grad?
+ denoiser_net = SFMPrecondEmpty()
+ else: # sfm or sfm_two_stage
+ denoiser_net = SFMPrecondSR(
+ img_in_channels=img_in_channels + model_args["N_grid_channels"],
+ **model_args,
+ )
+
+ denoiser_net.train().requires_grad_(True).to(dist.device)
+ denoiser_ema = copy.deepcopy(denoiser_net).eval().requires_grad_(False)
+ ema_halflife_nimg = int(cfg.training.hp.ema * 1000000)
+ if hasattr(cfg.training.hp, "ema_rampup_ratio"):
+ ema_rampup_ratio = float(cfg.training.hp.ema_rampup_ratio)
+ else:
+ ema_rampup_ratio = 0.5
+
+ # Create or load the encoder:
+ if cfg.model.name in ["sfm", "sfm_encoder"]:
+ encoder_net = get_encoder(cfg)
+ encoder_net.train().requires_grad_(True).to(dist.device)
+ logger0.success("Constructed encoder network succesfully")
+ else: # "sfm_two_stage"
+ if not hasattr(cfg.training.io, "encoder_checkpoint_path"):
+ raise KeyError(
+ "Need to provide encoder_checkpoint_path when using sfm_two_stage"
+ )
+ encoder_checkpoint_path = to_absolute_path(
+ cfg.training.io.encoder_checkpoint_path
+ )
+ if not os.path.exists(encoder_checkpoint_path):
+ raise FileNotFoundError(
+ f"Expected this encoder checkpoint but not found: {encoder_checkpoint_path}"
+ )
+ encoder_net = Module.from_checkpoint(encoder_checkpoint_path)
+ encoder_net.eval().requires_grad_(False).to(dist.device)
+ logger0.success("Loaded the pre-trained encoder network")
+
+ # Instantiate the loss function(s)
+ if cfg.model.name in ("sfm", "sfm_two_stage"):
+ loss_fn = SFMLoss(
+ encoder_loss_type=cfg.model.encoder_loss_type,
+ encoder_loss_weight=cfg.model.encoder_loss_weight,
+ sigma_min=cfg.model.sigma_min,
+ )
+ # with sfm the encoder and diffusion model are trained together
+ if cfg.model.name == "sfm":
+ loss_fn_encoder = SFMEncoderLoss(encoder_loss_type="l2")
+ elif cfg.model.name == "sfm_encoder":
+ loss_fn = SFMEncoderLoss(
+ encoder_loss_type=cfg.model.encoder_loss_type,
+ )
+ else:
+ raise NotImplementedError(f"Model {cfg.model.name} not supported.")
+
+ # Instantiate the optimizer
+ if cfg.model.name == "sfm_two_stage":
+ params = denoiser_net.parameters()
+ else:
+ params = list(denoiser_net.parameters()) + list(encoder_net.parameters())
+
+ optimizer = torch.optim.Adam(
+ params=params, lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8
+ )
+
+ # Enable distributed data parallel if applicable
+ if dist.world_size > 1:
+ ddp_denoiser_net = DistributedDataParallel(
+ denoiser_net,
+ device_ids=[dist.local_rank],
+ broadcast_buffers=True,
+ output_device=dist.device,
+ find_unused_parameters=dist.find_unused_parameters,
+ )
+ if cfg.model.name != "sfm_two_stage":
+ encoder_net = DistributedDataParallel(
+ encoder_net,
+ device_ids=[dist.local_rank],
+ broadcast_buffers=True,
+ output_device=dist.device,
+ find_unused_parameters=dist.find_unused_parameters,
+ )
+ else:
+ # for convenience when updating the denoiser sigma
+ ddp_denoiser_net = denoiser_net
+
+ # Record the current time to measure the duration of subsequent operations.
+ start_time = time.time()
+
+ # Compute the number of required gradient accumulation rounds
+ # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size
+ batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds(
+ cfg.training.hp.total_batch_size,
+ cfg.training.hp.batch_size_per_gpu,
+ dist.world_size,
+ )
+ batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu
+ logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds")
+
+ ## Resume training from previous checkpoints if exists
+ ### TODO needs to be redone, need to store model + encoder + optimizer
+ if dist.world_size > 1:
+ torch.distributed.barrier()
+ try:
+ cur_nimg = load_checkpoint(
+ path=checkpoint_dir,
+ models=[denoiser_net, encoder_net],
+ optimizer=optimizer,
+ device=dist.device,
+ )
+ except:
+ cur_nimg = 0
+
+ ############################################################################
+ # MAIN TRAINING LOOP #
+ ############################################################################
+
+ logger0.info(f"Training for {cfg.training.hp.training_duration} images...")
+ done = False
+
+ # init variables to monitor running mean of average loss since last periodic
+ average_loss_running_mean = 0
+ n_average_loss_running_mean = 1
+
+ while not done:
+ tick_start_nimg = cur_nimg
+ tick_start_time = time.time()
+ # Compute & accumulate gradients
+ optimizer.zero_grad(set_to_none=True)
+ loss_accum = 0
+ for _ in range(num_accumulation_rounds):
+ img_clean, img_lr, labels = next(dataset_iterator)
+ img_clean = img_clean.to(dist.device).to(torch.float32).contiguous()
+ img_lr = img_lr.to(dist.device).to(torch.float32).contiguous()
+ with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp):
+ loss = loss_fn(
+ denoiser_net=ddp_denoiser_net,
+ encoder_net=encoder_net,
+ img_clean=img_clean,
+ img_lr=img_lr,
+ )
+ loss = loss.sum() / batch_size_per_gpu
+ loss_accum += loss / num_accumulation_rounds
+ loss.backward()
+
+ loss_sum = torch.tensor([loss_accum], device=dist.device)
+ if dist.world_size > 1:
+ torch.distributed.barrier()
+ torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM)
+ average_loss = (loss_sum / dist.world_size).cpu().item()
+
+ # update running mean of average loss since last periodic task
+ average_loss_running_mean += (
+ average_loss - average_loss_running_mean
+ ) / n_average_loss_running_mean
+ n_average_loss_running_mean += 1
+
+ if dist.rank == 0:
+ writer.add_scalar("training_loss", average_loss, cur_nimg)
+ writer.add_scalar(
+ "training_loss_running_mean", average_loss_running_mean, cur_nimg
+ )
+
+ ptt = is_time_for_periodic_task(
+ cur_nimg,
+ cfg.training.io.print_progress_freq,
+ done,
+ cfg.training.hp.total_batch_size,
+ dist.rank,
+ rank_0_only=True,
+ )
+ if ptt:
+ # reset running mean of average loss
+ average_loss_running_mean = 0
+ n_average_loss_running_mean = 1
+
+ # Update weights.
+ lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate
+ for g in optimizer.param_groups:
+ if lr_rampup > 0:
+ g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1)
+ if cur_nimg >= lr_rampup:
+ g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6)
+ current_lr = g["lr"]
+ if dist.rank == 0:
+ writer.add_scalar("learning_rate", current_lr, cur_nimg)
+
+ # clear any nans from the denoiser
+ for param in denoiser_net.parameters():
+ if param.grad is not None:
+ torch.nan_to_num(
+ param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad
+ )
+ handle_and_clip_gradients(
+ denoiser_net, grad_clip_threshold=cfg.training.hp.grad_clip_threshold
+ )
+ handle_and_clip_gradients(
+ encoder_net, grad_clip_threshold=cfg.training.hp.grad_clip_threshold
+ )
+ optimizer.step()
+
+ # Update EMA.
+ if ema_rampup_ratio is not None:
+ ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
+ ema_beta = 0.5 ** (
+ cfg.training.hp.total_batch_size / max(ema_halflife_nimg, 1e-8)
+ )
+ for p_ema, p_net in zip(denoiser_ema.parameters(), denoiser_net.parameters()):
+ p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))
+
+ cur_nimg += cfg.training.hp.total_batch_size
+ done = cur_nimg >= cfg.training.hp.training_duration
+
+ # Validation
+ if validation_dataset_iterator is not None:
+ valid_loss_accum = 0
+ if is_time_for_periodic_task(
+ cur_nimg,
+ cfg.training.io.validation_freq,
+ done,
+ cfg.training.hp.total_batch_size,
+ dist.rank,
+ ):
+ rmse_encoder_valid_accum_mean = 0
+ with torch.no_grad():
+ for _ in range(cfg.training.io.validation_steps):
+ img_clean_valid, img_lr_valid, labels_valid = next(
+ validation_dataset_iterator
+ )
+
+ img_clean_valid = (
+ img_clean_valid.to(dist.device)
+ .to(torch.float32)
+ .contiguous()
+ )
+ img_lr_valid = (
+ img_lr_valid.to(dist.device).to(torch.float32).contiguous()
+ )
+ loss_valid = loss_fn(
+ denoiser_net=ddp_denoiser_net,
+ encoder_net=encoder_net,
+ img_clean=img_clean_valid,
+ img_lr=img_lr_valid,
+ )
+ loss_valid = (
+ (loss_valid.sum() / batch_size_per_gpu).cpu().item()
+ )
+ valid_loss_accum += (
+ loss_valid / cfg.training.io.validation_steps
+ )
+
+ if cfg.model.name == "sfm":
+ rmse_encoder_valid = loss_fn_encoder(
+ denoiser_net=ddp_denoiser_net,
+ encoder_net=encoder_net,
+ img_clean=img_clean_valid,
+ img_lr=img_lr_valid,
+ )
+ rmse_encoder_valid_accum_mean += (
+ rmse_encoder_valid.mean((0, 2, 3))
+ / cfg.training.io.validation_steps
+ )
+
+ valid_loss_sum = torch.tensor(
+ [valid_loss_accum], device=dist.device
+ )
+ if dist.world_size > 1:
+ torch.distributed.barrier()
+ torch.distributed.all_reduce(
+ valid_loss_sum, op=torch.distributed.ReduceOp.SUM
+ )
+ average_valid_loss = valid_loss_sum / dist.world_size
+ if dist.rank == 0:
+ writer.add_scalar(
+ "validation_loss", average_valid_loss, cur_nimg
+ )
+
+ if dist.rank == 0:
+ if (
+ cfg.model.name == "sfm"
+ and cfg.model.model_args["sigma_max"]["learnable"]
+ ):
+ denoiser_net.update_sigma_max(rmse_encoder_valid_accum_mean)
+
+ if is_time_for_periodic_task(
+ cur_nimg,
+ cfg.training.io.print_progress_freq,
+ done,
+ cfg.training.hp.total_batch_size,
+ dist.rank,
+ rank_0_only=True,
+ ):
+ # Print stats if we crossed the printing threshold with this batch
+ tick_end_time = time.time()
+ fields = []
+ fields += [f"samples {cur_nimg:<9.1f}"]
+ fields += [f"training_loss {average_loss:<7.2f}"]
+ fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"]
+ fields += [f"learning_rate {current_lr:<7.8f}"]
+ fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"]
+ fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"]
+ fields += [
+ f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}"
+ ]
+ fields += [
+ f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"
+ ]
+ fields += [
+ f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}"
+ ]
+ fields += [
+ f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}"
+ ]
+ logger0.info(" ".join(fields))
+ torch.cuda.reset_peak_memory_stats()
+
+ # Save checkpoints
+ if dist.world_size > 1:
+ torch.distributed.barrier()
+ if is_time_for_periodic_task(
+ cur_nimg,
+ cfg.training.io.save_checkpoint_freq,
+ done,
+ cfg.training.hp.total_batch_size,
+ dist.rank,
+ rank_0_only=True,
+ ):
+ if cfg.model.name in ["sfm", "sfm_two_stage"]:
+ save_list = [denoiser_net, encoder_net]
+ else:
+ save_list = [encoder_net]
+ save_checkpoint(
+ path=checkpoint_dir,
+ models=save_list,
+ optimizer=optimizer,
+ epoch=cur_nimg,
+ )
+
+ # Done.
+ logger0.info("Training Completed.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/physicsnemo/launch/logging/wandb.py b/physicsnemo/launch/logging/wandb.py
index 9cb4879c8e..24e5a38974 100644
--- a/physicsnemo/launch/logging/wandb.py
+++ b/physicsnemo/launch/logging/wandb.py
@@ -23,9 +23,8 @@
from typing import Literal
import wandb
-from wandb import AlertLevel
-
from physicsnemo.distributed import DistributedManager
+from wandb import AlertLevel
from .utils import create_ddp_group_tag
diff --git a/physicsnemo/metrics/diffusion/__init__.py b/physicsnemo/metrics/diffusion/__init__.py
index 8673ce8eb2..d96c11deb0 100644
--- a/physicsnemo/metrics/diffusion/__init__.py
+++ b/physicsnemo/metrics/diffusion/__init__.py
@@ -25,3 +25,4 @@
VELoss_dfsr,
VPLoss,
)
+from .sfm_loss import SFMEncoderLoss, SFMLoss
diff --git a/physicsnemo/metrics/diffusion/sfm_loss.py b/physicsnemo/metrics/diffusion/sfm_loss.py
new file mode 100644
index 0000000000..7eb83035f7
--- /dev/null
+++ b/physicsnemo/metrics/diffusion/sfm_loss.py
@@ -0,0 +1,251 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from physicsnemo.models.diffusion import SongUNetPosEmbd
+
+
+class SFMLoss:
+ """
+ Loss function corresponding to Stochastic Flow matching
+
+ Parameters
+ ----------
+ encoder_loss_type: str
+ Type of loss to use ["l1", "l2", None]
+ encoder_loss_weight: float
+ Regularizer loss weights, by defaults 0.1.
+ sigma_min: Union[List[float], float]
+ Minimum value of noise sigma, default 2e-3
+ Protects against values near zero that result in loss explosion.
+ sigma_data: float
+ EDM weighting, default 0.5
+ """
+
+ def __init__(
+ self,
+ encoder_loss_type: str = "l2",
+ encoder_loss_weight: float = 0.1,
+ sigma_min: Union[List[float], float] = 0.002,
+ sigma_data: float = 0.5,
+ ):
+ """
+ Loss function corresponding to Stochastic Flow matching
+
+ Parameters
+ ----------
+ encoder_loss_type: str, optional
+ Type of loss to use ["l1", "l2", None], defaults to 'l2'
+ encoder_loss_weight: float, optional
+ Regularizer loss weights, by defaults 0.1.
+ sigma_min: Union[List[float], float], optional
+ Minimum value of noise sigma, default 2e-3
+ Protects against values near zero that result in loss explosion.
+ sigma_data: float, optional
+ EDM weighting, default 0.5
+ """
+ self.encoder_loss_type = encoder_loss_type
+ self.encoder_loss_weight = encoder_loss_weight
+ self.sigma_min = sigma_min
+ self.sigma_data = sigma_data
+
+ if encoder_loss_type not in ["l1", "l2", None]:
+ raise ValueError(
+ f"encoder_loss_type should be one of ['l1', 'l2', None] not {encoder_loss_type}"
+ )
+
+ def __call__(
+ self,
+ denoiser_net: torch.nn.Module,
+ encoder_net: torch.nn.Module,
+ img_clean: torch.Tensor,
+ img_lr: torch.Tensor,
+ ):
+ """
+ Calculate the loss for corresponding to stochastic flow matching
+
+ Parameters
+ ----------
+ denoiser_net: torch.Tensor
+ The denoiser network making the predictions
+ encoder_net: torch.Tensor
+ The encoder network making the predictions
+ img_clean: torch.Tensor
+ Input images (high resolution) to the neural network.
+ img_lr: torch.Tensor
+ Input images (low resolution) to the neural network.
+ Returns
+ -------
+ torch.Tensor
+ A tensor representing the combined loss calculated based on the flow matching
+ encoder and denoiser networks
+ """
+ # uniformly samples from 0 to 1 in torch
+ if isinstance(denoiser_net, torch.nn.parallel.DistributedDataParallel):
+ sigma_max_per_channel = denoiser_net.module.get_sigma_max().to(
+ device=img_clean.device
+ )
+ else:
+ sigma_max_per_channel = denoiser_net.get_sigma_max().to(
+ device=img_clean.device
+ )
+
+ # clamp to min value
+ if not isinstance(self.sigma_min, float) and len(self.sigma_min) > 1:
+ sigma_max_per_channel = torch.maximum(
+ sigma_max_per_channel,
+ torch.tensor(self.sigma_min, device=img_clean.device),
+ )
+ # Normalize from 0 to 1
+ sigma_max = 1.0
+ else:
+ # just use the first value, ignore the rest
+ sigma_max = torch.maximum(
+ sigma_max_per_channel,
+ torch.tensor(self.sigma_min, device=img_clean.device),
+ )[0]
+
+ rnd_uniform = torch.rand([img_clean.shape[0], 1, 1, 1], device=img_clean.device)
+
+ sampled_sigma = rnd_uniform * sigma_max
+ weight = (sampled_sigma**2 + self.sigma_data**2) / (
+ sampled_sigma * self.sigma_data
+ ) ** 2
+
+ # augment for conditional generaiton
+ x_tot = torch.cat((img_clean, img_lr), dim=1)
+
+ x_1 = x_tot[:, : img_clean.shape[1], :, :] # x_1 - target
+ x_low = x_tot[:, img_clean.shape[1] :, :, :] # x_low - upsampled ERA5
+
+ # encode the low resolution data x_low to x_0
+ # check same if encoder_net is in distributed data parallel
+
+ if isinstance(encoder_net, SongUNetPosEmbd) or (
+ isinstance(encoder_net, nn.parallel.DistributedDataParallel)
+ and isinstance(encoder_net.module, SongUNetPosEmbd)
+ ):
+ x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None)
+ else:
+ x_0 = encoder_net(x_low)
+
+ # convert sigma to time
+ # sampled_sigma = (1-t)*sigma_max
+ time = 1 - sampled_sigma / sigma_max # this is the time from 1 to 0
+
+ # we don't subtract x_0, this will be done in the sampler
+ x_t = ((1 - time) * x_0) + (time * x_1)
+ if not isinstance(self.sigma_min, float) and len(self.sigma_min) > 1:
+ sigma_t = (
+ sigma_max_per_channel.unsqueeze(0).unsqueeze(2).unsqueeze(2)
+ * sampled_sigma
+ )
+ else:
+ sigma_t = sampled_sigma
+ x_t_noisy = x_t + torch.randn_like(x_0) * sigma_t
+
+ D_x_t = denoiser_net(
+ x=x_t_noisy,
+ sigma=sampled_sigma,
+ condition=x_low,
+ )
+ # time_weight = lambda t: 1/(1 - torch.clamp(t, 0.9))
+ # time_weight = lambda t: 1
+ def time_weight(t):
+ return 1
+
+ sfm_loss = weight * ((time_weight(time) * (D_x_t - x_1)) ** 2)
+
+ if self.encoder_loss_type == "l1":
+ encoder_loss = F.l1_loss(x_1, x_0, reduction="none")
+ weighted_encoder_loss = self.encoder_loss_weight * encoder_loss
+ elif self.encoder_loss_type == "l2":
+ encoder_loss = F.mse_loss(x_1, x_0, reduction="none")
+ weighted_encoder_loss = self.encoder_loss_weight * encoder_loss
+ elif self.encoder_loss_type is None:
+ encoder_loss = torch.tensor(
+ 0.0
+ ) # This covers the case where there is no encoder_loss
+ weighted_encoder_loss = torch.tensor(0.0)
+
+ return sfm_loss + weighted_encoder_loss
+
+
+class SFMEncoderLoss:
+ """
+ Loss function corresponding to Stochastic Flow matching for the encoder portion
+
+ Parameters
+ ----------
+ encoder_loss_type: str
+ Type of loss to use ["l1", "l2", None]
+ """
+
+ def __init__(self, encoder_loss_type: str = "l2", **kwargs):
+ if encoder_loss_type not in ["l1", "l2"]:
+ raise ValueError(
+ f"encoder_loss_type should be either l1 or l2 not {encoder_loss_type}"
+ )
+ self.encoder_loss_type = encoder_loss_type
+
+ def __call__(
+ self,
+ denoiser_net: torch.nn.Module,
+ encoder_net: torch.nn.Module,
+ img_clean: torch.Tensor,
+ img_lr: torch.Tensor,
+ ):
+ """
+ Calculate the loss for the enoder used in stochastic flow matching
+
+ Parameters
+ ----------
+ models: [torch.Tensor, torch.Tensor]
+ The denoiser and encoder networks making the predictions
+ Stored as [denoiser, encoder]
+ img_clean: torch.Tensor
+ Input images (high resolution) to the neural network.
+ img_lr: torch.Tensor
+ Input images (low resolution) to the neural network.
+
+ Returns
+ -------
+ torch.Tensor
+ A tensor representing the loss calculated based on the encoder's
+ predictions
+ """
+ x_1 = img_clean
+ x_low = img_lr
+
+ if isinstance(encoder_net, SongUNetPosEmbd) or (
+ isinstance(encoder_net, nn.parallel.DistributedDataParallel)
+ and isinstance(encoder_net.module, SongUNetPosEmbd)
+ ):
+ x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None)
+ else:
+ x_0 = encoder_net(x_low)
+
+ if self.encoder_loss_type == "l1":
+ encoder_loss = F.l1_loss(x_1, x_0, reduction="none")
+ elif self.encoder_loss_type == "l2":
+ encoder_loss = F.mse_loss(x_1, x_0, reduction="none")
+
+ return encoder_loss
diff --git a/physicsnemo/models/diffusion/__init__.py b/physicsnemo/models/diffusion/__init__.py
index 3984bffd42..64349700f7 100644
--- a/physicsnemo/models/diffusion/__init__.py
+++ b/physicsnemo/models/diffusion/__init__.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
+
+from .encoders import Conv2dSerializable
from .utils import weight_init
from .layers import (
AttentionOp,
@@ -36,3 +38,4 @@
VEPrecond_dfsr_cond,
VEPrecond_dfsr,
)
+from .sfm_preconditioning import SFMPrecondSR, SFMPrecondEmpty
diff --git a/physicsnemo/models/diffusion/encoders.py b/physicsnemo/models/diffusion/encoders.py
new file mode 100644
index 0000000000..001728fdc0
--- /dev/null
+++ b/physicsnemo/models/diffusion/encoders.py
@@ -0,0 +1,78 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+
+import nvtx
+
+import physicsnemo.models.diffusion as diffusion
+from physicsnemo.models.meta import ModelMetaData
+from physicsnemo.models.module import Module
+
+
+@dataclass
+class MetaData(ModelMetaData):
+ name: str = "Conv2dSerializable"
+ # Optimization
+ jit: bool = True
+ cuda_graphs: bool = True
+ amp_cpu: bool = True
+ amp_gpu: bool = True
+ torch_fx: bool = False
+ # Data type
+ bf16: bool = True
+ # Inference
+ onnx: bool = False
+ # Physics informed
+ func_torch: bool = False
+ auto_grad: bool = True
+
+
+class Conv2dSerializable(Module):
+ """
+ A serializable version of a 2d convolution
+
+ Parameters
+ ----------
+ in_channels: int
+ Number of input channels
+ out_channels: int
+ Number of output channels
+ kernel_size: int
+ Size of the convolution kernel
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ ):
+ super().__init__(meta=MetaData())
+ self.out_channels = out_channels
+ self.in_channels = in_channels
+ self.kernel_size = kernel_size
+ self.net = diffusion.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ )
+
+ @nvtx.annotate(message="Conv2dSerializable", color="blue")
+ def forward(self, x):
+ """forward pass"""
+ return self.net(x)
diff --git a/physicsnemo/models/diffusion/sfm_preconditioning.py b/physicsnemo/models/diffusion/sfm_preconditioning.py
new file mode 100644
index 0000000000..3143b7e9bb
--- /dev/null
+++ b/physicsnemo/models/diffusion/sfm_preconditioning.py
@@ -0,0 +1,239 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import importlib
+from dataclasses import dataclass
+from typing import List, Union
+
+import nvtx
+import torch
+
+from physicsnemo.models.diffusion import ( # noqa: F401 for globals
+ Conv2dSerializable,
+ DhariwalUNet,
+ SongUNet,
+)
+from physicsnemo.models.meta import ModelMetaData
+from physicsnemo.models.module import Module
+
+network_module = importlib.import_module("physicsnemo.models.diffusion")
+
+
+@dataclass
+class SFMPrecondSRMetaData(ModelMetaData):
+ """EDMPrecondSR meta data"""
+
+ name: str = "SFMPrecondSR"
+ # Optimization
+ jit: bool = False
+ cuda_graphs: bool = False
+ amp_cpu: bool = False
+ amp_gpu: bool = True
+ torch_fx: bool = False
+ # Data type
+ bf16: bool = False
+ # Inference
+ onnx: bool = False
+ # Physics informed
+ func_torch: bool = False
+ auto_grad: bool = False
+
+
+class SFMPrecondSR(Module):
+ """
+ preconditioning based on the Stochastic Flow Model approach
+
+ Parameters
+ ----------
+ img_resolution : Union[List[int], int]
+ Image resolution.
+ img_in_channels : int
+ Number of input color channels.
+ img_out_channels : int
+ Number of output color channels.
+ use_fp16 : bool
+ Execute the underlying model at FP16 precision?, by default False.
+ sigma_max : Union[dict, float]
+ Maximum supported noise level, by default inf.
+ sigma_data : float
+ Expected standard deviation of the training data, by default 0.5.
+ model_type :str
+ Class name of the underlying model, by default "SongUNetPosEmbd".
+ **model_kwargs : dict
+ Keyword arguments for the underlying model.
+ """
+
+ def __init__(
+ self,
+ img_resolution: Union[List[int], int],
+ img_in_channels: int,
+ img_out_channels: int,
+ use_fp16: bool = False,
+ N_grid_channels: int = 0,
+ sigma_max: Union[dict, float] = float("inf"),
+ sigma_data: float = 0.5,
+ model_type: str = "SongUNetPosEmbd",
+ use_x_low_conditioning=None,
+ **model_kwargs,
+ ) -> None:
+ Module.__init__(self, meta=SFMPrecondSRMetaData)
+ model_class = getattr(network_module, model_type)
+
+ self.use_fp16 = use_fp16
+ self.sigma_data = sigma_data
+ self.img_shape_y = img_resolution[0]
+ self.img_shape_x = img_resolution[1]
+ self.use_x_low_conditioning = use_x_low_conditioning
+
+ if isinstance(sigma_max, float):
+ self.sigma_max_current = torch.tensor(sigma_max)
+ # disable weighted updates, let that be handled externally
+ self.ema_weight = torch.tensor(0.0)
+ self.min_values = torch.tensor(0.0)
+ else:
+ sigma_max_current = torch.tensor(sigma_max["initial_values"])
+ self.register_buffer("sigma_max_current", sigma_max_current)
+ self.ema_weight = torch.tensor(sigma_max["ema_weight"])
+ self.min_values = torch.tensor(sigma_max["min_values"])
+
+ # SongUNetPosEmbd
+ if "encoder_type" in model_kwargs:
+ del model_kwargs["encoder_type"]
+ lr_channels = img_in_channels if self.use_x_low_conditioning else 0
+ self.denoiser_net = model_class(
+ img_resolution=img_resolution,
+ in_channels=img_out_channels + lr_channels + N_grid_channels,
+ out_channels=img_out_channels,
+ **model_kwargs,
+ )
+
+ def update_sigma_max(self, sigma_max: float):
+ """
+ Updates the maximum noise level
+
+ Sigma max is updated externally so any needed accumulation
+ and reductions can be handled by the training loop
+
+ Parameters
+ ----------
+ sigma_max : float
+ Maximum noise level
+ """
+ ema_weight = self.ema_weight.to(self.sigma_max_current.device)
+ sigma_max = torch.tensor(sigma_max).to(self.sigma_max_current.device)
+ # Update sigma_max_current without gradients
+ new_sigma_max_current = (
+ ema_weight * self.sigma_max_current + (1 - ema_weight) * sigma_max
+ )
+ self.sigma_max_current = torch.max(
+ new_sigma_max_current, self.min_values.to(self.sigma_max_current.device)
+ )
+
+ def get_sigma_max(self):
+ """returns the current max sigma"""
+ return self.sigma_max_current
+
+ @nvtx.annotate(message="SFMPrecondSR", color="orange")
+ def forward(
+ self,
+ x: torch.Tensor,
+ sigma: torch.Tensor,
+ condition: torch.Tensor,
+ force_fp32: bool = False,
+ **model_kwargs,
+ ):
+ """
+ Forward pass of the Stochastic Flow Model preconditioner
+
+ Parameters
+ ----------
+ x : tensor
+ The partially noised input image
+
+ sigma : torch.Tensor
+ The image containing random noise
+
+ condition : torch.Tensor
+ The low resoltuion
+
+ force_fp32 : bool
+ Whether float 32 computations should be forced, default False
+
+ model_kwargs: dict
+ Keyword arguments for the underlying model.
+
+ Returns
+ -------
+ torch.Tensor : the denoised image
+
+ """
+ x = x.to(torch.float32)
+ dtype = (
+ torch.float16
+ if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
+ else torch.float32
+ )
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+ c_noise = sigma.log() / 4
+
+ if not self.use_x_low_conditioning:
+ scaled_x = c_in * x
+ else:
+ condition = condition.to(torch.float32)
+ scaled_x = torch.cat([c_in * x, condition], dim=1)
+
+ F_x = self.denoiser_net(
+ scaled_x.to(dtype), noise_labels=c_noise.flatten(), class_labels=None
+ )
+
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
+ return D_x
+
+ def round_sigma(self, sigma: Union[List[float], torch.Tensor]):
+ """
+ Convert a given sigma value(s) to a tensor representation in the
+ same precision as the model
+
+ Parameters
+ ----------
+ sigma : Union[float list, torch.Tensor]
+ The sigma value(s) to convert.
+
+ Returns
+ -------
+ torch.Tensor
+ The tensor representation of the provided sigma value(s).
+ """
+ return torch.as_tensor(sigma)
+
+
+class SFMPrecondEmpty(Module):
+ """
+ A preconditioner that does nothing
+
+ Parameters
+ ----------
+ **model_kwargs : dict
+ Keyword arguments for the underlying model, not used
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.tensor(0.0))
+ self.label_dim = None
diff --git a/physicsnemo/utils/generative/__init__.py b/physicsnemo/utils/generative/__init__.py
index a708ccb3d6..c5017a0219 100644
--- a/physicsnemo/utils/generative/__init__.py
+++ b/physicsnemo/utils/generative/__init__.py
@@ -15,6 +15,11 @@
# limitations under the License.
from .deterministic_sampler import deterministic_sampler
+from .sfm_samplers import (
+ SFM_encoder_sampler,
+ SFM_Euler_sampler,
+ SFM_Euler_sampler_Adaptive_Sigma,
+)
from .stochastic_sampler import image_batching, image_fuse, stochastic_sampler
from .utils import (
EasyDict,
diff --git a/physicsnemo/utils/generative/sfm_samplers.py b/physicsnemo/utils/generative/sfm_samplers.py
new file mode 100644
index 0000000000..abb87a721a
--- /dev/null
+++ b/physicsnemo/utils/generative/sfm_samplers.py
@@ -0,0 +1,231 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from typing import Dict
+
+import nvtx
+import torch
+import torch.nn as nn
+from omegaconf import DictConfig
+
+from physicsnemo.models.diffusion import SongUNetPosEmbd
+
+
+def sigma(t):
+ return t
+
+
+def sigma_inv(sigma):
+ return sigma
+
+
+@nvtx.annotate(message="SFM_encoder_sampler", color="red")
+def SFM_encoder_sampler(
+ networks: Dict[str, torch.nn.Module],
+ img_lr: torch.Tensor,
+ randn_like: Callable = None,
+ cfg: DictConfig = None,
+):
+ """
+ Sampler for the SFM encoder, just runs the encoder
+
+ networks: Dict
+ A dictionary containing "encoder_net" and "denoiser_net" entries
+ for the denoiser and encoder networks.
+ Note: denoiser_net is not used for SFM_encoder_sampler
+ img_lr: torch.tensor
+ The low resolution image used for denoising
+ randn_like: StackedRandomGenerator
+ The random noise generator used for denoising.
+ Note: not used for SFM_encoder_sampler
+ cfg: DictConfig
+ The configuration used for sampling
+ Note: not used for SFM_encoder_sampler
+ """
+ encoder_net = networks["encoder_net"]
+ x_low = img_lr
+ # in V1 the encoder net was inside the denoiser
+ if isinstance(encoder_net, SongUNetPosEmbd) or (
+ isinstance(encoder_net, nn.parallel.DistributedDataParallel)
+ and isinstance(encoder_net.module, SongUNetPosEmbd)
+ ):
+ x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None)
+ else:
+ x_0 = encoder_net(x_low) # MODULUS
+
+ return x_0
+
+
+@nvtx.annotate(message="SFM_Euler_sampler", color="red")
+def SFM_Euler_sampler(
+ networks: Dict[str, torch.nn.Module],
+ img_lr: torch.Tensor,
+ randn_like: Callable,
+ cfg: DictConfig,
+):
+ """
+ Sampler for the SFM encoder, just runs the encoder
+
+ networks: Dict
+ A dictionary containing "encoder_net" and "denoiser_net" entries
+ for the denoiser and encoder networks.
+ Note: denoiser_net is not used for SFM_encoder_sampler
+ img_lr: torch.tensor
+ The low resolution image used for denoising
+ randn_like: StackedRandomGenerator
+ The random noise generator used for denoising.
+ cfg: DictConfig
+ The configuration used for sampling
+ """
+ denoiser_net = networks["denoiser_net"]
+ encoder_net = networks["encoder_net"]
+
+ x_low = img_lr
+
+ # Define time steps in terms of noise level.
+ step_indices = torch.arange(cfg.num_steps, device=denoiser_net.device)
+ # STATHI TODO: This is a hack, we should treat s_max per channels
+ sigma_max = (
+ denoiser_net.get_sigma_max()[0]
+ if len(denoiser_net.get_sigma_max().shape) > 0
+ else denoiser_net.get_sigma_max()
+ )
+ sigma_steps = (
+ sigma_max ** (1 / cfg.rho)
+ + step_indices
+ / (cfg.num_steps - 1)
+ * (cfg.sigma_min ** (1 / cfg.rho) - sigma_max ** (1 / cfg.rho))
+ ) ** cfg.rho
+
+ # Define noise level cfg.schedule.
+ # if cfg.schedule == "linear":
+ # sigma = lambda t: t
+ # sigma_inv = lambda sigma: sigma
+
+ # Compute final time steps based on the corresponding noise levels.
+ t_steps = sigma_inv(denoiser_net.round_sigma(sigma_steps))
+ t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
+
+ # Main sampling loop.
+ t_next = t_steps[0]
+
+ if isinstance(encoder_net, SongUNetPosEmbd) or (
+ isinstance(encoder_net, nn.parallel.DistributedDataParallel)
+ and isinstance(encoder_net.module, SongUNetPosEmbd)
+ ):
+ x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None)
+ else:
+ x_0 = encoder_net(x_low) # MODULUS
+
+ # x_0 = x_0.to(torch.float64)
+ x_t = x_0 + sigma_max * randn_like(x_0)
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
+ x_t_cur = x_t
+ t_hat = t_cur
+ x_t_hat = x_t_cur
+
+ # Euler step.
+ x_t_denoised = denoiser_net(x_t_hat, sigma(t_hat), condition=x_low).to(
+ torch.float64
+ )
+
+ u_t = (x_t_denoised - x_t_hat) / (torch.clamp(t_hat, min=cfg.t_min))
+
+ dt = t_hat - t_next # needs to be reversed
+ x_t = x_t_hat + u_t * dt
+
+ return x_t
+
+
+@nvtx.annotate(message="SFM_Euler_sampler_Adaptive_Sigma", color="red")
+def SFM_Euler_sampler_Adaptive_Sigma(
+ networks: Dict[str, torch.nn.Module],
+ img_lr: torch.Tensor,
+ randn_like: Callable,
+ cfg: DictConfig,
+):
+ """
+ Sampler for the SFM encoder, just runs the encoder
+
+ networks: Dict
+ A dictionary containing "encoder_net" and "denoiser_net" entries
+ for the denoiser and encoder networks.
+ Note: denoiser_net is not used for SFM_encoder_sampler
+ img_lr: torch.tensor
+ The low resolution image used for denoising
+ randn_like: StackedRandomGenerator
+ The random noise generator used for denoising.
+ cfg: DictConfig
+ The configuration used for sampling
+ """
+ denoiser_net = networks["denoiser_net"]
+ encoder_net = networks["encoder_net"]
+
+ x_low = img_lr
+
+ # Define time steps in terms of noise level.
+ step_indices = torch.arange(cfg.num_steps, device=denoiser_net.device)
+ sigma_max_adaptive = denoiser_net.get_sigma_max()
+
+ # set sigma_max for sampling purposes to 1.0, this normalizes time [0-1]
+ sigma_max = 1.0
+ sigma_steps = (
+ sigma_max ** (1 / cfg.rho)
+ + step_indices
+ / (cfg.num_steps - 1)
+ * (cfg.sigma_min ** (1 / cfg.rho) - sigma_max ** (1 / cfg.rho))
+ ) ** cfg.rho
+
+ # Define noise level cfg.schedule.
+ # if cfg.schedule == "linear":
+ # sigma = lambda t: t
+ # # sigma_deriv = lambda t: 1
+ # sigma_inv = lambda sigma: sigma
+
+ # Compute final time steps based on the corresponding noise levels.
+ t_steps = sigma_inv(denoiser_net.round_sigma(sigma_steps))
+ t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
+
+ # Main sampling loop.
+ t_next = t_steps[0]
+
+ if isinstance(encoder_net, SongUNetPosEmbd) or (
+ isinstance(encoder_net, nn.parallel.DistributedDataParallel)
+ and isinstance(encoder_net.module, SongUNetPosEmbd)
+ ):
+ x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None)
+ else:
+ x_0 = encoder_net(x_low) # MODULUS
+
+ x_t = x_0 + sigma_max_adaptive.view(1, -1, 1, 1) * randn_like(x_0)
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
+ x_t_cur = x_t
+
+ t_hat = t_cur
+ x_t_hat = x_t_cur
+
+ # Euler step.
+ x_t_denoised = denoiser_net(x_t_hat, sigma(t_hat), condition=x_low).to(
+ torch.float64
+ )
+
+ u_t = (x_t_denoised - x_t_hat) / (torch.clamp(t_hat, min=cfg.t_min))
+
+ dt = t_hat - t_next # needs to be reversed
+ x_t = x_t_hat + u_t * dt
+
+ return x_t
diff --git a/test/metrics/diffusion/test_sfm_losses.py b/test/metrics/diffusion/test_sfm_losses.py
new file mode 100644
index 0000000000..37eaa0a830
--- /dev/null
+++ b/test/metrics/diffusion/test_sfm_losses.py
@@ -0,0 +1,176 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+
+import pytest
+import torch
+
+from physicsnemo.metrics.diffusion import (
+ SFMEncoderLoss,
+ SFMLoss,
+)
+from physicsnemo.models.diffusion import SongUNetPosEmbd
+
+
+class fake_net(torch.nn.Module):
+ """dummy class to test sfm encoder"""
+
+ def get_sigma_max(self):
+ return torch.tensor([1.0])
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+
+def get_songunet():
+ """helper that creates a songunet for testing"""
+ songunet_kwargs = {
+ "img_resolution": 8,
+ "in_channels": 2,
+ "out_channels": 4,
+ "embedding_type": "zero",
+ "label_dim": 0,
+ "encoder_type": "standard",
+ "decoder_type": "standard",
+ "channel_mult_noise": 1,
+ "resample_filter": [1, 1],
+ "channel_mult": [1, 2, 2],
+ "attn_resolutions": [28],
+ "N_grid_channels": 0,
+ "model_channels": 4,
+ }
+ return SongUNetPosEmbd(**songunet_kwargs)
+
+
+def test_sfmloss_initialization():
+ """checks SFMLoss __init__"""
+
+ loss_fn = SFMLoss()
+
+ assert loss_fn.encoder_loss_type == "l2"
+ assert loss_fn.encoder_loss_weight == 0.1
+ assert loss_fn.sigma_min == 0.002
+ assert loss_fn.sigma_data == 0.5
+
+ loss_fn = SFMLoss(
+ encoder_loss_type="l1",
+ encoder_loss_weight=[0.1, 0.2],
+ sigma_min=5e-4,
+ sigma_data=0.1,
+ )
+
+ assert loss_fn.encoder_loss_type == "l1"
+ assert loss_fn.encoder_loss_weight == [0.1, 0.2]
+ assert loss_fn.sigma_min == 5e-4
+ assert loss_fn.sigma_data == 0.1
+
+ # test for invalid loss type
+ with pytest.raises(
+ ValueError,
+ match=re.escape(
+ "encoder_loss_type should be one of ['l1', 'l2', None] not bogus"
+ ),
+ ):
+ loss_fn = SFMLoss(
+ encoder_loss_type="bogus",
+ )
+
+
+def test_sfmloss_call():
+ """checks SFMLoss __call__"""
+
+ # dummy network for loss
+ dummy_denoiser = fake_net()
+ dummy_encoder = get_songunet()
+ dummy_net = torch.nn.Identity()
+
+ image_zeros = torch.zeros((2, 4, 8, 8))
+ image_rnd = torch.rand((2, 2, 8, 8))
+
+ # test defaults, encoder l2 loss, sigma_min is float
+ loss_fn = SFMLoss()
+
+ loss_value = loss_fn(dummy_denoiser, dummy_net, image_zeros, image_zeros)
+ assert isinstance(loss_value, torch.Tensor)
+
+ loss_value = loss_fn(dummy_denoiser, dummy_net, image_zeros, image_zeros)
+ assert isinstance(loss_value, torch.Tensor)
+
+ # test encoder l1 loss, sigma_min is list
+ loss_fn = SFMLoss(encoder_loss_type="l1", sigma_min=[0.001, 0.001, 0.001, 0.001])
+ loss_value = loss_fn(dummy_denoiser, dummy_net, image_zeros, image_zeros)
+ assert isinstance(loss_value, torch.Tensor)
+
+ # test no encoder loss, sigma_min is list
+ loss_fn = SFMLoss(encoder_loss_type=None)
+ loss_value = loss_fn(dummy_denoiser, dummy_net, image_zeros, image_zeros)
+ assert isinstance(loss_value, torch.Tensor)
+
+ # testing using songunet and no encoder loss
+ loss_value = loss_fn(dummy_denoiser, dummy_encoder, image_zeros, image_rnd)
+ assert isinstance(loss_value, torch.Tensor)
+
+
+def test_sfmencoderloss_initialization():
+ """checks SFMEncoderLoss __init__"""
+ loss_fn = SFMEncoderLoss()
+
+ assert loss_fn.encoder_loss_type == "l2"
+
+ # test for invalid loss type
+ with pytest.raises(
+ ValueError,
+ match="encoder_loss_type should be either l1 or l2 not bogus",
+ ):
+ loss_fn = SFMEncoderLoss(
+ encoder_loss_type="bogus",
+ )
+
+
+def test_sfmencoderloss_call():
+ """checks SFMEncoderLoss __call__"""
+ # dummy network for loss
+ dummy_net = torch.nn.Identity()
+ dummy_encoder = get_songunet()
+
+ image_zeros = torch.zeros((2, 4, 8, 8))
+ image_twos = torch.ones((2, 4, 8, 8)) * 2
+ image_rnd = torch.rand((2, 2, 8, 8))
+
+ # test l2 loss
+ loss_fn = SFMEncoderLoss()
+
+ # encoder loss is deterministic
+ loss_value = loss_fn(dummy_net, dummy_net, image_twos, image_twos)
+ assert torch.equal(loss_value, image_zeros)
+
+ loss_value = loss_fn(dummy_net, dummy_net, image_zeros, image_twos)
+ assert torch.equal(loss_value, image_twos * image_twos)
+
+ # test l1 loss
+ loss_fn = SFMEncoderLoss("l1")
+
+ # encoder loss is deterministic
+ loss_value = loss_fn(dummy_net, dummy_net, image_twos, image_twos)
+ assert torch.equal(loss_value, image_zeros)
+
+ loss_value = loss_fn(dummy_net, dummy_net, image_zeros, image_twos)
+ assert torch.equal(loss_value, image_twos)
+
+ # using songunet
+ loss_value = loss_fn(dummy_net, dummy_encoder, image_twos, image_rnd)
+ assert isinstance(loss_value, torch.Tensor)
diff --git a/test/models/diffusion/test_encoders.py b/test/models/diffusion/test_encoders.py
new file mode 100644
index 0000000000..52a5173df3
--- /dev/null
+++ b/test/models/diffusion/test_encoders.py
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from physicsnemo.models.diffusion import Conv2dSerializable
+
+
+def test_conv2dserializable_initialization():
+ conv_net = Conv2dSerializable(in_channels=8, out_channels=4, kernel_size=3)
+
+ assert conv_net.in_channels == 8
+ assert conv_net.out_channels == 4
+ assert conv_net.kernel_size == 3
+
+
+def test_conv2dserializable_forward():
+ test_data = torch.zeros(1, 8, 16, 16)
+ conv_net = Conv2dSerializable(in_channels=8, out_channels=4, kernel_size=3)
+
+ out_tensor = conv_net(test_data)
+ assert out_tensor.shape == (1, 4, 16, 16)
diff --git a/test/models/diffusion/test_sfm_preconditioning.py b/test/models/diffusion/test_sfm_preconditioning.py
new file mode 100644
index 0000000000..d8a7279e78
--- /dev/null
+++ b/test/models/diffusion/test_sfm_preconditioning.py
@@ -0,0 +1,144 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from physicsnemo.models.diffusion import SFMPrecondEmpty, SFMPrecondSR
+
+
+# SFMPrecondEmpty test
+def test_sfmprecondempty_initialzation():
+ """checks SFMPrecondEmpty __init__"""
+ precond = SFMPrecondEmpty()
+
+ assert isinstance(precond, SFMPrecondEmpty)
+ assert precond.label_dim is None
+ assert precond.param == torch.tensor(0.0)
+
+
+# SFMPrecondSR tests
+def test_sfmprecondsr_initialization():
+ """checks SFMPrecondSR __init__"""
+ precond = SFMPrecondSR(
+ img_resolution=[32, 32],
+ img_in_channels=4,
+ img_out_channels=6,
+ sigma_data=0.3,
+ )
+
+ assert isinstance(precond, SFMPrecondSR)
+ assert precond.img_shape_y == 32
+ assert precond.img_shape_x == 32
+ assert precond.sigma_data == 0.3
+
+ # test with sigma_max as a dict
+ sigma_max = {
+ "initial_values": [0.5, 0.5, 0.5],
+ "ema_weight": 0.2,
+ "min_values": 0.1,
+ }
+
+ precond = SFMPrecondSR(
+ img_resolution=[32, 32],
+ img_in_channels=4,
+ img_out_channels=6,
+ sigma_max=sigma_max,
+ sigma_data=0.3,
+ encoder_type="l1",
+ )
+
+ assert isinstance(precond, SFMPrecondSR)
+ assert precond.img_shape_y == 32
+ assert precond.img_shape_x == 32
+ assert torch.equal(precond.sigma_max_current, torch.Tensor([0.5, 0.5, 0.5]))
+ assert torch.equal(precond.ema_weight, torch.tensor(0.2))
+ assert torch.equal(precond.min_values, torch.tensor(0.1))
+ assert precond.sigma_data == 0.3
+
+
+# SFMPrecondSR tests
+def test_sfmprecondsr_sigma():
+ """checks update_sigma_max and get_sigma_max"""
+
+ precond = SFMPrecondSR(
+ img_resolution=[32, 32],
+ img_in_channels=4,
+ img_out_channels=6,
+ sigma_data=0.3,
+ sigma_max=8.0,
+ )
+
+ assert precond.get_sigma_max() == 8.0
+
+ precond.update_sigma_max(16.0)
+ assert precond.sigma_max_current == 16.0
+
+ # test with ema as a dict
+ ema_weight = 0.5
+ init_vals = torch.tensor([16.0, 1.0, 0.5])
+ updated_vals = torch.tensor([8.0, 2.0, 2.0])
+ expected_vals = init_vals * 0.5 + updated_vals * 0.5
+
+ sigma_max = {
+ "initial_values": init_vals.tolist(),
+ "ema_weight": ema_weight,
+ "min_values": 0.1,
+ }
+
+ precond = SFMPrecondSR(
+ img_resolution=[32, 32],
+ img_in_channels=4,
+ img_out_channels=6,
+ sigma_max=sigma_max,
+ sigma_data=0.3,
+ encoder_type="l1",
+ )
+
+ assert torch.equal(precond.sigma_max_current, init_vals)
+
+ precond.update_sigma_max(updated_vals.tolist())
+ assert torch.equal(precond.sigma_max_current, expected_vals)
+
+ assert torch.equal(precond.round_sigma(init_vals.tolist()), init_vals)
+
+
+def test_sfmprecondsr_forward():
+ """checks forward"""
+
+ image_in = torch.zeros((2, 4, 32, 32))
+ image_cond = torch.zeros((2, 6, 32, 32))
+ sigma = torch.rand((2, 1, 1, 1))
+
+ precond = SFMPrecondSR(
+ img_resolution=[32, 32],
+ img_in_channels=6,
+ img_out_channels=4,
+ N_grid_channels=4,
+ )
+
+ out_val = precond(image_in, sigma=sigma, condition=None)
+ assert isinstance(out_val, torch.Tensor)
+
+ precond = SFMPrecondSR(
+ img_resolution=[32, 32],
+ img_in_channels=6,
+ img_out_channels=4,
+ N_grid_channels=4,
+ use_x_low_conditioning=True,
+ )
+
+ out_val = precond(image_in, sigma=sigma, condition=image_cond)
+ assert isinstance(out_val, torch.Tensor)
diff --git a/test/utils/generative/test_sfm_sampler.py b/test/utils/generative/test_sfm_sampler.py
new file mode 100644
index 0000000000..abee9ffdcf
--- /dev/null
+++ b/test/utils/generative/test_sfm_sampler.py
@@ -0,0 +1,163 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from omegaconf import DictConfig
+
+from physicsnemo.models.diffusion import SFMPrecondSR, SongUNetPosEmbd
+from physicsnemo.utils.generative import (
+ SFM_encoder_sampler,
+ SFM_Euler_sampler,
+ SFM_Euler_sampler_Adaptive_Sigma,
+ StackedRandomGenerator,
+)
+
+
+def get_songunet():
+ """helper that creates a songunet for testing"""
+ songunet_kwargs = {
+ "img_resolution": 32,
+ "in_channels": 6,
+ "out_channels": 6,
+ "embedding_type": "zero",
+ "label_dim": 0,
+ "encoder_type": "standard",
+ "decoder_type": "standard",
+ "channel_mult_noise": 1,
+ "resample_filter": [1, 1],
+ "channel_mult": [1, 2, 2],
+ "attn_resolutions": [28],
+ "N_grid_channels": 0,
+ "model_channels": 4,
+ }
+ return SongUNetPosEmbd(**songunet_kwargs)
+
+
+class fake_net(torch.nn.Module):
+ """dummy class to test sfm encoder"""
+
+ def __init__(self, sigma_max=[0.5]):
+ self.sigma_max = torch.tensor(sigma_max)
+
+ def get_sigma_max(self):
+ return self.sigma_max
+
+ def round_sigma(self, x):
+ return torch.tensor(x)
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+
+def test_sfm_encoder_sampler():
+ """SFM_encoder_sampler"""
+ dummy_net = torch.nn.Identity()
+ encoder = get_songunet()
+
+ image_rnd = torch.rand((2, 6, 8, 8))
+
+ networks = {
+ "encoder_net": dummy_net,
+ "denoiser_net": None,
+ }
+
+ out_val = SFM_encoder_sampler(networks, image_rnd)
+ assert torch.equal(out_val, image_rnd)
+
+ networks = {
+ "encoder_net": encoder,
+ "denoiser_net": None,
+ }
+ out_val = SFM_encoder_sampler(networks, image_rnd)
+ assert isinstance(out_val, torch.Tensor)
+
+
+def test_sfm_euler_sampler():
+ """SFM_Euler_sampler"""
+ encoder = get_songunet()
+ denoiser = SFMPrecondSR(
+ img_resolution=[32, 32],
+ img_in_channels=4,
+ img_out_channels=6,
+ N_grid_channels=4,
+ ).to("cpu")
+ dummy_net = torch.nn.Identity()
+
+ batch_seeds = torch.as_tensor([1, 2]).to("cpu")
+ image_rnd = torch.rand((2, 6, 32, 32)).to("cpu")
+ rnd = StackedRandomGenerator(image_rnd.device, batch_seeds)
+
+ cfg = {"rho": 7, "num_steps": 5, "sigma_min": 0.01, "t_min": 0.002}
+ cfg = DictConfig(cfg)
+
+ networks = {
+ "encoder_net": dummy_net,
+ "denoiser_net": denoiser,
+ }
+
+ out_val = SFM_Euler_sampler(
+ networks=networks, img_lr=image_rnd, randn_like=rnd.randn_like, cfg=cfg
+ )
+ assert isinstance(out_val, torch.Tensor)
+
+ networks = {
+ "encoder_net": encoder,
+ "denoiser_net": denoiser,
+ }
+
+ out_val = SFM_Euler_sampler(
+ networks=networks, img_lr=image_rnd, randn_like=rnd.randn_like, cfg=cfg
+ )
+ assert isinstance(out_val, torch.Tensor)
+
+
+def test_sfm_euler_sampler_adaptive_sigma():
+ """SFM_Euler_sampler"""
+ encoder = get_songunet()
+ denoiser = SFMPrecondSR(
+ img_resolution=[32, 32],
+ img_in_channels=4,
+ img_out_channels=6,
+ N_grid_channels=4,
+ ).to("cpu")
+ dummy_net = torch.nn.Identity()
+
+ batch_seeds = torch.as_tensor([1, 2]).to("cpu")
+ image_rnd = torch.rand((2, 6, 32, 32)).to("cpu")
+ rnd = StackedRandomGenerator(image_rnd.device, batch_seeds)
+
+ cfg = {"rho": 7, "num_steps": 5, "sigma_min": 0.01, "t_min": 0.002}
+ cfg = DictConfig(cfg)
+
+ networks = {
+ "encoder_net": dummy_net,
+ "denoiser_net": denoiser,
+ }
+
+ out_val = SFM_Euler_sampler_Adaptive_Sigma(
+ networks=networks, img_lr=image_rnd, randn_like=rnd.randn_like, cfg=cfg
+ )
+ assert isinstance(out_val, torch.Tensor)
+
+ networks = {
+ "encoder_net": encoder,
+ "denoiser_net": denoiser,
+ }
+
+ out_val = SFM_Euler_sampler_Adaptive_Sigma(
+ networks=networks, img_lr=image_rnd, randn_like=rnd.randn_like, cfg=cfg
+ )
+ assert isinstance(out_val, torch.Tensor)