diff --git a/docs/img/sfm_model.png b/docs/img/sfm_model.png new file mode 100755 index 0000000000..81b22634e5 Binary files /dev/null and b/docs/img/sfm_model.png differ diff --git a/examples/generative/corrdiff_plus_plus/README.md b/examples/generative/corrdiff_plus_plus/README.md new file mode 100644 index 0000000000..b6b18db18a --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/README.md @@ -0,0 +1,149 @@ + +## CorrDiff++: Adaptive Flow Matching for Resolving Small-Scale Physics + +## Problem overview + +Conditional diffusion and flow-based models have shown promise for super-resolving fine-scale structure in natural images. However, applying them to weather and physical domains introduces unique challenges: +(i) spatial misalignment between input and target fields due to differing PDE resolutions, +(ii) mismatched and heterogeneous input-output channels (i.e., channel synthesis), and +(iii) channel-dependent stochasticity. + +**CorrDiff** was proposed to address these issues but suffers from poor generalization—its regression-based residuals can overfit training data, leading to degraded performance on out-of-distribution inputs. + +To overcome these limitations, **CorDiff++** was introduced at NVIDIA. It relies on adaptive floe matching (AFM) that improves upon CorrDiff in generalization, calibration, and efficiency through several key innovations: + +- A joint encoder–generator architecture trained via flow matching +- A latent base distribution that reconstructs the large-scale, deterministic signal +- An adaptive noise scaling mechanism informed by the encoder’s RMSE, used to inject calibrated uncertainty +- A final flow matching step to refine latent samples and synthesize fine-scale physical details + +AFM outperforms previous methods across both real-world (e.g., 25 → 2 km super-resolution in Taiwan) and synthetic (Kolmogorov flow) benchmarks—especially for highly stochastic output channels. + +📄 For details, see the [SFM paper (arXiv:2410.19814)](https://arxiv.org/abs/2410.19814). + +

+ +

+ +## Getting started + +To build custom CorrDiff++ versions, you can get started by training the "Mini" version of CorrDiff++, which uses smaller training samples and a smaller network to reduce training costs from thousands of GPU hours to around 10 hours on A100 GPUs while still producing reasonable results. It also includes a simple data loader that can be used as a baseline for training CorrDiff++ on custom datasets. [@mohammad: the mini version nees to be created similar to corrdiff and needs to be tested] + +### Preliminaries +Start by installing Modulus (if not already installed) and copying this folder (`examples/generative/corrdiff++`) to a system with a GPU available. Also download the CorrDiff++ dataset from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets-hrrr_mini). + +### Configuration basics + +Similar to its earlier version, CorrDiff++ training is handled by `train.py` and controlled by YAML configuration files handled by [Hydra](https://hydra.cc/docs/intro/). Prebuilt configuration files are found in the `conf` directory. You can choose the configuration file using the `--config-name` option. The main configuration file specifies the training dataset, the model configuration and the training options. The details of these are given in the corresponding configuration files. To change a configuration option, you can either edit the configuration files or use the Hydra command line overrides. For example, the training batch size is controlled by the option `training.hp.total_batch_size`. We can override this from the command line with the `++` syntax: `python train.py ++training.hp.total_batch_size=64` would set run the training with the batch size set to 64. + +### Joint training of encoder and generator via flow matching + +To train CorrDiff++, we use the main configuration file [config_train.yaml](/conf/config_training_sfm.yaml). This includes the following components: +To start the training, run: +```bash +python train.py --config-name=config_training_sfm.yaml ++dataset.data_path=/hrrr_mini_train.nc ++dataset.stats_path=/stats.json +``` + +The training will require a few hours on a single A100 GPU. If training is interrupted, it will automatically continue from the latest checkpoint when restarted. Multi-GPU and multi-node training are supported and will launch automatically when the training is run in a `torchrun` or MPI environment. + +The results, including logs and checkpoints, are saved by default to `outputs/mini_generation/`. You can direct the checkpoints to be saved elsewhere by setting: `++training.io.checkpoint_dir=`. + +> **_Out of memory?_** CorrDiff-Mini trains by default with a batch size of 256 (set by `training.hp.total_batch_size`). If you're using a single GPU, especially one with a smaller amout of memory, you might see out-of-memory error. If that happens, set a smaller batch size per GPU, e.g.: `++training.hp.batch_size_per_gpu=16`. CorrDiff training will then automatically use gradient accumulation to train with an effective batch size of `training.hp.total_batch_size`. + + +### Generation + +Use the `generate.py` script to generate samples with the trained networks: +```bash +python generate.py --config-name="config_generate_sfm.yaml" ++generation.io.res_ckpt_filename= ++generation.io.reg_ckpt_filename= ++generation.io.output_filename= +``` +where `` and `` should point to the encoder and diffusion model checkpoints, respectively, and `` indicates the output NetCDF4 file. + +You can open the output file with e.g. the Python NetCDF4 library. The inputs are saved in the `input` group of the file, the ground truth data in the `truth` group, and the CorrDiff prediction in the `prediction` group. + +## Configs + +The `conf` directory contains the configuration files for the model, data, +training, etc. The configs are given in YAML format and use the `omegaconf` +library to manage them. Several example configs are given for training +different models that are encoder only, flow matching with pretrained encoder, and encoder with diffusion +models. +The default configs are set to train the encoder with diffusion model. +To train the other models, please adjust `conf/config_training.yaml` +according to the comments. Alternatively, you can create a new config file +and specify it using the `--config-name` option. + + +## Dataset & Datapipe + +In this example, CorrDiff training is demonstrated on the Taiwan dataset, +conditioned on the [ERA5 dataset](https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5). +We have made this dataset available for non-commercial use under the +[CC BY-NC-ND 4.0 license](https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode.en) +and can be downloaded from [https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets_cwa](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets_cwa) +by `ngc registry resource download-version "nvidia/modulus/modulus_datasets_cwa:v1"`. +The datapipe in this example is tailored specifically for the Taiwan dataset. +A light-weight datapipe for the HRRR dataset is also available and can be used +with the CorrDiff++ model. +For other datasets, you will need to create a custom datapipe. +You can use the lightweight HRRR datapipe as a starting point for developing your new one. + + + +### Sampling and Model Evaluation + +Model evaluation is split into two components. `generate.py` creates a netCDF file +for the generated outputs, and `score_samples.py` computes deterministic and probablistic +scores. + +To generate samples and save output in a netCDF file, run: + +```bash +python generate.py +``` +This will use the base configs specified in the `conf/config_generate.yaml` file. + +Next, to score the generated samples, run: + +```bash +python score_samples.py path= output= +``` + +Some legacy plotting scripts are also available in the `inference` directory. +You can also bring your checkpoints to [earth2studio] +for further anaylysis and visualizations. + +## Logging + +We use TensorBoard for logging training and validation losses, as well as +the learning rate during training. To visualize TensorBoard running in a +Docker container on a remote server from your local desktop, follow these steps: + +1. **Expose the Port in Docker:** + Expose port 6006 in the Docker container by including + `-p 6006:6006` in your docker run command. + +2. **Launch TensorBoard:** + Start TensorBoard within the Docker container: + ```bash + tensorboard --logdir=/path/to/logdir --port=6006 + ``` + +3. **Set Up SSH Tunneling:** + Create an SSH tunnel to forward port 6006 from the remote server to your local machine: + ```bash + ssh -L 6006:localhost:6006 @ + ``` + Replace `` with your SSH username and `` with the IP address + of your remote server. You can use a different port if necessary. + +4. **Access TensorBoard:** + Open your web browser and navigate to `http://localhost:6006` to view TensorBoard. + +**Note:** Ensure the remote server’s firewall allows connections on port `6006` +and that your local machine’s firewall allows outgoing connections. + + +## References + +- [Adaptive Flow Matching for Resolving Small-Scale Physics](https://openreview.net/forum?id=YJ1My9ttEN) diff --git a/examples/generative/corrdiff_plus_plus/conf/config_generate.yaml b/examples/generative/corrdiff_plus_plus/conf/config_generate.yaml new file mode 100644 index 0000000000..30a0f38fe1 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/config_generate.yaml @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: generation + run: + dir: ./outputs/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/cwb_generate + + # Generation + - generation/sfm diff --git a/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_encoder.yaml new file mode 100644 index 0000000000..05e69f12c4 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_encoder.yaml @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: generation + run: + dir: ./outputs/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/cwb_generate + + # Generation + - generation/sfm_encoder diff --git a/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_two_stage.yaml new file mode 100644 index 0000000000..5b1661ff56 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/config_generate_sfm_two_stage.yaml @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: generation + run: + dir: ./outputs/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/cwb_generate + + # Generation + - generation/sfm_two_stage diff --git a/examples/generative/corrdiff_plus_plus/conf/config_training.yaml b/examples/generative/corrdiff_plus_plus/conf/config_training.yaml new file mode 100644 index 0000000000..3fa07a5e08 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/config_training.yaml @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: sfm # choose from ["sfm_encoder", "sfm", "sfm_two_stage"] + run: + dir: ./outputs/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/cwb_train + + # Model + #- model/sfm_encoder + - model/sfm + #- model/sfm_two_stage + + # Training + #- training/sfm_encoder + - training/sfm + #- training/sfm_two_stage + + # Validation (comment out to disable validation) + - validation/cwb diff --git a/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_encoder.yaml new file mode 100644 index 0000000000..73dd814b39 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_encoder.yaml @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: sfm # choose from ["sfm_encoder", "sfm", "sfm_two_stage"] + run: + dir: ./outputs/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/cwb_train + + # Model + - model/sfm_encoder + #- model/sfm + #- model/sfm_two_stage + + # Training + - training/sfm_encoder + #- training/sfm + #- training/sfm_two_stage + + # Validation (comment out to disable validation) + - validation/cwb diff --git a/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_two_stage.yaml new file mode 100644 index 0000000000..bf8d07273e --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/config_training_sfm_two_stage.yaml @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: sfm # choose from ["sfm_encoder", "sfm", "sfm_two_stage"] + run: + dir: ./outputs/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/cwb_train + + # Model + #- model/sfm_encoder + #- model/sfm + - model/sfm_two_stage + + # Training + #- training/sfm_encoder + #- training/sfm + - training/sfm_two_stage + + # Validation (comment out to disable validation) + - validation/cwb diff --git a/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_generate.yaml b/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_generate.yaml new file mode 100644 index 0000000000..787e1c9847 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_generate.yaml @@ -0,0 +1,31 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +type: cwb +data_path: /code/2023-01-24-cwb-4years.zarr +in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19] +out_channels: [0, 17, 18, 19] +img_shape_x: 448 +img_shape_y: 448 +add_grid: true +ds_factor: 4 +min_path: null +max_path: null +global_means_path: null +global_stds_path: null +train: False +all_times: True \ No newline at end of file diff --git a/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_train.yaml b/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_train.yaml new file mode 100644 index 0000000000..5bcd741031 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/dataset/cwb_train.yaml @@ -0,0 +1,29 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +type: cwb +data_path: /code/2023-01-24-cwb-4years.zarr +in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19] +out_channels: [0, 17, 18, 19] +img_shape_x: 448 +img_shape_y: 448 +add_grid: true +ds_factor: 4 +min_path: null +max_path: null +global_means_path: null +global_stds_path: null diff --git a/examples/generative/corrdiff_plus_plus/conf/generation/sfm.yaml b/examples/generative/corrdiff_plus_plus/conf/generation/sfm.yaml new file mode 100644 index 0000000000..f5d6f7e28d --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/generation/sfm.yaml @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +num_ensembles: 64 + # Number of ensembles to generate per input +seed_batch_size: 4 + # Size of the batched inference +inference_mode: sfm + # Choose between "sfm", "sfm_encoder" or "sfm_two_stage" +learnable_sigma: True + # whether the model uses a learnable sigma +gridtype: "sinusoidal" + +sampler: + rho: 7 + # shape of noise schedule curve + num_steps: 50 + # Number of sampling steps + sigma_min: 0.01 + # Lowest noise level + t_min: 0.002 + # clamp time step to this value + +N_grid_channels: 4 +times_range: null +times: + - 2021-02-02T00:00:00 + - 2021-03-02T00:00:00 + - 2021-04-02T00:00:00 + # hurricane + - 2021-09-12T00:00:00 + - 2021-09-12T12:00:00 + +perf: + force_fp16: false + # Whether to force fp16 precision for the model. If false, it'll use the precision + # specified upon training. + use_torch_compile: false + # whether to use torch.compile on the diffusion model + # this will make the first time stamp generation very slow due to compilation overheads + # but will significantly speed up subsequent inference runs + num_writer_workers: 1 + # number of workers to use for writing file + # To support multiple workers a threadsafe version of the netCDF library must be used + +io: + encoder_ckpt_filename: Conv2dSerializable.0.895232.mdlus + # Checkpoint filename for the diffusion model + denoiser_ckpt_filename: EDMPrecondSR.0.895232.mdlus + # Checkpoint filename for the mean predictor model diff --git a/examples/generative/corrdiff_plus_plus/conf/generation/sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/generation/sfm_encoder.yaml new file mode 100644 index 0000000000..07a0663ef8 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/generation/sfm_encoder.yaml @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +num_ensembles: 64 + # Number of ensembles to generate per input +seed_batch_size: 4 + # Size of the batched inference +inference_mode: sfm_encoder + # Choose between "sfm", "sfm_encoder" or "sfm_two_stage" +gridtype: "sinusoidal" + +sampler: null + # no config options for the encoder sampler + +N_grid_channels: 4 +times_range: null +times: + - 2021-02-02T00:00:00 + - 2021-03-02T00:00:00 + - 2021-04-02T00:00:00 + # hurricane + - 2021-09-12T00:00:00 + - 2021-09-12T12:00:00 + +perf: + force_fp16: false + # Whether to force fp16 precision for the model. If false, it'll use the precision + # specified upon training. + use_torch_compile: false + # whether to use torch.compile on the diffusion model + # this will make the first time stamp generation very slow due to compilation overheads + # but will significantly speed up subsequent inference runs + num_writer_workers: 1 + # number of workers to use for writing file + # To support multiple workers a threadsafe version of the netCDF library must be used + +io: + encoder_ckpt_filename: outputs/sfm/checkpoints_sfm_encoder/Conv2dSerializable.0.5120.mdlus + # Checkpoint filename for the encoder model diff --git a/examples/generative/corrdiff_plus_plus/conf/generation/sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/generation/sfm_two_stage.yaml new file mode 100644 index 0000000000..14d586a9b4 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/generation/sfm_two_stage.yaml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - sfm + +io: + encoder_ckpt_filename: outputs/sfm/checkpoints_sfm_two_stage/Conv2dSerializable.0.5120.mdlus + # Checkpoint filename for the encoder model + denoiser_ckpt_filename: outputs/sfm/checkpoints_sfm_two_stage/SFMPrecondSR.0.5120.mdlus + # Checkpoint filename for the diffusion model diff --git a/examples/generative/corrdiff_plus_plus/conf/model/sfm.yaml b/examples/generative/corrdiff_plus_plus/conf/model/sfm.yaml new file mode 100644 index 0000000000..60a0661b44 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/model/sfm.yaml @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: sfm +# Name of the model + +encoder_type: 1x1conv +# Different encoder types for SFM: 1x1conv, songunet_s, songunet_xs, songunet_2xs, songunet_3xs +encoder_loss_type: l2 +# Regularizer for the SFM model [null, l2, l1] +encoder_loss_weight: 0.25 +# Regularizer weight for the SFM model [null, float] + +sigma_min: [0.002, 0.002, 0.002, 0.002] # for sampling +# Minimum value of the noise sigma + +model_args: + dropout: 0.13 + # Dropout probability + # Maximum value of the noise sigma + sigma_max: + initial_values: [1.0, 1.0, 1.0, 1.0] + learnable: True + # For learnable sigma_max + min_values: [0.05, 0.05, 0.05, 0.05] + # Minimum value of sigma_max distribution (only if learnable) + ema_weight: 0.99 + # EMA weight for sigma_max (only if learnable) + use_x_low_conditioning: False + # Conditioning for the diffusion model on the upsampled ERA data [True, False] diff --git a/examples/generative/corrdiff_plus_plus/conf/model/sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/model/sfm_encoder.yaml new file mode 100644 index 0000000000..205538a592 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/model/sfm_encoder.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: sfm_encoder +# Name of the model + +encoder_type: 1x1conv +# Different encoder types for SFM: 1x1conv, songunet_s, songunet_xs, songunet_2xs, songunet_3xs +encoder_loss_type: l2 +# Regularizer for the SFM model [null, l2, l1] diff --git a/examples/generative/corrdiff_plus_plus/conf/model/sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/model/sfm_two_stage.yaml new file mode 100644 index 0000000000..09801532e3 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/model/sfm_two_stage.yaml @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: sfm_two_stage +# Name of the model + +encoder_loss_type: l2 +# Regularizer for the SFM model [null, l2, l1] +encoder_loss_weight: 0.25 +# Regularizer weight for the SFM model [null, float] + +sigma_min: [0.002, 0.002, 0.002, 0.002] # for sampling +# Minimum value of the noise sigma + +model_args: + dropout: 0.13 + # Dropout probability + # Maximum value of the noise sigma + sigma_max: + initial_values: [1.0, 1.0, 1.0, 1.0] + learnable: False + # For learnable sigma_max + min_values: [0.05, 0.05, 0.05, 0.05] + # Minimum value of sigma_max distribution (only if learnable) + ema_weight: 0.99 + # EMA weight for sigma_max (only if learnable) + use_x_low_conditioning: False + # Conditioning for the diffusion model on the upsampled ERA data [True, False] diff --git a/examples/generative/corrdiff_plus_plus/conf/training/sfm.yaml b/examples/generative/corrdiff_plus_plus/conf/training/sfm.yaml new file mode 100644 index 0000000000..00a1a767a6 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/training/sfm.yaml @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Hyperparameters +hp: + training_duration: 8000000 + # Training duration based on the number of processed samples + total_batch_size: 256 + # Total batch size + batch_size_per_gpu: 4 + # Batch size per GPU + lr: 0.0002 + # Learning rate + grad_clip_threshold: null + # no gradient clipping for defualt non-patch-based training + lr_decay: 1 + # LR decay rate + lr_rampup: 0 + # Rampup for learning rate, in number of samples + ema: 0.5 + # EMA half-life + ema_rampup_ratio: 0.05 + # EMA ramp-up coefficient, None = no rampup. + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 1 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# I/O +io: + # encoder_checkpoint_path: checkpoints/sfm_encoder.mdlus + # Where to load the regression checkpoint + print_progress_freq: 10000 + # How often to print progress + save_checkpoint_freq: 5000 + # How often to save the checkpoints, measured in number of processed samples + validation_freq: 5000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 10 + # how many loss evaluations are used to compute the validation loss per checkpoint diff --git a/examples/generative/corrdiff_plus_plus/conf/training/sfm_encoder.yaml b/examples/generative/corrdiff_plus_plus/conf/training/sfm_encoder.yaml new file mode 100644 index 0000000000..7a88566287 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/training/sfm_encoder.yaml @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - sfm + +# Hyperparameters +hp: + training_duration: 2000000 + # Training duration based on the number of processed samples + total_batch_size: 256 + # Total batch size + batch_size_per_gpu: 32 + # Batch size per GPU + lr: 0.0001 + # Learning rate + grad_clip_threshold: null + # no gradient clipping for defualt non-patch-based training diff --git a/examples/generative/corrdiff_plus_plus/conf/training/sfm_two_stage.yaml b/examples/generative/corrdiff_plus_plus/conf/training/sfm_two_stage.yaml new file mode 100644 index 0000000000..dec4ae4983 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/training/sfm_two_stage.yaml @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - sfm + +# Hyperparameters +hp: + training_duration: 4000000 + # Training duration based on the number of processed samples + batch_size_per_gpu: 4 + # Batch size per GPU + lr: 0.0001 + # Learning rate + +# I/O +io: + encoder_checkpoint_path: outputs/sfm/checkpoints_sfm_encoder/Conv2dSerializable.0.5120.mdlus + # Where to load the encoder checkpoint diff --git a/examples/generative/corrdiff_plus_plus/conf/validation/cwb.yaml b/examples/generative/corrdiff_plus_plus/conf/validation/cwb.yaml new file mode 100644 index 0000000000..f29f412a7c --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/conf/validation/cwb.yaml @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Validation dataset options +# (need to set dataset.train_test_split == true to have an effect) +train: false +all_times: false diff --git a/examples/generative/corrdiff_plus_plus/datasets/base.py b/examples/generative/corrdiff_plus_plus/datasets/base.py new file mode 100644 index 0000000000..22b00d252c --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/datasets/base.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Tuple + +import numpy as np +import torch + + +@dataclass +class ChannelMetadata: + """Metadata describing a data channel.""" + + name: str + level: str = "" + auxiliary: bool = False + + +class DownscalingDataset(torch.utils.data.Dataset, ABC): + """An abstract class that defines the interface for downscaling datasets.""" + + @abstractmethod + def longitude(self) -> np.ndarray: + """Get longitude values from the dataset.""" + pass + + @abstractmethod + def latitude(self) -> np.ndarray: + """Get latitude values from the dataset.""" + pass + + @abstractmethod + def input_channels(self) -> List[ChannelMetadata]: + """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def output_channels(self) -> List[ChannelMetadata]: + """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def time(self) -> List: + """Get time values from the dataset.""" + pass + + @abstractmethod + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + pass + + def normalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from physical units to normalized data.""" + return x + + def denormalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from normalized data to physical units.""" + return x + + def normalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from physical units to normalized data.""" + return x + + def denormalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from normalized data to physical units.""" + return x + + def info(self) -> dict: + """Get information about the dataset.""" + return {} diff --git a/examples/generative/corrdiff_plus_plus/datasets/cwb.py b/examples/generative/corrdiff_plus_plus/datasets/cwb.py new file mode 100644 index 0000000000..91f469633c --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/datasets/cwb.py @@ -0,0 +1,531 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Streaming images and labels from datasets created with dataset_tool.py.""" + +import logging +import random + +import cftime +import cv2 +from hydra.utils import to_absolute_path +import numpy as np +import zarr + +from .base import ChannelMetadata, DownscalingDataset +from .img_utils import reshape_fields +from .norm import denormalize, normalize + +logger = logging.getLogger(__file__) + + +def get_target_normalizations_v1(group): + """Get target normalizations using center and scale values from the 'group'.""" + return group["cwb_center"][:], group["cwb_scale"][:] + + +def get_target_normalizations_v2(group): + """Change the normalizations of the non-gaussian output variables""" + center = group["cwb_center"] + scale = group["cwb_scale"] + variable = group["cwb_variable"] + + center = np.where(variable == "maximum_radar_reflectivity", 25.0, center) + center = np.where(variable == "eastward_wind_10m", 0.0, center) + center = np.where(variable == "northward_wind_10m", 0, center) + + scale = np.where(variable == "maximum_radar_reflectivity", 25.0, scale) + scale = np.where(variable == "eastward_wind_10m", 20.0, scale) + scale = np.where(variable == "northward_wind_10m", 20.0, scale) + return center, scale + + +class _ZarrDataset(DownscalingDataset): + """A Dataset for loading paired training data from a Zarr-file + + This dataset should not be modified to add image processing contributions. + """ + + path: str + + def __init__( + self, path: str, get_target_normalization=get_target_normalizations_v1 + ): + self.path = path + self.group = zarr.open_consolidated(path) + self.get_target_normalization = get_target_normalization + + # valid indices + cwb_valid = self.group["cwb_valid"] + era5_valid = self.group["era5_valid"] + if not ( + era5_valid.ndim == 2 + and cwb_valid.ndim == 1 + and cwb_valid.shape[0] == era5_valid.shape[0] + ): + raise ValueError("Invalid dataset shape") + era5_all_channels_valid = np.all(era5_valid, axis=-1) + valid_times = cwb_valid & era5_all_channels_valid + # need to cast to bool since cwb_valis is stored as an int8 type in zarr. + self.valid_times = valid_times != 0 + + logger.info("Number of valid times: %d", len(self)) + logger.info("input_channels:%s", self.input_channels()) + logger.info("output_channels:%s", self.output_channels()) + + def _get_valid_time_index(self, idx): + time_indexes = np.arange(self.group["time"].size) + if not self.valid_times.dtype == np.bool_: + raise ValueError("valid_times must be a boolean array") + valid_time_indexes = time_indexes[self.valid_times] + return valid_time_indexes[idx] + + def __getitem__(self, idx): + idx_to_load = self._get_valid_time_index(idx) + target = self.group["cwb"][idx_to_load] + input = self.group["era5"][idx_to_load] + label = 0 + + target = self.normalize_output(target[None, ...])[0] + input = self.normalize_input(input[None, ...])[0] + + return target, input, label + + def longitude(self): + """The longitude. useful for plotting""" + return self.group["XLONG"] + + def latitude(self): + """The latitude. useful for plotting""" + return self.group["XLAT"] + + def _get_channel_meta(self, variable, level): + if np.isnan(level): + level = "" + return ChannelMetadata(name=variable, level=str(level)) + + def input_channels(self): + """Metadata for the input channels. A list of dictionaries, one for each channel""" + variable = self.group["era5_variable"] + level = self.group["era5_pressure"] + return [self._get_channel_meta(*v) for v in zip(variable, level)] + + def output_channels(self): + """Metadata for the output channels. A list of dictionaries, one for each channel""" + variable = self.group["cwb_variable"] + level = self.group["cwb_pressure"] + return [self._get_channel_meta(*v) for v in zip(variable, level)] + + def _read_time(self): + """The vector of time coordinate has length (self)""" + + return cftime.num2date( + self.group["time"], units=self.group["time"].attrs["units"] + ) + + def time(self): + """The vector of time coordinate has length (self)""" + time = self._read_time() + return time[self.valid_times].tolist() + + def image_shape(self): + """Get the shape of the image (same for input and output).""" + return self.group["cwb"].shape[-2:] + + def _select_norm_channels(self, means, stds, channels): + if channels is not None: + means = means[channels] + stds = stds[channels] + return (means, stds) + + def normalize_input(self, x, channels=None): + """Convert input from physical units to normalized data.""" + norm = self._select_norm_channels( + self.group["era5_center"], self.group["era5_scale"], channels + ) + return normalize(x, *norm) + + def denormalize_input(self, x, channels=None): + """Convert input from normalized data to physical units.""" + norm = self._select_norm_channels( + self.group["era5_center"], self.group["era5_scale"], channels + ) + return denormalize(x, *norm) + + def normalize_output(self, x, channels=None): + """Convert output from physical units to normalized data.""" + norm = self.get_target_normalization(self.group) + norm = self._select_norm_channels(*norm, channels) + return normalize(x, *norm) + + def denormalize_output(self, x, channels=None): + """Convert output from normalized data to physical units.""" + norm = self.get_target_normalization(self.group) + norm = self._select_norm_channels(*norm, channels) + return denormalize(x, *norm) + + def info(self): + return { + "target_normalization": self.get_target_normalization(self.group), + "input_normalization": ( + self.group["era5_center"][:], + self.group["era5_scale"][:], + ), + } + + def __len__(self): + return self.valid_times.sum() + + +class FilterTime(DownscalingDataset): + """Filter a time dependent dataset""" + + def __init__(self, dataset, filter_fn): + """ + Args: + filter_fn: if filter_fn(time) is True then return point + """ + self._dataset = dataset + self._filter_fn = filter_fn + self._indices = [i for i, t in enumerate(self._dataset.time()) if filter_fn(t)] + + def longitude(self): + """Get longitude values from the dataset.""" + return self._dataset.longitude() + + def latitude(self): + """Get latitude values from the dataset.""" + return self._dataset.latitude() + + def input_channels(self): + """Metadata for the input channels. A list of dictionaries, one for each channel""" + return self._dataset.input_channels() + + def output_channels(self): + """Metadata for the output channels. A list of dictionaries, one for each channel""" + return self._dataset.output_channels() + + def time(self): + """Get time values from the dataset.""" + time = self._dataset.time() + return [time[i] for i in self._indices] + + def info(self): + """Get information about the dataset.""" + return self._dataset.info() + + def image_shape(self): + """Get the shape of the image (same for input and output).""" + return self._dataset.image_shape() + + def normalize_input(self, x, channels=None): + """Convert input from physical units to normalized data.""" + return self._dataset.normalize_input(x, channels=channels) + + def denormalize_input(self, x, channels=None): + """Convert input from normalized data to physical units.""" + return self._dataset.denormalize_input(x, channels=channels) + + def normalize_output(self, x, channels=None): + """Convert output from physical units to normalized data.""" + return self._dataset.normalize_output(x, channels=channels) + + def denormalize_output(self, x, channels=None): + """Convert output from normalized data to physical units.""" + return self._dataset.denormalize_output(x, channels=channels) + + def __getitem__(self, idx): + return self._dataset[self._indices[idx]] + + def __len__(self): + return len(self._indices) + + +def is_2021(time): + """Check if the given time is in the year 2021.""" + return time.year == 2021 + + +def is_not_2021(time): + """Check if the given time is not in the year 2021.""" + return not is_2021(time) + + +class ZarrDataset(DownscalingDataset): + """A Dataset for loading paired training data from a Zarr-file with the + following schema:: + + xarray.Dataset { + dimensions: + south_north = 450 ; + west_east = 450 ; + west_east_stag = 451 ; + south_north_stag = 451 ; + time = 8760 ; + cwb_channel = 20 ; + era5_channel = 20 ; + + variables: + float32 XLAT(south_north, west_east) ; + XLAT:FieldType = 104 ; + XLAT:MemoryOrder = XY ; + XLAT:description = LATITUDE, SOUTH IS NEGATIVE ; + XLAT:stagger = ; + XLAT:units = degree_north ; + float32 XLAT_U(south_north, west_east_stag) ; + XLAT_U:FieldType = 104 ; + XLAT_U:MemoryOrder = XY ; + XLAT_U:description = LATITUDE, SOUTH IS NEGATIVE ; + XLAT_U:stagger = X ; + XLAT_U:units = degree_north ; + float32 XLAT_V(south_north_stag, west_east) ; + XLAT_V:FieldType = 104 ; + XLAT_V:MemoryOrder = XY ; + XLAT_V:description = LATITUDE, SOUTH IS NEGATIVE ; + XLAT_V:stagger = Y ; + XLAT_V:units = degree_north ; + float32 XLONG(south_north, west_east) ; + XLONG:FieldType = 104 ; + XLONG:MemoryOrder = XY ; + XLONG:description = LONGITUDE, WEST IS NEGATIVE ; + XLONG:stagger = ; + XLONG:units = degree_east ; + float32 XLONG_U(south_north, west_east_stag) ; + XLONG_U:FieldType = 104 ; + XLONG_U:MemoryOrder = XY ; + XLONG_U:description = LONGITUDE, WEST IS NEGATIVE ; + XLONG_U:stagger = X ; + XLONG_U:units = degree_east ; + float32 XLONG_V(south_north_stag, west_east) ; + XLONG_V:FieldType = 104 ; + XLONG_V:MemoryOrder = XY ; + XLONG_V:description = LONGITUDE, WEST IS NEGATIVE ; + XLONG_V:stagger = Y ; + XLONG_V:units = degree_east ; + datetime64[ns] XTIME() ; + XTIME:FieldType = 104 ; + XTIME:MemoryOrder = 0 ; + XTIME:description = minutes since 2022-12-18 13:00:00 ; + XTIME:stagger = ; + float32 cwb(time, cwb_channel, south_north, west_east) ; + float32 cwb_center(cwb_channel) ; + float64 cwb_pressure(cwb_channel) ; + float32 cwb_scale(cwb_channel) ; + bool cwb_valid(time) ; + 1: + target = self._create_lowres_(target, factor=self.ds_factor) + + reshape_args = ( + y_roll, + self.train, + self.n_history, + self.in_channels, + self.out_channels, + self.img_shape_x, + self.img_shape_y, + self.min_path, + self.max_path, + self.global_means_path, + self.global_stds_path, + self.normalization, + self.roll, + ) + # SR + input = reshape_fields( + input, + "inp", + *reshape_args, + normalize=False, + ) # 3x720x1440 + target = reshape_fields( + target, "tar", *reshape_args, normalize=False + ) # 3x720x1440 + + return target, input, idx + + def input_channels(self): + """Metadata for the input channels. A list of dictionaries, one for each channel""" + in_channels = self._dataset.input_channels() + in_channels = [in_channels[i] for i in self.in_channels] + return in_channels + + def output_channels(self): + """Metadata for the output channels. A list of dictionaries, one for each channel""" + out_channels = self._dataset.output_channels() + return [out_channels[i] for i in self.out_channels] + + def __len__(self): + return len(self._dataset) + + def longitude(self): + """Get longitude values from the dataset.""" + lon = self._dataset.longitude() + return lon if self.train else lon[..., : self.img_shape_y, : self.img_shape_x] + + def latitude(self): + """Get latitude values from the dataset.""" + lat = self._dataset.latitude() + return lat if self.train else lat[..., : self.img_shape_y, : self.img_shape_x] + + def time(self): + """Get time values from the dataset.""" + return self._dataset.time() + + def image_shape(self): + """Get the shape of the image (same for input and output).""" + return (self.img_shape_x, self.img_shape_y) + + def normalize_input(self, x): + """Convert input from physical units to normalized data.""" + x_norm = self._dataset.normalize_input( + x[:, : len(self.in_channels)], channels=self.in_channels + ) + return np.concatenate((x_norm, x[:, self.in_channels :]), axis=1) + + def denormalize_input(self, x): + """Convert input from normalized data to physical units.""" + x_denorm = self._dataset.denormalize_input( + x[:, : len(self.in_channels)], channels=self.in_channels + ) + return np.concatenate((x_denorm, x[:, len(self.in_channels) :]), axis=1) + + def normalize_output(self, x): + """Convert output from physical units to normalized data.""" + return self._dataset.normalize_output(x, channels=self.out_channels) + + def denormalize_output(self, x): + """Convert output from normalized data to physical units.""" + return self._dataset.denormalize_output(x, channels=self.out_channels) + + def _create_highres_(self, x, shape): + # downsample the high res imag + x = x.transpose(1, 2, 0) + # upsample with bicubic interpolation to bring the image to the nominal size + x = cv2.resize( + x, (shape[0], shape[1]), interpolation=cv2.INTER_CUBIC + ) # 32x32x3 + x = x.transpose(2, 0, 1) # 3x32x32 + return x + + def _create_lowres_(self, x, factor=4): + # downsample the high res imag + x = x.transpose(1, 2, 0) + x = x[::factor, ::factor, :] # 8x8x3 #subsample + # upsample with bicubic interpolation to bring the image to the nominal size + x = cv2.resize( + x, (x.shape[1] * factor, x.shape[0] * factor), interpolation=cv2.INTER_CUBIC + ) # 32x32x3 + x = x.transpose(2, 0, 1) # 3x32x32 + return x + + +def get_zarr_dataset(*, data_path, normalization="v1", all_times=False, **kwargs): + """Get a Zarr dataset for training or evaluation.""" + data_path = to_absolute_path(data_path) + get_target_normalization = { + "v1": get_target_normalizations_v1, + "v2": get_target_normalizations_v2, + }[normalization] + logger.info(f"Normalization: {normalization}") + zdataset = _ZarrDataset( + data_path, get_target_normalization=get_target_normalization + ) + return ZarrDataset( + dataset=zdataset, normalization=normalization, all_times=all_times, **kwargs + ) diff --git a/examples/generative/corrdiff_plus_plus/datasets/dataset.py b/examples/generative/corrdiff_plus_plus/datasets/dataset.py new file mode 100644 index 0000000000..d55f6038af --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/datasets/dataset.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Tuple, Union +import copy +import torch + +from physicsnemo.utils.generative import InfiniteSampler +from physicsnemo.distributed import DistributedManager + +from . import base, cwb, hrrrmini + + +# this maps all known dataset types to the corresponding init function +known_datasets = {"cwb": cwb.get_zarr_dataset, "hrrr_mini": hrrrmini.HRRRMiniDataset} + + +def init_train_valid_datasets_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, + validation_dataset_cfg: Union[dict, None] = None, + train_test_split: bool = True, +) -> Tuple[ + base.DownscalingDataset, + Iterable, + Union[base.DownscalingDataset, None], + Union[Iterable, None], +]: + """ + A wrapper function for managing the train-test split for the CWB dataset. + + Parameters: + - dataset_cfg (dict): Configuration for the dataset. + - dataloader_cfg (dict, optional): Configuration for the dataloader. Defaults to None. + - batch_size (int): The number of samples in each batch of data. Defaults to 1. + - seed (int): The random seed for dataset shuffling. Defaults to 0. + - train_test_split (bool): A flag to determine whether to create a validation dataset. Defaults to True. + + Returns: + - Tuple[base.DownscalingDataset, Iterable, Optional[base.DownscalingDataset], Optional[Iterable]]: A tuple containing the training dataset and iterator, and optionally the validation dataset and iterator if train_test_split is True. + """ + + config = copy.deepcopy(dataset_cfg) + (dataset, dataset_iter) = init_dataset_from_config( + config, dataloader_cfg, batch_size=batch_size, seed=seed + ) + if train_test_split: + valid_dataset_cfg = copy.deepcopy(config) + if validation_dataset_cfg: + valid_dataset_cfg.update(validation_dataset_cfg) + (valid_dataset, valid_dataset_iter) = init_dataset_from_config( + valid_dataset_cfg, dataloader_cfg, batch_size=batch_size, seed=seed + ) + else: + valid_dataset = valid_dataset_iter = None + + return dataset, dataset_iter, valid_dataset, valid_dataset_iter + + +def init_dataset_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, +) -> Tuple[base.DownscalingDataset, Iterable]: + dataset_cfg = copy.deepcopy(dataset_cfg) + dataset_type = dataset_cfg.pop("type", "cwb") + if "train_test_split" in dataset_cfg: + # handled by init_train_valid_datasets_from_config + del dataset_cfg["train_test_split"] + dataset_init_func = known_datasets[dataset_type] + + dataset_obj = dataset_init_func(**dataset_cfg) + if dataloader_cfg is None: + dataloader_cfg = {} + + dist = DistributedManager() + dataset_sampler = InfiniteSampler( + dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed + ) + + dataset_iterator = iter( + torch.utils.data.DataLoader( + dataset=dataset_obj, + sampler=dataset_sampler, + batch_size=batch_size, + worker_init_fn=None, + **dataloader_cfg, + ) + ) + + return (dataset_obj, dataset_iterator) diff --git a/examples/generative/corrdiff_plus_plus/datasets/hrrrmini.py b/examples/generative/corrdiff_plus_plus/datasets/hrrrmini.py new file mode 100644 index 0000000000..4537e0780a --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/datasets/hrrrmini.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import math +from typing import List, Tuple, Union + +import numpy as np +from numba import jit, prange +import xarray as xr + +from physicsnemo.utils.generative import convert_datetime_to_cftime + +from .base import ChannelMetadata, DownscalingDataset + + +class HRRRMiniDataset(DownscalingDataset): + """Reader for reduced-size HRRR dataset used for CorrDiff-mini.""" + + def __init__( + self, + data_path: str, + stats_path: str, + input_variables: Union[List[str], None] = None, + output_variables: Union[List[str], None] = None, + invariant_variables: Union[List[str], None] = ("elev_mean", "lsm_mean"), + ): + # load data + (self.input, self.input_variables) = _load_dataset( + data_path, "input", input_variables + ) + (self.output, self.output_variables) = _load_dataset( + data_path, "output", output_variables + ) + (self.invariants, self.invariant_variables) = _load_dataset( + data_path, "invariant", invariant_variables, stack_axis=0 + ) + + # load temporal and spatial coordinates + with xr.open_dataset(data_path) as ds: + self.times = np.array(ds["time"]) + self.coords = np.array(ds["coord"]) + + self.img_shape = self.output.shape[-2:] + self.upsample_factor = self.output.shape[-1] // self.input.shape[-1] + + # load normalization stats + with open(stats_path, "r") as f: + stats = json.load(f) + (input_mean, input_std) = _load_stats(stats, self.input_variables, "input") + (inv_mean, inv_std) = _load_stats(stats, self.invariant_variables, "invariant") + self.input_mean = np.concatenate([input_mean, inv_mean], axis=0) + self.input_std = np.concatenate([input_std, inv_std], axis=0) + (self.output_mean, self.output_std) = _load_stats( + stats, self.output_variables, "output" + ) + + def __getitem__(self, idx): + """Return the data sample (output, input, 0) at index idx.""" + x = self.upsample(self.input[idx].copy()) + + # add invariants to input + (i, j) = self.coords[idx] + inv = self.invariants[:, i : i + self.img_shape[0], j : j + self.img_shape[1]] + x = np.concatenate([x, inv], axis=0) + + y = self.output[idx] + + x = self.normalize_input(x) + y = self.normalize_output(y) + return (y, x, 0) + + def __len__(self): + return self.input.shape[0] + + def longitude(self) -> np.ndarray: + """Get longitude values from the dataset.""" + return np.full(self.img_shape, np.nan) + + def latitude(self) -> np.ndarray: + """Get latitude values from the dataset.""" + return np.full(self.img_shape, np.nan) + + def input_channels(self) -> List[ChannelMetadata]: + """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" + inputs = [ChannelMetadata(name=v) for v in self.input_variables] + invariants = [ + ChannelMetadata(name=v, auxiliary=True) for v in self.invariant_variables + ] + return inputs + invariants + + def output_channels(self) -> List[ChannelMetadata]: + """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" + return [ChannelMetadata(name=v) for v in self.output_variables] + + def time(self) -> List: + """Get time values from the dataset.""" + datetimes = ( + datetime.datetime.utcfromtimestamp(t.tolist() / 1e9) for t in self.times + ) + return [convert_datetime_to_cftime(t) for t in datetimes] + + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + return self.img_shape + + def normalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from physical units to normalized data.""" + return (x - self.input_mean) / self.input_std + + def denormalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from normalized data to physical units.""" + return x * self.input_std + self.input_mean + + def normalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from physical units to normalized data.""" + return (x - self.output_mean) / self.output_std + + def denormalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from normalized data to physical units.""" + return x * self.output_std + self.output_mean + + def upsample(self, x): + """Extend x around edges with linear extrapolation.""" + y_shape = ( + x.shape[0], + x.shape[1] * self.upsample_factor, + x.shape[2] * self.upsample_factor, + ) + y = np.empty(y_shape, dtype=np.float32) + _zoom_extrapolate(x, y, self.upsample_factor) + return y + + +def _load_dataset(data_path, group, variables=None, stack_axis=1): + with xr.open_dataset(data_path, group=group) as ds: + if variables is None: + variables = list(ds.keys()) + data = np.stack([ds[v] for v in variables], axis=stack_axis) + return (data, variables) + + +def _load_stats(stats, variables, group): + mean = np.array([stats[group][v]["mean"] for v in variables])[:, None, None].astype( + np.float32 + ) + std = np.array([stats[group][v]["std"] for v in variables])[:, None, None].astype( + np.float32 + ) + return (mean, std) + + +@jit(nopython=True) +def _zoom_extrapolate(x, y, factor): + """Bilinear zoom with extrapolation. + Use a numba function here because numpy/scipy options are rather slow. + """ + s = 1 / factor + for k in prange(y.shape[0]): + for iy in range(y.shape[1]): + ix = (iy + 0.5) * s - 0.5 + ix0 = int(math.floor(ix)) + ix0 = max(0, min(ix0, x.shape[1] - 2)) + ix1 = ix0 + 1 + for jy in range(y.shape[2]): + jx = (jy + 0.5) * s - 0.5 + jx0 = int(math.floor(jx)) + jx0 = max(0, min(jx0, x.shape[2] - 2)) + jx1 = jx0 + 1 + + x00 = x[k, ix0, jx0] + x01 = x[k, ix0, jx1] + x10 = x[k, ix1, jx0] + x11 = x[k, ix1, jx1] + djx = jx - jx0 + x0 = x00 + djx * (x01 - x00) + x1 = x10 + djx * (x11 - x10) + y[k, iy, jy] = x0 + (ix - ix0) * (x1 - x0) diff --git a/examples/generative/corrdiff_plus_plus/datasets/img_utils.py b/examples/generative/corrdiff_plus_plus/datasets/img_utils.py new file mode 100644 index 0000000000..c354985c81 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/datasets/img_utils.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + + +def reshape_fields( + img, + inp_or_tar, + y_roll, + train, + n_history, + in_channels, + out_channels, + img_shape_x, + img_shape_y, + min_path, + max_path, + global_means_path, + global_stds_path, + normalization, + roll, + normalize=True, +): + """ + Takes in np array of size (n_history+1, c, h, w) and returns torch tensor of + size ((n_channels*(n_history+1), img_shape_x, img_shape_y) + """ + + if len(np.shape(img)) == 3: + img = np.expand_dims(img, 0) + + if img.shape[3] > 720: + img = img[:, :, 0:720] # remove last pixel for era5 data + + n_history = n_history + + n_channels = np.shape(img)[1] # this will either be N_in_channels or N_out_channels + channels = in_channels if inp_or_tar == "inp" else out_channels + + if normalize and train: + mins = np.load(min_path)[:, channels] + maxs = np.load(max_path)[:, channels] + means = np.load(global_means_path)[:, channels] + stds = np.load(global_stds_path)[:, channels] + + img = img[:, :, :img_shape_x, :img_shape_y] + + if normalize and train: + if normalization == "minmax": + img -= mins + img /= maxs - mins + elif normalization == "zscore": + img -= means + img /= stds + + if roll: + img = np.roll(img, y_roll, axis=-1) + + if inp_or_tar == "inp": + img = np.reshape(img, (n_channels * (n_history + 1), img_shape_x, img_shape_y)) + elif inp_or_tar == "tar": + img = np.reshape(img, (n_channels, img_shape_x, img_shape_y)) + + return torch.as_tensor(img) diff --git a/examples/generative/corrdiff_plus_plus/datasets/norm.py b/examples/generative/corrdiff_plus_plus/datasets/norm.py new file mode 100644 index 0000000000..f1c50d13a3 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/datasets/norm.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def normalize(x, center, scale): + """Normalize input data 'x' using center and scale values.""" + center = np.asarray(center) + scale = np.asarray(scale) + if not (center.ndim == 1 and scale.ndim == 1): + raise ValueError("center and scale must be 1D arrays") + return (x - center[np.newaxis, :, np.newaxis, np.newaxis]) / scale[ + np.newaxis, :, np.newaxis, np.newaxis + ] + + +def denormalize(x, center, scale): + """Denormalize input data 'x' using center and scale values.""" + center = np.asarray(center) + scale = np.asarray(scale) + if not (center.ndim == 1 and scale.ndim == 1): + raise ValueError("center and scale must be 1D arrays") + return ( + x * scale[np.newaxis, :, np.newaxis, np.newaxis] + + center[np.newaxis, :, np.newaxis, np.newaxis] + ) diff --git a/examples/generative/corrdiff_plus_plus/generate.py b/examples/generative/corrdiff_plus_plus/generate.py new file mode 100644 index 0000000000..a758255ff2 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/generate.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +from omegaconf import OmegaConf, DictConfig +import torch +import torch._dynamo +import nvtx +import numpy as np +import netCDF4 as nc +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo import Module +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from einops import rearrange +from torch.distributed import gather +import tqdm + + +from hydra.utils import to_absolute_path +from physicsnemo.utils.generative import ( + SFM_Euler_sampler, + SFM_Euler_sampler_Adaptive_Sigma, + StackedRandomGenerator, + SFM_encoder_sampler, +) +from physicsnemo.utils.corrdiff import ( + NetCDFWriter, + get_time_from_range, +) + + +from helpers.generate_helpers import ( + get_dataset_and_sampler, + save_images, +) +from helpers.train_helpers import set_patch_shape + + +@hydra.main(version_base="1.2", config_path="conf", config_name="config_generate") +def main(cfg: DictConfig) -> None: + """Generate random images using the techniques described in the paper + "Elucidating the Design Space of Diffusion-Based Generative Models". + """ + + # Initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + + # Initialize logger + logger = PythonLogger("generate") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging("generate.log") + + # Handle the batch size + seeds = list(np.arange(cfg.generation.num_ensembles)) + num_batches = ( + (len(seeds) - 1) // (cfg.generation.seed_batch_size * dist.world_size) + 1 + ) * dist.world_size + all_batches = torch.as_tensor(seeds).tensor_split(num_batches) + rank_batches = all_batches[dist.rank :: dist.world_size] + + # Synchronize + if dist.world_size > 1: + torch.distributed.barrier() + + # Parse the inference input times + if cfg.generation.times_range and cfg.generation.times: + raise ValueError("Either times_range or times must be provided, but not both") + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range) + else: + times = cfg.generation.times + + # Create dataset object + dataset_cfg = OmegaConf.to_container(cfg.dataset) + dataset, sampler = get_dataset_and_sampler(dataset_cfg=dataset_cfg, times=times) + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + + # patching not supported for + patch_shape = (None, None) + img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + + # Parse the inference mode + if cfg.generation.inference_mode not in ["sfm", "sfm_encoder", "sfm_two_stage"]: + raise ValueError(f"Invalid inference mode {cfg.generation.inference_mode}") + + # Load networks, move to device, change precision + encoder_ckpt_filename = cfg.generation.io.encoder_ckpt_filename + logger0.info(f'Loading encoder network from "{encoder_ckpt_filename}"...') + encoder_net = Module.from_checkpoint(to_absolute_path(encoder_ckpt_filename)) + encoder_net = encoder_net.eval().to(device).to(memory_format=torch.channels_last) + + if cfg.generation.inference_mode in ["sfm", "sfm_two_stage"]: + denoiser_ckpt_filename = cfg.generation.io.denoiser_ckpt_filename + logger0.info(f'Loading residual network from "{denoiser_ckpt_filename}"...') + denoiser_net = Module.from_checkpoint(to_absolute_path(denoiser_ckpt_filename)) + denoiser_net = ( + denoiser_net.eval().to(device).to(memory_format=torch.channels_last) + ) + else: + denoiser_net = None + + if cfg.generation.perf.force_fp16: + encoder_net.use_fp16 = True + denoiser_net.use_fp16 = True + + # Reset since we are using a different mode. + if cfg.generation.perf.use_torch_compile: + torch._dynamo.reset() + encoder_net = torch.compile(encoder_net, mode="reduce-overhead") + if denoiser_net: + denoiser_net = torch.compile(denoiser_net, mode="reduce-overhead") + networks = {"denoiser_net": denoiser_net, "encoder_net": encoder_net} + + # Partially instantiate the sampler based on the configs + if cfg.generation.inference_mode in ["sfm", "sfm_two_stage"]: + if cfg.generation.learnable_sigma: + sampler_fn = SFM_Euler_sampler_Adaptive_Sigma + else: + sampler_fn = SFM_Euler_sampler + elif cfg.generation.inference_mode == "sfm_encoder": + sampler_fn = SFM_encoder_sampler + else: + raise ValueError(f"Unknown sampling method {cfg.generation.inference_mode}") + + # Main generation definition + def generate_fn(): + img_shape_y, img_shape_x = img_shape + with nvtx.annotate("generate_fn", color="green"): + all_images = [] + for batch_seeds in tqdm.tqdm( + rank_batches, unit="batch", disable=(dist.rank != 0) + ): + batch_size = len(batch_seeds) + if batch_size == 0: + continue + rnd = StackedRandomGenerator(device, batch_seeds) + with nvtx.annotate( + f"{cfg.generation.inference_mode} model", color="rapids" + ): + with torch.inference_mode(): + images = sampler_fn( + networks=networks, + img_lr=image_lr, + randn_like=rnd.randn_like, + cfg=cfg.generation.sampler, + ) + all_images.append(images) + image_out = torch.cat(all_images) + + # Gather tensors on rank 0 + if dist.world_size > 1: + if dist.rank == 0: + gathered_tensors = [ + torch.zeros_like( + image_out, dtype=image_out.dtype, device=image_out.device + ) + for _ in range(dist.world_size) + ] + else: + gathered_tensors = None + + torch.distributed.barrier() + gather( + image_out, + gather_list=gathered_tensors if dist.rank == 0 else None, + dst=0, + ) + + if dist.rank == 0: + return torch.cat(gathered_tensors) + else: + return None + else: + return image_out + + # generate images + output_path = getattr(cfg.generation.io, "output_filename", "corrdiff_output.nc") + logger0.info(f"Generating images, saving results to {output_path}...") + batch_size = 1 + warmup_steps = min(len(times) - 1, 2) + # Generates model predictions from the input data using the specified + # `generate_fn`, and save the predictions to the provided NetCDF file. It iterates + # through the dataset using a data loader, computes predictions, and saves them along + # with associated metadata. + if dist.rank == 0: + f = nc.Dataset(output_path, "w") + # add attributes + f.cfg = str(cfg) + + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + + data_loader = torch.utils.data.DataLoader( + dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True + ) + time_index = -1 + if dist.rank == 0: + writer = NetCDFWriter( + f, + lat=dataset.latitude(), + lon=dataset.longitude(), + input_channels=dataset.input_channels(), + output_channels=dataset.output_channels(), + ) + + # Initialize threadpool for writers + writer_executor = ThreadPoolExecutor( + max_workers=cfg.generation.perf.num_writer_workers + ) + writer_threads = [] + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + times = dataset.time() + for image_tar, image_lr, index in iter(data_loader): + time_index += 1 + if dist.rank == 0: + logger0.info(f"starting index: {time_index}") + + if time_index == warmup_steps: + start.record() + + # continue + image_lr = ( + image_lr.to(device=device) + .to(torch.float32) + .to(memory_format=torch.channels_last) + ) + # expand to batch size + image_lr = image_lr.expand( + cfg.generation.seed_batch_size, -1, -1, -1 + ).to(memory_format=torch.channels_last) + image_tar = image_tar.to(device=device).to(torch.float32) + image_out = generate_fn() + + if dist.rank == 0: + batch_size = image_out.shape[0] + # write out data in a seperate thread so we don't hold up inferencing + writer_threads.append( + writer_executor.submit( + save_images, + writer, + dataset, + list(times), + image_out.cpu(), + image_tar.cpu(), + image_lr.cpu(), + time_index, + index[0], + ) + ) + end.record() + end.synchronize() + elapsed_time = start.elapsed_time(end) / 1000.0 # Convert ms to s + timed_steps = time_index + 1 - warmup_steps + if dist.rank == 0: + average_time_per_batch_element = elapsed_time / timed_steps / batch_size + logger.info( + f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s" + ) + logger.info( + f"Average time per batch element = {average_time_per_batch_element} s" + ) + + # make sure all the workers are done writing + if dist.rank == 0: + for thread in list(writer_threads): + thread.result() + writer_threads.remove(thread) + writer_executor.shutdown() + + if dist.rank == 0: + f.close() + logger0.info("Generation Completed.") + + +if __name__ == "__main__": + main() diff --git a/examples/generative/corrdiff_plus_plus/helpers/generate_helpers.py b/examples/generative/corrdiff_plus_plus/helpers/generate_helpers.py new file mode 100644 index 0000000000..755cb95759 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/helpers/generate_helpers.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from datasets.base import DownscalingDataset +from datasets.dataset import init_dataset_from_config +from physicsnemo.utils.generative import convert_datetime_to_cftime + + +def get_dataset_and_sampler(dataset_cfg, times): + """ + Get a dataset and sampler for generation. + """ + (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1) + plot_times = [ + convert_datetime_to_cftime( + datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") + ) + for time in times + ] + all_times = dataset.time() + time_indices = [all_times.index(t) for t in plot_times] + sampler = time_indices + + return dataset, sampler + + +def save_images( + writer, + dataset: DownscalingDataset, + times, + image_out, + image_tar, + image_lr, + time_index, + t_index, +): + """ + Saves inferencing result along with the baseline + + Parameters + ---------- + + writer (NetCDFWriter): Where the data is being written + in_channels (List): List of the input channels being used + input_channel_info (Dict): Description of the input channels + out_channels (List): List of the output channels being used + output_channel_info (Dict): Description of the output channels + input_norm (Tuple): Normalization data for input + target_norm (Tuple): Normalization data for the target + image_out (torch.Tensor): Generated output data + image_tar (torch.Tensor): Ground truth data + image_lr (torch.Tensor): Low resolution input data + time_index (int): Epoch number + t_index (int): index where times are located + """ + # weather sub-plot + image_lr2 = image_lr[0].unsqueeze(0) + image_lr2 = image_lr2.cpu().numpy() + image_lr2 = dataset.denormalize_input(image_lr2) + + image_tar2 = image_tar[0].unsqueeze(0) + image_tar2 = image_tar2.cpu().numpy() + image_tar2 = dataset.denormalize_output(image_tar2) + + # some runtime assertions + if image_tar2.ndim != 4: + raise ValueError("image_tar2 must be 4-dimensional") + + for idx in range(image_out.shape[0]): + image_out2 = image_out[idx].unsqueeze(0) + if image_out2.ndim != 4: + raise ValueError("image_out2 must be 4-dimensional") + + # Denormalize the input and outputs + image_out2 = image_out2.cpu().numpy() + image_out2 = dataset.denormalize_output(image_out2) + + time = times[t_index] + writer.write_time(time_index, time) + for channel_idx in range(image_out2.shape[1]): + info = dataset.output_channels()[channel_idx] + channel_name = info.name + info.level + truth = image_tar2[0, channel_idx] + + writer.write_truth(channel_name, time_index, truth) + writer.write_prediction( + channel_name, time_index, idx, image_out2[0, channel_idx] + ) + + input_channel_info = dataset.input_channels() + for channel_idx in range(len(input_channel_info)): + info = input_channel_info[channel_idx] + channel_name = info.name + info.level + writer.write_input(channel_name, time_index, image_lr2[0, channel_idx]) + if channel_idx == image_lr2.shape[1] - 1: + break diff --git a/examples/generative/corrdiff_plus_plus/helpers/sfm_utils.py b/examples/generative/corrdiff_plus_plus/helpers/sfm_utils.py new file mode 100644 index 0000000000..99243ff4dd --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/helpers/sfm_utils.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from physicsnemo.models.diffusion import SongUNetPosEmbd +from physicsnemo.models.diffusion import Conv2dSerializable +import torch +from omegaconf import DictConfig + + +def get_encoder(cfg: DictConfig): + """ + Helper that sets instantiates a + + Parameters + ---------- + cfg: DictConfig + configuration for the encoder + + Returns + torch.nn.Module: The encoder + """ + in_channels = len(cfg.dataset["in_channels"]) + out_channels = len(cfg.dataset["out_channels"]) + encoder_type = cfg.model["encoder_type"] + + if encoder_type == "1x1conv": + encoder = Conv2dSerializable(in_channels, out_channels, kernel_size=1) + elif "songunet" in encoder_type: + model_channels_dict = { + "songunet_s": 32, # 11.60M + "songunet_xs": 16, # 2.90M + "songunet_2xs": 8, # 0.74M + "songunet_3xs": 4, # 0.19M + } + if hasattr(cfg.model, "songunet_checkpoint_level"): + songunet_checkpoint_level = cfg.model.songunet_checkpoint_level + else: + songunet_checkpoint_level = None + + songunet_kwargs = { + "embedding_type": "zero", + "label_dim": 0, + "encoder_type": "standard", + "decoder_type": "standard", + "channel_mult_noise": 1, + "resample_filter": [1, 1], + "channel_mult": [1, 2, 2, 4, 4], + "attn_resolutions": [28], + "N_grid_channels": 0, + "dropout": cfg.model.dropout, + "checkpoint_level": songunet_checkpoint_level, + "model_channels": model_channels_dict[encoder_type], + } + encoder = SongUNetPosEmbd( + img_resolution=cfg.dataset["img_shape_x"], + in_channels=in_channels, + out_channels=out_channels, + **songunet_kwargs, + ) + else: + raise ValueError(f"Unknown encoder type: {encoder_type}") + + return encoder diff --git a/examples/generative/corrdiff_plus_plus/helpers/train_helpers.py b/examples/generative/corrdiff_plus_plus/helpers/train_helpers.py new file mode 100644 index 0000000000..d4529ac821 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/helpers/train_helpers.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from omegaconf import ListConfig + + +def set_patch_shape(img_shape, patch_shape): + img_shape_y, img_shape_x = img_shape + patch_shape_y, patch_shape_x = patch_shape + if (patch_shape_x is None) or (patch_shape_x > img_shape_x): + patch_shape_x = img_shape_x + if (patch_shape_y is None) or (patch_shape_y > img_shape_y): + patch_shape_y = img_shape_y + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: + if patch_shape_x != patch_shape_y: + raise NotImplementedError("Rectangular patch not supported yet") + if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0: + raise ValueError("Patch shape needs to be a multiple of 32") + return (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) + + +def set_seed(rank): + """ + Set seeds for NumPy and PyTorch to ensure reproducibility in distributed settings + """ + np.random.seed(rank % (1 << 31)) + torch.manual_seed(np.random.randint(1 << 31)) + + +def configure_cuda_for_consistent_precision(): + """ + Configures CUDA and cuDNN settings to ensure consistent precision by + disabling TensorFloat-32 (TF32) and reduced precision settings. + """ + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + +def compute_num_accumulation_rounds(total_batch_size, batch_size_per_gpu, world_size): + """ + Calculate the total batch size per GPU in a distributed setting, log the batch size per GPU, ensure it's within valid limits, + determine the number of accumulation rounds, and validate that the global batch size matches the expected value. + """ + batch_gpu_total = total_batch_size // world_size + batch_size_per_gpu = batch_size_per_gpu + if batch_size_per_gpu is None or batch_size_per_gpu > batch_gpu_total: + batch_size_per_gpu = batch_gpu_total + num_accumulation_rounds = batch_gpu_total // batch_size_per_gpu + if total_batch_size != batch_size_per_gpu * num_accumulation_rounds * world_size: + raise ValueError( + "total_batch_size must be equal to batch_size_per_gpu * num_accumulation_rounds * world_size" + ) + return batch_gpu_total, num_accumulation_rounds + + +def handle_and_clip_gradients(model, grad_clip_threshold=None): + """ + Handles NaNs and infinities in the gradients and optionally clips the gradients. + + Parameters: + - model (torch.nn.Module): The model whose gradients need to be processed. + - grad_clip_threshold (float, optional): The threshold for gradient clipping. If None, no clipping is performed. + """ + # Replace NaNs and infinities in gradients + for param in model.parameters(): + if param.grad is not None: + torch.nan_to_num( + param.grad, nan=0.0, posinf=1e5, neginf=-1e5, out=param.grad + ) + + # Clip gradients if a threshold is provided + if grad_clip_threshold is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold) + + +def parse_model_args(args): + """Convert ListConfig values in args to tuples.""" + return {k: tuple(v) if isinstance(v, ListConfig) else v for k, v in args.items()} + + +def is_time_for_periodic_task( + cur_nimg, freq, done, batch_size, rank, rank_0_only=False +): + """Should we perform a task that is done every `freq` samples?""" + if rank_0_only and rank != 0: + return False + elif done: # Run periodic tasks also at the end of training + return True + else: + return cur_nimg % freq < batch_size diff --git a/examples/generative/corrdiff_plus_plus/inference/concat.py b/examples/generative/corrdiff_plus_plus/inference/concat.py new file mode 100644 index 0000000000..ee44633d59 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/inference/concat.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import dask.diagnostics +import xarray + +base = sys.argv[1:-1] +out = sys.argv[-1] + +with dask.diagnostics.ProgressBar(): + t = xarray.open_mfdataset( + base, + group="prediction", + concat_dim="ensemble", + combine="nested", + chunks={"time": 1, "ensemble": 10}, + ) + t.to_zarr(out, group="prediction") + + t = xarray.open_dataset(base[0], group="input", chunks={"time": 1}) + t.to_zarr(out, group="input", mode="a") + + t = xarray.open_dataset(base[0], group="truth", chunks={"time": 1}) + t.to_zarr(out, group="truth", mode="a") + + t = xarray.open_dataset(base[0]) + t.to_zarr(out, mode="a") diff --git a/examples/generative/corrdiff_plus_plus/inference/matplotlibrc b/examples/generative/corrdiff_plus_plus/inference/matplotlibrc new file mode 100644 index 0000000000..54e097e836 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/inference/matplotlibrc @@ -0,0 +1,770 @@ +#### MATPLOTLIBRC FORMAT + +## NOTE FOR END USERS: DO NOT EDIT THIS FILE! +## +## This is a sample Matplotlib configuration file - you can find a copy +## of it on your system in site-packages/matplotlib/mpl-data/matplotlibrc +## (relative to your Python installation location). +## +## You should find a copy of it on your system at +## site-packages/matplotlib/mpl-data/matplotlibrc (relative to your Python +## installation location). DO NOT EDIT IT! +## +## If you wish to change your default style, copy this file to one of the +## following locations: +## Unix/Linux: +## $HOME/.config/matplotlib/matplotlibrc OR +## $XDG_CONFIG_HOME/matplotlib/matplotlibrc (if $XDG_CONFIG_HOME is set) +## Other platforms: +## $HOME/.matplotlib/matplotlibrc +## and edit that copy. +## +## See https://matplotlib.org/users/customizing.html#the-matplotlibrc-file +## for more details on the paths which are checked for the configuration file. +## +## Blank lines, or lines starting with a comment symbol, are ignored, as are +## trailing comments. Other lines must have the format: +## key: val # optional comment +## +## Formatting: Use PEP8-like style (as enforced in the rest of the codebase). +## All lines start with an additional '#', so that removing all leading '#'s +## yields a valid style file. +## +## Colors: for the color values below, you can either use +## - a Matplotlib color string, such as r, k, or b +## - an RGB tuple, such as (1.0, 0.5, 0.0) +## - a hex string, such as ff00ff +## - a scalar grayscale intensity such as 0.75 +## - a legal html color name, e.g., red, blue, darkslategray +## +## Matplotlib configuration are currently divided into following parts: +## - BACKENDS +## - LINES +## - PATCHES +## - HATCHES +## - BOXPLOT +## - FONT +## - TEXT +## - LaTeX +## - AXES +## - DATES +## - TICKS +## - GRIDS +## - LEGEND +## - FIGURE +## - IMAGES +## - CONTOUR PLOTS +## - ERRORBAR PLOTS +## - HISTOGRAM PLOTS +## - SCATTER PLOTS +## - AGG RENDERING +## - PATHS +## - SAVING FIGURES +## - INTERACTIVE KEYMAPS +## - ANIMATION + +##### CONFIGURATION BEGINS HERE + + +## *************************************************************************** +## * BACKENDS * +## *************************************************************************** +## The default backend. If you omit this parameter, the first working +## backend from the following list is used: +## MacOSX QtAgg Gtk4Agg Gtk3Agg TkAgg WxAgg Agg +## Other choices include: +## QtCairo GTK4Cairo GTK3Cairo TkCairo WxCairo Cairo +## Qt5Agg Qt5Cairo Wx # deprecated. +## PS PDF SVG Template +## You can also deploy your own backend outside of Matplotlib by referring to +## the module name (which must be in the PYTHONPATH) as 'module://my_backend'. +##backend: Agg + +## The port to use for the web server in the WebAgg backend. +#webagg.port: 8988 + +## The address on which the WebAgg web server should be reachable +#webagg.address: 127.0.0.1 + +## If webagg.port is unavailable, a number of other random ports will +## be tried until one that is available is found. +#webagg.port_retries: 50 + +## When True, open the web browser to the plot that is shown +#webagg.open_in_browser: True + +## If you are running pyplot inside a GUI and your backend choice +## conflicts, we will automatically try to find a compatible one for +## you if backend_fallback is True +#backend_fallback: True + +#interactive: False +#toolbar: toolbar2 # {None, toolbar2, toolmanager} +#timezone: UTC # a pytz timezone string, e.g., US/Central or Europe/Paris + + +## *************************************************************************** +## * LINES * +## *************************************************************************** +## See https://matplotlib.org/api/artist_api.html#module-matplotlib.lines +## for more information on line properties. +#lines.linewidth: 1.5 # line width in points +#lines.linestyle: - # solid line +#lines.color: C0 # has no affect on plot(); see axes.prop_cycle +#lines.marker: None # the default marker +#lines.markerfacecolor: auto # the default marker face color +#lines.markeredgecolor: auto # the default marker edge color +#lines.markeredgewidth: 1.0 # the line width around the marker symbol +#lines.markersize: 6 # marker size, in points +#lines.dash_joinstyle: round # {miter, round, bevel} +#lines.dash_capstyle: butt # {butt, round, projecting} +#lines.solid_joinstyle: round # {miter, round, bevel} +#lines.solid_capstyle: projecting # {butt, round, projecting} +#lines.antialiased: True # render lines in antialiased (no jaggies) + +## The three standard dash patterns. These are scaled by the linewidth. +#lines.dashed_pattern: 3.7, 1.6 +#lines.dashdot_pattern: 6.4, 1.6, 1, 1.6 +#lines.dotted_pattern: 1, 1.65 +#lines.scale_dashes: True + +#markers.fillstyle: full # {full, left, right, bottom, top, none} + +#pcolor.shading: auto +#pcolormesh.snap: True # Whether to snap the mesh to pixel boundaries. This is + # provided solely to allow old test images to remain + # unchanged. Set to False to obtain the previous behavior. + +## *************************************************************************** +## * PATCHES * +## *************************************************************************** +## Patches are graphical objects that fill 2D space, like polygons or circles. +## See https://matplotlib.org/api/artist_api.html#module-matplotlib.patches +## for more information on patch properties. +#patch.linewidth: 1 # edge width in points. +#patch.facecolor: C0 +#patch.edgecolor: black # if forced, or patch is not filled +#patch.force_edgecolor: False # True to always use edgecolor +#patch.antialiased: True # render patches in antialiased (no jaggies) + + +## *************************************************************************** +## * HATCHES * +## *************************************************************************** +#hatch.color: black +#hatch.linewidth: 1.0 + + +## *************************************************************************** +## * BOXPLOT * +## *************************************************************************** +#boxplot.notch: False +#boxplot.vertical: True +#boxplot.whiskers: 1.5 +#boxplot.bootstrap: None +#boxplot.patchartist: False +#boxplot.showmeans: False +#boxplot.showcaps: True +#boxplot.showbox: True +#boxplot.showfliers: True +#boxplot.meanline: False + +#boxplot.flierprops.color: black +#boxplot.flierprops.marker: o +#boxplot.flierprops.markerfacecolor: none +#boxplot.flierprops.markeredgecolor: black +#boxplot.flierprops.markeredgewidth: 1.0 +#boxplot.flierprops.markersize: 6 +#boxplot.flierprops.linestyle: none +#boxplot.flierprops.linewidth: 1.0 + +#boxplot.boxprops.color: black +#boxplot.boxprops.linewidth: 1.0 +#boxplot.boxprops.linestyle: - + +#boxplot.whiskerprops.color: black +#boxplot.whiskerprops.linewidth: 1.0 +#boxplot.whiskerprops.linestyle: - + +#boxplot.capprops.color: black +#boxplot.capprops.linewidth: 1.0 +#boxplot.capprops.linestyle: - + +#boxplot.medianprops.color: C1 +#boxplot.medianprops.linewidth: 1.0 +#boxplot.medianprops.linestyle: - + +#boxplot.meanprops.color: C2 +#boxplot.meanprops.marker: ^ +#boxplot.meanprops.markerfacecolor: C2 +#boxplot.meanprops.markeredgecolor: C2 +#boxplot.meanprops.markersize: 6 +#boxplot.meanprops.linestyle: -- +#boxplot.meanprops.linewidth: 1.0 + + +## *************************************************************************** +## * FONT * +## *************************************************************************** +## The font properties used by `text.Text`. +## See https://matplotlib.org/api/font_manager_api.html for more information +## on font properties. The 6 font properties used for font matching are +## given below with their default values. +## +## The font.family property can take either a concrete font name (not supported +## when rendering text with usetex), or one of the following five generic +## values: +## - 'serif' (e.g., Times), +## - 'sans-serif' (e.g., Helvetica), +## - 'cursive' (e.g., Zapf-Chancery), +## - 'fantasy' (e.g., Western), and +## - 'monospace' (e.g., Courier). +## Each of these values has a corresponding default list of font names +## (font.serif, etc.); the first available font in the list is used. Note that +## for font.serif, font.sans-serif, and font.monospace, the first element of +## the list (a DejaVu font) will always be used because DejaVu is shipped with +## Matplotlib and is thus guaranteed to be available; the other entries are +## left as examples of other possible values. +## +## The font.style property has three values: normal (or roman), italic +## or oblique. The oblique style will be used for italic, if it is not +## present. +## +## The font.variant property has two values: normal or small-caps. For +## TrueType fonts, which are scalable fonts, small-caps is equivalent +## to using a font size of 'smaller', or about 83%% of the current font +## size. +## +## The font.weight property has effectively 13 values: normal, bold, +## bolder, lighter, 100, 200, 300, ..., 900. Normal is the same as +## 400, and bold is 700. bolder and lighter are relative values with +## respect to the current weight. +## +## The font.stretch property has 11 values: ultra-condensed, +## extra-condensed, condensed, semi-condensed, normal, semi-expanded, +## expanded, extra-expanded, ultra-expanded, wider, and narrower. This +## property is not currently implemented. +## +## The font.size property is the default font size for text, given in points. +## 10 pt is the standard value. +## +## Note that font.size controls default text sizes. To configure +## special text sizes tick labels, axes, labels, title, etc., see the rc +## settings for axes and ticks. Special text sizes can be defined +## relative to font.size, using the following values: xx-small, x-small, +## small, medium, large, x-large, xx-large, larger, or smaller + +#font.family: sans-serif +#font.style: normal +#font.variant: normal +#font.weight: normal +#font.stretch: normal +#font.size: 10.0 + +#font.serif: DejaVu Serif, Bitstream Vera Serif, Computer Modern Roman, New Century Schoolbook, Century Schoolbook L, Utopia, ITC Bookman, Bookman, Nimbus Roman No9 L, Times New Roman, Times, Palatino, Charter, serif +#font.sans-serif: DejaVu Sans, Bitstream Vera Sans, Computer Modern Sans Serif, Lucida Grande, Verdana, Geneva, Lucid, Arial, Helvetica, Avant Garde, sans-serif +#font.cursive: Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, Comic Neue, Comic Sans MS, cursive +#font.fantasy: Chicago, Charcoal, Impact, Western, Humor Sans, xkcd, fantasy +#font.monospace: DejaVu Sans Mono, Bitstream Vera Sans Mono, Computer Modern Typewriter, Andale Mono, Nimbus Mono L, Courier New, Courier, Fixed, Terminal, monospace + + +## *************************************************************************** +## * TEXT * +## *************************************************************************** +## The text properties used by `text.Text`. +## See https://matplotlib.org/api/artist_api.html#module-matplotlib.text +## for more information on text properties +#text.color: black + +## FreeType hinting flag ("foo" corresponds to FT_LOAD_FOO); may be one of the +## following (Proprietary Matplotlib-specific synonyms are given in parentheses, +## but their use is discouraged): +## - default: Use the font's native hinter if possible, else FreeType's auto-hinter. +## ("either" is a synonym). +## - no_autohint: Use the font's native hinter if possible, else don't hint. +## ("native" is a synonym.) +## - force_autohint: Use FreeType's auto-hinter. ("auto" is a synonym.) +## - no_hinting: Disable hinting. ("none" is a synonym.) +#text.hinting: force_autohint + +#text.hinting_factor: 8 # Specifies the amount of softness for hinting in the + # horizontal direction. A value of 1 will hint to full + # pixels. A value of 2 will hint to half pixels etc. +#text.kerning_factor: 0 # Specifies the scaling factor for kerning values. This + # is provided solely to allow old test images to remain + # unchanged. Set to 6 to obtain previous behavior. + # Values other than 0 or 6 have no defined meaning. +#text.antialiased: True # If True (default), the text will be antialiased. + # This only affects raster outputs. + + +## *************************************************************************** +## * LaTeX * +## *************************************************************************** +## For more information on LaTeX properties, see +## https://matplotlib.org/tutorials/text/usetex.html +#text.usetex: False # use latex for all text handling. The following fonts + # are supported through the usual rc parameter settings: + # new century schoolbook, bookman, times, palatino, + # zapf chancery, charter, serif, sans-serif, helvetica, + # avant garde, courier, monospace, computer modern roman, + # computer modern sans serif, computer modern typewriter +#text.latex.preamble: # IMPROPER USE OF THIS FEATURE WILL LEAD TO LATEX FAILURES + # AND IS THEREFORE UNSUPPORTED. PLEASE DO NOT ASK FOR HELP + # IF THIS FEATURE DOES NOT DO WHAT YOU EXPECT IT TO. + # text.latex.preamble is a single line of LaTeX code that + # will be passed on to the LaTeX system. It may contain + # any code that is valid for the LaTeX "preamble", i.e. + # between the "\documentclass" and "\begin{document}" + # statements. + # Note that it has to be put on a single line, which may + # become quite long. + # The following packages are always loaded with usetex, + # so beware of package collisions: + # geometry, inputenc, type1cm. + # PostScript (PSNFSS) font packages may also be + # loaded, depending on your font settings. + +## The following settings allow you to select the fonts in math mode. +#mathtext.fontset: dejavusans # Should be 'dejavusans' (default), + # 'dejavuserif', 'cm' (Computer Modern), 'stix', + # 'stixsans' or 'custom' (unsupported, may go + # away in the future) +## "mathtext.fontset: custom" is defined by the mathtext.bf, .cal, .it, ... +## settings which map a TeX font name to a fontconfig font pattern. (These +## settings are not used for other font sets.) +#mathtext.bf: sans:bold +#mathtext.cal: cursive +#mathtext.it: sans:italic +#mathtext.rm: sans +#mathtext.sf: sans +#mathtext.tt: monospace +#mathtext.fallback: cm # Select fallback font from ['cm' (Computer Modern), 'stix' + # 'stixsans'] when a symbol can not be found in one of the + # custom math fonts. Select 'None' to not perform fallback + # and replace the missing character by a dummy symbol. +#mathtext.default: it # The default font to use for math. + # Can be any of the LaTeX font names, including + # the special name "regular" for the same font + # used in regular text. + + +## *************************************************************************** +## * AXES * +## *************************************************************************** +## Following are default face and edge colors, default tick sizes, +## default font sizes for tick labels, and so on. See +## https://matplotlib.org/api/axes_api.html#module-matplotlib.axes +#axes.facecolor: white # axes background color +#axes.edgecolor: black # axes edge color +#axes.linewidth: 0.8 # edge line width +#axes.grid: False # display grid or not +#axes.grid.axis: both # which axis the grid should apply to +#axes.grid.which: major # grid lines at {major, minor, both} ticks +#axes.titlelocation: center # alignment of the title: {left, right, center} +#axes.titlesize: large # font size of the axes title +#axes.titleweight: normal # font weight of title +#axes.titlecolor: auto # color of the axes title, auto falls back to + # text.color as default value +#axes.titley: None # position title (axes relative units). None implies auto +#axes.titlepad: 6.0 # pad between axes and title in points +#axes.labelsize: medium # font size of the x and y labels +#axes.labelpad: 4.0 # space between label and axis +#axes.labelweight: normal # weight of the x and y labels +#axes.labelcolor: black +#axes.axisbelow: line # draw axis gridlines and ticks: + # - below patches (True) + # - above patches but below lines ('line') + # - above all (False) + +#axes.formatter.limits: -5, 6 # use scientific notation if log10 + # of the axis range is smaller than the + # first or larger than the second +#axes.formatter.use_locale: False # When True, format tick labels + # according to the user's locale. + # For example, use ',' as a decimal + # separator in the fr_FR locale. +#axes.formatter.use_mathtext: False # When True, use mathtext for scientific + # notation. +#axes.formatter.min_exponent: 0 # minimum exponent to format in scientific notation +#axes.formatter.useoffset: True # If True, the tick label formatter + # will default to labeling ticks relative + # to an offset when the data range is + # small compared to the minimum absolute + # value of the data. +#axes.formatter.offset_threshold: 4 # When useoffset is True, the offset + # will be used when it can remove + # at least this number of significant + # digits from tick labels. + +#axes.spines.left: True # display axis spines +#axes.spines.bottom: True +#axes.spines.top: True +#axes.spines.right: True + +#axes.unicode_minus: True # use Unicode for the minus symbol rather than hyphen. See + # https://en.wikipedia.org/wiki/Plus_and_minus_signs#Character_codes +# adapted from https://davidmathlogic.com/colorblind +# put yellow last, remove black +axes.prop_cycle: cycler('color', [ '56b4e9', 'e69f00', '009e73', '0072b2', 'd55e00', 'cc79a7', 'f0e442']) +#axes.prop_cycle: cycler('color', ['1f77b4', 'ff7f0e', '2ca02c', 'd62728', '9467bd', '8c564b', 'e377c2', '7f7f7f', 'bcbd22', '17becf']) + # color cycle for plot lines as list of string color specs: + # single letter, long name, or web-style hex + # As opposed to all other parameters in this file, the color + # values must be enclosed in quotes for this parameter, + # e.g. '1f77b4', instead of 1f77b4. + # See also https://matplotlib.org/tutorials/intermediate/color_cycle.html + # for more details on prop_cycle usage. +#axes.xmargin: .05 # x margin. See `axes.Axes.margins` +#axes.ymargin: .05 # y margin. See `axes.Axes.margins` +#axes.zmargin: .05 # z margin. See `axes.Axes.margins` +#axes.autolimit_mode: data # If "data", use axes.xmargin and axes.ymargin as is. + # If "round_numbers", after application of margins, axis + # limits are further expanded to the nearest "round" number. +#polaraxes.grid: True # display grid on polar axes +#axes3d.grid: True # display grid on 3D axes + + +## *************************************************************************** +## * AXIS * +## *************************************************************************** +#xaxis.labellocation: center # alignment of the xaxis label: {left, right, center} +#yaxis.labellocation: center # alignment of the yaxis label: {bottom, top, center} + + +## *************************************************************************** +## * DATES * +## *************************************************************************** +## These control the default format strings used in AutoDateFormatter. +## Any valid format datetime format string can be used (see the python +## `datetime` for details). For example, by using: +## - '%%x' will use the locale date representation +## - '%%X' will use the locale time representation +## - '%%c' will use the full locale datetime representation +## These values map to the scales: +## {'year': 365, 'month': 30, 'day': 1, 'hour': 1/24, 'minute': 1 / (24 * 60)} + +#date.autoformatter.year: %Y +#date.autoformatter.month: %Y-%m +#date.autoformatter.day: %Y-%m-%d +#date.autoformatter.hour: %m-%d %H +#date.autoformatter.minute: %d %H:%M +#date.autoformatter.second: %H:%M:%S +#date.autoformatter.microsecond: %M:%S.%f +## The reference date for Matplotlib's internal date representation +## See https://matplotlib.org/examples/ticks_and_spines/date_precision_and_epochs.py +#date.epoch: 1970-01-01T00:00:00 +## 'auto', 'concise': +#date.converter: auto +## For auto converter whether to use interval_multiples: +#date.interval_multiples: True + +## *************************************************************************** +## * TICKS * +## *************************************************************************** +## See https://matplotlib.org/api/axis_api.html#matplotlib.axis.Tick +#xtick.top: False # draw ticks on the top side +#xtick.bottom: True # draw ticks on the bottom side +#xtick.labeltop: False # draw label on the top +#xtick.labelbottom: True # draw label on the bottom +#xtick.major.size: 3.5 # major tick size in points +#xtick.minor.size: 2 # minor tick size in points +#xtick.major.width: 0.8 # major tick width in points +#xtick.minor.width: 0.6 # minor tick width in points +#xtick.major.pad: 3.5 # distance to major tick label in points +#xtick.minor.pad: 3.4 # distance to the minor tick label in points +#xtick.color: black # color of the ticks +#xtick.labelcolor: inherit # color of the tick labels or inherit from xtick.color +#xtick.labelsize: medium # font size of the tick labels +#xtick.direction: out # direction: {in, out, inout} +#xtick.minor.visible: False # visibility of minor ticks on x-axis +#xtick.major.top: True # draw x axis top major ticks +#xtick.major.bottom: True # draw x axis bottom major ticks +#xtick.minor.top: True # draw x axis top minor ticks +#xtick.minor.bottom: True # draw x axis bottom minor ticks +#xtick.alignment: center # alignment of xticks + +#ytick.left: True # draw ticks on the left side +#ytick.right: False # draw ticks on the right side +#ytick.labelleft: True # draw tick labels on the left side +#ytick.labelright: False # draw tick labels on the right side +#ytick.major.size: 3.5 # major tick size in points +#ytick.minor.size: 2 # minor tick size in points +#ytick.major.width: 0.8 # major tick width in points +#ytick.minor.width: 0.6 # minor tick width in points +#ytick.major.pad: 3.5 # distance to major tick label in points +#ytick.minor.pad: 3.4 # distance to the minor tick label in points +#ytick.color: black # color of the ticks +#ytick.labelcolor: inherit # color of the tick labels or inherit from ytick.color +#ytick.labelsize: medium # font size of the tick labels +#ytick.direction: out # direction: {in, out, inout} +#ytick.minor.visible: False # visibility of minor ticks on y-axis +#ytick.major.left: True # draw y axis left major ticks +#ytick.major.right: True # draw y axis right major ticks +#ytick.minor.left: True # draw y axis left minor ticks +#ytick.minor.right: True # draw y axis right minor ticks +#ytick.alignment: center_baseline # alignment of yticks + + +## *************************************************************************** +## * GRIDS * +## *************************************************************************** +#grid.color: b0b0b0 # grid color +#grid.linestyle: - # solid +#grid.linewidth: 0.8 # in points +#grid.alpha: 1.0 # transparency, between 0.0 and 1.0 + + +## *************************************************************************** +## * LEGEND * +## *************************************************************************** +#legend.loc: best +#legend.frameon: True # if True, draw the legend on a background patch +#legend.framealpha: 0.8 # legend patch transparency +#legend.facecolor: inherit # inherit from axes.facecolor; or color spec +#legend.edgecolor: 0.8 # background patch boundary color +#legend.fancybox: True # if True, use a rounded box for the + # legend background, else a rectangle +#legend.shadow: False # if True, give background a shadow effect +#legend.numpoints: 1 # the number of marker points in the legend line +#legend.scatterpoints: 1 # number of scatter points +#legend.markerscale: 1.0 # the relative size of legend markers vs. original +#legend.fontsize: medium +#legend.labelcolor: None +#legend.title_fontsize: None # None sets to the same as the default axes. + +## Dimensions as fraction of font size: +#legend.borderpad: 0.4 # border whitespace +#legend.labelspacing: 0.5 # the vertical space between the legend entries +#legend.handlelength: 2.0 # the length of the legend lines +#legend.handleheight: 0.7 # the height of the legend handle +#legend.handletextpad: 0.8 # the space between the legend line and legend text +#legend.borderaxespad: 0.5 # the border between the axes and legend edge +#legend.columnspacing: 2.0 # column separation + + +## *************************************************************************** +## * FIGURE * +## *************************************************************************** +## See https://matplotlib.org/api/figure_api.html#matplotlib.figure.Figure +#figure.titlesize: large # size of the figure title (``Figure.suptitle()``) +#figure.titleweight: normal # weight of the figure title +#figure.figsize: 6.4, 4.8 # figure size in inches +#figure.dpi: 100 # figure dots per inch +#figure.facecolor: white # figure face color +#figure.edgecolor: white # figure edge color +#figure.frameon: True # enable figure frame +#figure.max_open_warning: 20 # The maximum number of figures to open through + # the pyplot interface before emitting a warning. + # If less than one this feature is disabled. +#figure.raise_window : True # Raise the GUI window to front when show() is called. + +## The figure subplot parameters. All dimensions are a fraction of the figure width and height. +#figure.subplot.left: 0.125 # the left side of the subplots of the figure +#figure.subplot.right: 0.9 # the right side of the subplots of the figure +#figure.subplot.bottom: 0.11 # the bottom of the subplots of the figure +#figure.subplot.top: 0.88 # the top of the subplots of the figure +#figure.subplot.wspace: 0.2 # the amount of width reserved for space between subplots, + # expressed as a fraction of the average axis width +#figure.subplot.hspace: 0.2 # the amount of height reserved for space between subplots, + # expressed as a fraction of the average axis height + +## Figure layout +#figure.autolayout: False # When True, automatically adjust subplot + # parameters to make the plot fit the figure + # using `tight_layout` +#figure.constrained_layout.use: False # When True, automatically make plot + # elements fit on the figure. (Not + # compatible with `autolayout`, above). +#figure.constrained_layout.h_pad: 0.04167 # Padding around axes objects. Float representing +#figure.constrained_layout.w_pad: 0.04167 # inches. Default is 3/72 inches (3 points) +#figure.constrained_layout.hspace: 0.02 # Space between subplot groups. Float representing +#figure.constrained_layout.wspace: 0.02 # a fraction of the subplot widths being separated. + + +## *************************************************************************** +## * IMAGES * +## *************************************************************************** +#image.aspect: equal # {equal, auto} or a number +#image.interpolation: antialiased # see help(imshow) for options +#image.cmap: viridis # A colormap name, gray etc... +#image.lut: 256 # the size of the colormap lookup table +#image.origin: upper # {lower, upper} +#image.resample: True +#image.composite_image: True # When True, all the images on a set of axes are + # combined into a single composite image before + # saving a figure as a vector graphics file, + # such as a PDF. + + +## *************************************************************************** +## * CONTOUR PLOTS * +## *************************************************************************** +#contour.negative_linestyle: dashed # string or on-off ink sequence +#contour.corner_mask: True # {True, False, legacy} +#contour.linewidth: None # {float, None} Size of the contour line + # widths. If set to None, it falls back to + # `line.linewidth`. + + +## *************************************************************************** +## * ERRORBAR PLOTS * +## *************************************************************************** +#errorbar.capsize: 0 # length of end cap on error bars in pixels + + +## *************************************************************************** +## * HISTOGRAM PLOTS * +## *************************************************************************** +#hist.bins: 10 # The default number of histogram bins or 'auto'. + + +## *************************************************************************** +## * SCATTER PLOTS * +## *************************************************************************** +#scatter.marker: o # The default marker type for scatter plots. +#scatter.edgecolors: face # The default edge colors for scatter plots. + + +## *************************************************************************** +## * AGG RENDERING * +## *************************************************************************** +## Warning: experimental, 2008/10/10 +#agg.path.chunksize: 0 # 0 to disable; values in the range + # 10000 to 100000 can improve speed slightly + # and prevent an Agg rendering failure + # when plotting very large data sets, + # especially if they are very gappy. + # It may cause minor artifacts, though. + # A value of 20000 is probably a good + # starting point. + + +## *************************************************************************** +## * PATHS * +## *************************************************************************** +#path.simplify: True # When True, simplify paths by removing "invisible" + # points to reduce file size and increase rendering + # speed +#path.simplify_threshold: 0.111111111111 # The threshold of similarity below + # which vertices will be removed in + # the simplification process. +#path.snap: True # When True, rectilinear axis-aligned paths will be snapped + # to the nearest pixel when certain criteria are met. + # When False, paths will never be snapped. +#path.sketch: None # May be None, or a 3-tuple of the form: + # (scale, length, randomness). + # - *scale* is the amplitude of the wiggle + # perpendicular to the line (in pixels). + # - *length* is the length of the wiggle along the + # line (in pixels). + # - *randomness* is the factor by which the length is + # randomly scaled. +#path.effects: + + +## *************************************************************************** +## * SAVING FIGURES * +## *************************************************************************** +## The default savefig parameters can be different from the display parameters +## e.g., you may want a higher resolution, or to make the figure +## background white +#savefig.dpi: figure # figure dots per inch or 'figure' +#savefig.facecolor: auto # figure face color when saving +#savefig.edgecolor: auto # figure edge color when saving +#savefig.format: png # {png, ps, pdf, svg} +#savefig.bbox: standard # {tight, standard} + # 'tight' is incompatible with pipe-based animation + # backends (e.g. 'ffmpeg') but will work with those + # based on temporary files (e.g. 'ffmpeg_file') +#savefig.pad_inches: 0.1 # Padding to be used when bbox is set to 'tight' +#savefig.directory: ~ # default directory in savefig dialog box, + # leave empty to always use current working directory +#savefig.transparent: False # setting that controls whether figures are saved with a + # transparent background by default +#savefig.orientation: portrait # Orientation of saved figure + +### tk backend params +#tk.window_focus: False # Maintain shell focus for TkAgg + +### ps backend params +#ps.papersize: letter # {auto, letter, legal, ledger, A0-A10, B0-B10} +#ps.useafm: False # use of AFM fonts, results in small files +#ps.usedistiller: False # {ghostscript, xpdf, None} + # Experimental: may produce smaller files. + # xpdf intended for production of publication quality files, + # but requires ghostscript, xpdf and ps2eps +#ps.distiller.res: 6000 # dpi +#ps.fonttype: 3 # Output Type 3 (Type3) or Type 42 (TrueType) + +### PDF backend params +#pdf.compression: 6 # integer from 0 to 9 + # 0 disables compression (good for debugging) +#pdf.fonttype: 3 # Output Type 3 (Type3) or Type 42 (TrueType) +#pdf.use14corefonts: False +#pdf.inheritcolor: False + +### SVG backend params +#svg.image_inline: True # Write raster image data directly into the SVG file +#svg.fonttype: path # How to handle SVG fonts: + # path: Embed characters as paths -- supported + # by most SVG renderers + # None: Assume fonts are installed on the + # machine where the SVG will be viewed. +#svg.hashsalt: None # If not None, use this string as hash salt instead of uuid4 + +### pgf parameter +## See https://matplotlib.org/tutorials/text/pgf.html for more information. +#pgf.rcfonts: True +#pgf.preamble: # See text.latex.preamble for documentation +#pgf.texsystem: xelatex + +### docstring params +#docstring.hardcopy: False # set this when you want to generate hardcopy docstring + + +## *************************************************************************** +## * INTERACTIVE KEYMAPS * +## *************************************************************************** +## Event keys to interact with figures/plots via keyboard. +## See https://matplotlib.org/users/navigation_toolbar.html for more details on +## interactive navigation. Customize these settings according to your needs. +## Leave the field(s) empty if you don't need a key-map. (i.e., fullscreen : '') +#keymap.fullscreen: f, ctrl+f # toggling +#keymap.home: h, r, home # home or reset mnemonic +#keymap.back: left, c, backspace, MouseButton.BACK # forward / backward keys +#keymap.forward: right, v, MouseButton.FORWARD # for quick navigation +#keymap.pan: p # pan mnemonic +#keymap.zoom: o # zoom mnemonic +#keymap.save: s, ctrl+s # saving current figure +#keymap.help: f1 # display help about active tools +#keymap.quit: ctrl+w, cmd+w, q # close the current figure +#keymap.quit_all: # close all figures +#keymap.grid: g # switching on/off major grids in current axes +#keymap.grid_minor: G # switching on/off minor grids in current axes +#keymap.yscale: l # toggle scaling of y-axes ('log'/'linear') +#keymap.xscale: k, L # toggle scaling of x-axes ('log'/'linear') +#keymap.copy: ctrl+c, cmd+c # copy figure to clipboard + + +## *************************************************************************** +## * ANIMATION * +## *************************************************************************** +#animation.html: none # How to display the animation as HTML in + # the IPython notebook: + # - 'html5' uses HTML5 video tag + # - 'jshtml' creates a JavaScript animation +#animation.writer: ffmpeg # MovieWriter 'backend' to use +#animation.codec: h264 # Codec to use for writing movie +#animation.bitrate: -1 # Controls size/quality trade-off for movie. + # -1 implies let utility auto-determine +#animation.frame_format: png # Controls frame format used by temp files +#animation.ffmpeg_path: ffmpeg # Path to ffmpeg binary. Without full path + # $PATH is searched +#animation.ffmpeg_args: # Additional arguments to pass to ffmpeg +#animation.convert_path: convert # Path to ImageMagick's convert binary. + # On Windows use the full path since convert + # is also the name of a system tool. +#animation.convert_args: # Additional arguments to pass to convert +#animation.embed_limit: 20.0 # Limit, in MB, of size of base64 encoded + # animation in HTML (i.e. IPython notebook) \ No newline at end of file diff --git a/examples/generative/corrdiff_plus_plus/inference/plot_multiple_samples.py b/examples/generative/corrdiff_plus_plus/inference/plot_multiple_samples.py new file mode 100644 index 0000000000..76f36dca29 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/inference/plot_multiple_samples.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import joblib +import matplotlib.pyplot as plt +import xarray + + +def plot_samples(netcdf_file, output_dir, n_samples): + """Plot multiple samples""" + root = xarray.open_dataset(netcdf_file) + ds = ( + xarray.open_dataset(netcdf_file, group="prediction") + .merge(root) + .set_coords(["lat", "lon"]) + ) + truth = ( + xarray.open_dataset(netcdf_file, group="truth") + .merge(root) + .set_coords(["lat", "lon"]) + ) + os.makedirs(output_dir, exist_ok=True) + + # concatenate truth data and ensemble mean as an "ensemble" member for easy + # plotting + truth_expanded = truth.assign_coords(ensemble="truth").expand_dims("ensemble") + ens_mean = ( + ds.mean("ensemble") + .assign_coords(ensemble="ensemble_mean") + .expand_dims("ensemble") + ) + # add [0, 1, 2, ...] to ensemble dim + ds["ensemble"] = [str(i) for i in range(ds.sizes["ensemble"])] + merged = xarray.concat([truth_expanded, ens_mean, ds], dim="ensemble") + + # plot the variables in parallel + def plot(v): + print(v) + # 2 is for the ensemble and + merged[v][: n_samples + 2, :].plot(row="time", col="ensemble") + plt.savefig(f"{output_dir}/{v}.png") + + joblib.Parallel(n_jobs=8)(joblib.delayed(plot)(v) for v in merged) + + +if __name__ == "__main__": + # Create the parser + parser = argparse.ArgumentParser() + + # Add the positional arguments + parser.add_argument("--netcdf_file", help="Path to the NetCDF file") + parser.add_argument("--output_dir", help="Path to the output directory") + # Add the optional argument + parser.add_argument("--n-samples", help="Number of samples", default=5, type=int) + # Parse the arguments + args = parser.parse_args() + main(args.netcdf_file, args.output_dir, args.n_samples) diff --git a/examples/generative/corrdiff_plus_plus/inference/plot_single_sample.py b/examples/generative/corrdiff_plus_plus/inference/plot_single_sample.py new file mode 100644 index 0000000000..9cf7dbccfd --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/inference/plot_single_sample.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import cftime +import matplotlib.pyplot as plt +import netCDF4 as nc +import numpy as np + + +def pattern_correlation(x, y): + """Pattern correlation""" + mx = np.mean(x) + my = np.mean(y) + vx = np.mean((x - mx) ** 2) + vy = np.mean((y - my) ** 2) + + a = np.mean((x - mx) * (y - my)) + b = np.sqrt(vx * vy) + + return a / b + + +def plot_channels(group, time_idx: int): + """Plot channels""" + # weather sub-plot + num_channels = len(group.variables) + ncols = 4 + fig, axs = plt.subplots( + nrows=( + num_channels // ncols + if num_channels % ncols == 0 + else num_channels // ncols + 1 + ), + ncols=ncols, + sharex=True, + sharey=True, + constrained_layout=True, + figsize=(15, 15), + ) + + for ch, ax in zip(sorted(group.variables), axs.flat): + # label row + x = group[ch][time_idx] + ax.set_title(ch) + ax.imshow(x) + + +def channel_eq(a, b): + """Check if two channels are equal in variable and pressure.""" + variable_equal = a["variable"] == b["variable"] + pressure_is_nan = np.isnan(a["pressure"]) and np.isnan(b["pressure"]) + pressure_equal = a["pressure"] == b["pressure"] + return variable_equal and (pressure_equal or pressure_is_nan) + + +def channel_repr(channel): + """Return a string representation of a channel with variable and pressure.""" + v = channel["variable"] + pressure = channel["pressure"] + return f"{v}\n Pressure: {pressure}" + + +def get_clim(output_channels, f): + """Get color limits (clim) for output channels based on prediction and truth data.""" + colorlimits = {} + for ch in range(len(output_channels)): + channel = output_channels[ch] + y = f["prediction"][channel][:] + truth = f["truth"][channel][:] + + vmin = min([y.min(), truth.min()]) + vmax = max([y.max(), truth.max()]) + colorlimits[channel] = (vmin, vmax) + return colorlimits + + +def main(file, output_dir, sample): + """Plot single sample""" + os.makedirs(output_dir, exist_ok=True) + f = nc.Dataset(file, "r") + + # for c in f.time: + output_channels = list(f["prediction"].variables) + v = f["time"] + times = cftime.num2date(v, units=v.units, calendar=v.calendar) + + def plot_panel(ax, data, **kwargs): + return ax.pcolormesh(f["lon"], f["lat"], data, cmap="RdBu_r", **kwargs) + + colorlimits = get_clim(output_channels, f) + for idx in range(len(times)): + print("idx", idx) + # weather sub-plot + fig, axs = plt.subplots( + nrows=len(output_channels), + ncols=3, + sharex=True, + sharey=True, + constrained_layout=True, + figsize=(12, 12), + ) + row = axs[0] + row[0].set_title("Input") + row[1].set_title("Generated") + row[2].set_title("Truth") + + for ch in range(len(output_channels)): + channel = output_channels[ch] + row = axs[ch] + + # label row + + y = f["prediction"][channel][sample, idx] + truth = f["truth"][channel][idx] + + # search for input_channel + input_channels = list(f["input"].variables) + if channel in input_channels: + x = f["input"][channel][idx] + else: + x = None + + vmin, vmax = colorlimits[channel] + + def plot_panel(ax, data, **kwargs): + if channel == "maximum_radar_reflectivity": + return ax.pcolormesh( + f["lon"], f["lat"], data, cmap="magma", vmin=0, vmax=vmax + ) + if channel == "temperature_2m": + return ax.pcolormesh( + f["lon"], f["lat"], data, cmap="magma", vmin=vmin, vmax=vmax + ) + else: + if vmin < 0 < vmax: + bound = max(abs(vmin), abs(vmax)) + vmin1 = -bound + vmax1 = bound + else: + vmin1 = vmin + vmax1 = vmax + return ax.pcolormesh( + f["lon"], f["lat"], data, cmap="RdBu_r", vmin=vmin1, vmax=vmax1 + ) + + if x is not None: + plot_panel(row[0], x) + pc_x = pattern_correlation(x, truth) + label_x = pc_x + row[0].set_title(f"Input, Pattern correlation: {label_x:.2f}") + + im = plot_panel(row[1], y) + plot_panel(row[2], truth) + + cb = plt.colorbar(im, ax=row.tolist()) + cb.set_label(channel) + + pc_y = pattern_correlation(y, truth) + label_y = pc_y + row[1].set_title(f"Generated, Pattern correlation: {label_y:.2f}") + + for ax in axs[-1]: + ax.set_xlabel("longitude [deg E]") + + for ax in axs[:, 0]: + ax.set_ylabel("latitude [deg N]") + + time = times[idx] + plt.suptitle(f"Time {time.isoformat()}") + plt.savefig(f"{output_dir}/{time.isoformat()}.sample.png") + + plot_channels(f["input"], idx) + plt.savefig(f"{output_dir}/{time.isoformat()}.input.png") + + +if __name__ == "__main__": + # Create the parser + parser = argparse.ArgumentParser() + # Add the positional arguments + parser.add_argument("--netcdf_file", help="Path to the NetCDF file") + parser.add_argument("--output_dir", help="Path to the output directory") + # Add the optional argument + parser.add_argument("--sample", help="Sample to plot", default=0, type=int) + # Parse the arguments + args = parser.parse_args() + main(args.netcdf_file, args.output_dir, args.sample) diff --git a/examples/generative/corrdiff_plus_plus/inference/power_spectra.py b/examples/generative/corrdiff_plus_plus/inference/power_spectra.py new file mode 100644 index 0000000000..6d6bfc84c4 --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/inference/power_spectra.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import matplotlib.pyplot as plt +import numpy as np +import typer +import xarray +from scipy.fft import irfft +from scipy.signal import periodogram + + +def open_data(file, group=False): + """ + Opens a dataset from a NetCDF file. + + Parameters: + file (str): Path to the NetCDF file. + group (bool, optional): Whether to open the file as a group. Default is False. + + Returns: + xarray.Dataset: An xarray dataset containing the data from the NetCDF file. + """ + root = xarray.open_dataset(file) + root = root.set_coords(["lat", "lon"]) + ds = xarray.open_dataset(file, group=group) + ds.coords.update(root.coords) + ds.attrs.update(root.attrs) + + return ds + + +def haversine(lat1, lon1, lat2, lon2): + """ + Calculate the Haversine distance between two sets of latitude and longitude coordinates. + + The Haversine formula calculates the shortest distance between two points on the + surface of a sphere (in this case, the Earth) given their latitude and longitude + coordinates. + + Parameters: + lat1 (float): Latitude of the first point in degrees. + lon1 (float): Longitude of the first point in degrees. + lat2 (float): Latitude of the second point in degrees. + lon2 (float): Longitude of the second point in degrees. + + Returns: + float: The Haversine distance between the two points in meters. + """ + # Convert latitude and longitude from degrees to radians + lat1_rad = np.radians(lat1) + lon1_rad = np.radians(lon1) + lat2_rad = np.radians(lat2) + lon2_rad = np.radians(lon2) + + # Earth radius in meters + earth_radius = 6371000 # Approximate value for the average Earth radius + + # Calculate differences in latitude and longitude + dlat_rad = lat2_rad - lat1_rad + dlon_rad = lon2_rad - lon1_rad + + # Haversine formula + a = ( + np.sin(dlat_rad / 2) ** 2 + + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon_rad / 2) ** 2 + ) + c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) + distance_meters = earth_radius * c + + return distance_meters + + +def compute_power_spectrum(data, d): + """ + Compute the power spectrum of a 2D data array using the Fast Fourier Transform (FFT). + + The power spectrum represents the distribution of signal power as a function of frequency. + + Parameters: + data (numpy.ndarray): 2D input data array. + d (float): Sampling interval (time between data points). + + Returns: + tuple: A tuple containing the frequency values and the corresponding power spectrum. + - freqs (numpy.ndarray): Frequency values corresponding to the power spectrum. + - power_spectrum (numpy.ndarray): Power spectrum of the input data. + """ + + # Compute the 2D FFT along the second dimension + fft_data = np.fft.fft(data, axis=-2) + + # Compute the power spectrum by taking the absolute value and squaring + power_spectrum = np.abs(fft_data) ** 2 + + # Scale the power spectrum based on the sampling interval 'd' + power_spectrum /= data.shape[-1] * d + freqs = np.fft.fftfreq(data.shape[-1], d) + + return freqs, power_spectrum + + +def power_spectra_to_acf(f, pw): + """ + Convert a one-sided power spectrum to an autocorrelation function. + + Args: + f (numpy.ndarray): Frequencies. + pw (numpy.ndarray): Power spectral density in units of V^2/Hz. + + Returns: + numpy.ndarray: Autocorrelation function (ACF). + """ + pw = pw.copy() + pw[0] = 0 + # magic factor 4 comes from periodogram/irfft stuff + # 1) a factor 2 comes from the periodogram being one-sided. + # I don't fully understasnd, but this ensures the acf is 1 at r=0 + sig2 = np.sum(pw * f[1]) * 4 + acf = irfft(pw) / sig2 + return acf + + +def average_power_spectrum(data, d): + """ + Compute the average power spectrum of a 2D data array. + + This function calculates the power spectrum for each row of the input data and + then averages them to obtain the overall power spectrum. + The power spectrum represents the distribution of signal power as a function of frequency. + + Parameters: + data (numpy.ndarray): 2D input data array. + d (float): Sampling interval (time between data points). + + Returns: + tuple: A tuple containing the frequency values and the average power spectrum. + - freqs (numpy.ndarray): Frequency values corresponding to the power spectrum. + - power_spectra (numpy.ndarray): Average power spectrum of the input data. + """ + # Compute the power spectrum along the second dimension for each row + freqs, power_spectra = periodogram(data, fs=1 / d, axis=-1) + + # Average along the first dimension + while power_spectra.ndim > 1: + power_spectra = power_spectra.mean(axis=0) + + return freqs, power_spectra + + +def main(file, output): + """ + Generate and save multiple power spectrum plots from input data. + + Parameters: + file (str): Path to the input data file. + output (str): Directory where the generated plots will be saved. + + This function loads and processes various datasets from the input file, + calculates their power spectra, and generates and saves multiple power spectrum plots. + The plots include kinetic energy, temperature, and reflectivity power spectra. + """ + + def savefig(name): + path = os.path.join(output, name + ".png") + plt.savefig(path) + + samples = {} + samples["prediction"] = open_data(file, group="prediction") + samples["prediction_mean"] = samples["prediction"].mean("ensemble") + samples["truth"] = open_data(file, group="truth") + samples["ERA5"] = open_data(file, group="input") + + prediction = samples["prediction"] + lat = prediction.lat + lon = prediction.lon + + dx = haversine(lat[0, 0], lon[0, 0], lat[1, 0], lon[1, 0]) + dy = haversine(lat[0, 0], lon[0, 0], lat[0, 1], lon[0, 1]) + print(dx, dy) + # the approximate resolution is dx=dy=2000m + + # in km + d = 2 + + # Plot the power spectrum + for name, data in samples.items(): + freqs, spec_x = average_power_spectrum(data.eastward_wind_10m, d=d) + _, spec_y = average_power_spectrum(data.northward_wind_10m, d=d) + spec = spec_x + spec_y + plt.loglog(freqs, spec, label=name) + plt.xlabel("Frequency (1/km)") + plt.ylabel("Power Spectrum") + plt.ylim(bottom=1e-1) + plt.title("Kinetic Energy power spectra") + plt.grid() + plt.legend() + savefig("ke-spectra") + + plt.figure() + for name, data in samples.items(): + freqs, spec = average_power_spectrum(data.temperature_2m, d=d) + plt.loglog(freqs, spec, label=name) + plt.xlabel("Frequency (1/km)") + plt.ylabel("Power Spectrum") + plt.ylim(bottom=1e-1) + plt.title("T2M Power spectra") + plt.grid() + plt.legend() + savefig("t2m-spectra") + + plt.figure() + for name, data in samples.items(): + try: + freqs, spec = average_power_spectrum(data.maximum_radar_reflectivity, d=d) + except AttributeError: + continue + plt.loglog(freqs, spec, label=name) + plt.xlabel("Frequency (1/km)") + plt.ylabel("Power Spectrum") + plt.ylim(bottom=1e-1) + plt.title("Reflectivity Power spectra") + plt.grid() + plt.legend() + savefig("reflectivity-spectra") + + +if __name__ == "__main__": + typer.run(main) diff --git a/examples/generative/corrdiff_plus_plus/inference/read_netcdf.py b/examples/generative/corrdiff_plus_plus/inference/read_netcdf.py new file mode 100644 index 0000000000..64164d44aa --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/inference/read_netcdf.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import netCDF4 as nc + +# Open the NetCDF file +file_path = "image_outdir_0_score.nc" # Replace with the path to your NetCDF file +dataset = nc.Dataset(file_path, "r") # 'r' stands for read mode + +# Access variables and attributes +print("Variables:") +for var_name, var in dataset.variables.items(): + print(f"{var_name}: {var[:]}") # Access the data for each variable + +print("\nGlobal attributes:") +for attr_name in dataset.ncattrs(): + print(f"{attr_name}: {getattr(dataset, attr_name)}") # Access global attributes + +# Close the NetCDF file when done +dataset.close() diff --git a/examples/generative/corrdiff_plus_plus/train.py b/examples/generative/corrdiff_plus_plus/train.py new file mode 100644 index 0000000000..1386526d2a --- /dev/null +++ b/examples/generative/corrdiff_plus_plus/train.py @@ -0,0 +1,496 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os, time, psutil, hydra, torch, sys +from hydra.utils import to_absolute_path +from omegaconf import DictConfig, OmegaConf +from torch.nn.parallel import DistributedDataParallel +from torch.utils.tensorboard import SummaryWriter +from physicsnemo import Module +from physicsnemo.models.diffusion import ( + SongUNet, + EDMPrecondSR, + SFMPrecondSR, + SFMPrecondEmpty, +) +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.metrics.diffusion import ( + RegressionLoss, + ResLoss, + SFMLoss, + # SFMLossSigmaPerChannel, + SFMEncoderLoss, +) +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.launch.utils import load_checkpoint, save_checkpoint + +# Load utilities from corrdiff examples, make the corrdiff path absolute to avoid issues +# sys.path.append(sys.path.append(os.path.join(os.path.dirname(__file__), "../corrdiff")) ) +from datasets.dataset import init_train_valid_datasets_from_config +from helpers.train_helpers import ( + set_patch_shape, + set_seed, + configure_cuda_for_consistent_precision, + compute_num_accumulation_rounds, + handle_and_clip_gradients, + is_time_for_periodic_task, +) +from helpers.sfm_utils import get_encoder + + +# Train the CorrDiff model using the configurations in "conf/config_training.yaml" +@hydra.main(version_base="1.2", config_path="conf", config_name="config_training") +def main(cfg: DictConfig) -> None: + # Initialize distributed environment for training + DistributedManager.initialize() + dist = DistributedManager() + + # Initialize loggers + if dist.rank == 0: + writer = SummaryWriter(log_dir="tensorboard") + logger = PythonLogger("main") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger + + # Resolve and parse configs + OmegaConf.resolve(cfg) + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if hasattr(cfg, "validation"): + train_test_split = True + validation_dataset_cfg = OmegaConf.to_container(cfg.validation) + else: + train_test_split = False + validation_dataset_cfg = None + + fp_optimizations = cfg.training.perf.fp_optimizations + fp16 = fp_optimizations == "fp16" + enable_amp = fp_optimizations.startswith("amp") + amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + + logger.info(f"Saving the outputs in {os.getcwd()}") + checkpoint_dir = os.path.join( + cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" + ) + if cfg.training.hp.batch_size_per_gpu == "auto": + cfg.training.hp.batch_size_per_gpu = ( + cfg.training.hp.total_batch_size // dist.world_size + ) + + # Set seeds and configure CUDA and cuDNN settings to ensure consistent precision + set_seed(dist.rank) + configure_cuda_for_consistent_precision() + + # Instantiate the dataset + data_loader_kwargs = { + "pin_memory": True, + "num_workers": cfg.training.perf.dataloader_workers, + "prefetch_factor": cfg.training.perf.dataloader_workers, + } + ( + dataset, + dataset_iterator, + validation_dataset, + validation_dataset_iterator, + ) = init_train_valid_datasets_from_config( + dataset_cfg, + data_loader_kwargs, + batch_size=cfg.training.hp.batch_size_per_gpu, + seed=0, + validation_dataset_cfg=validation_dataset_cfg, + train_test_split=train_test_split, + ) + + # Parse image configuration & update model args + dataset_channels = len(dataset.input_channels()) + img_in_channels = dataset_channels + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + patch_shape = (None, None) + + # Instantiate the model and move to device. + if cfg.model.name not in ( + "sfm_encoder", + "sfm", + "sfm_two_stage", + ): + raise ValueError("Invalid model") + model_args = { # default parameters for all networks + "img_out_channels": img_out_channels, + "img_resolution": list(img_shape), + "use_fp16": fp16, + } + ## remaining defaults are what we want + standard_model_cfgs = { # default parameters for different network types + "sfm": { + "gridtype": "sinusoidal", + "N_grid_channels": 4, + }, + "sfm_two_stage": { + "gridtype": "sinusoidal", + "N_grid_channels": 4, + }, + "sfm_encoder": {}, # empty preconditioner + } + + model_args.update(standard_model_cfgs[cfg.model.name]) + if hasattr(cfg.model, "model_args"): # override defaults from config file + model_args.update(OmegaConf.to_container(cfg.model.model_args)) + + if cfg.model.name == "sfm_encoder": + # should this be set to no_grad? + denoiser_net = SFMPrecondEmpty() + else: # sfm or sfm_two_stage + denoiser_net = SFMPrecondSR( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + + denoiser_net.train().requires_grad_(True).to(dist.device) + denoiser_ema = copy.deepcopy(denoiser_net).eval().requires_grad_(False) + ema_halflife_nimg = int(cfg.training.hp.ema * 1000000) + if hasattr(cfg.training.hp, "ema_rampup_ratio"): + ema_rampup_ratio = float(cfg.training.hp.ema_rampup_ratio) + else: + ema_rampup_ratio = 0.5 + + # Create or load the encoder: + if cfg.model.name in ["sfm", "sfm_encoder"]: + encoder_net = get_encoder(cfg) + encoder_net.train().requires_grad_(True).to(dist.device) + logger0.success("Constructed encoder network succesfully") + else: # "sfm_two_stage" + if not hasattr(cfg.training.io, "encoder_checkpoint_path"): + raise KeyError( + "Need to provide encoder_checkpoint_path when using sfm_two_stage" + ) + encoder_checkpoint_path = to_absolute_path( + cfg.training.io.encoder_checkpoint_path + ) + if not os.path.exists(encoder_checkpoint_path): + raise FileNotFoundError( + f"Expected this encoder checkpoint but not found: {encoder_checkpoint_path}" + ) + encoder_net = Module.from_checkpoint(encoder_checkpoint_path) + encoder_net.eval().requires_grad_(False).to(dist.device) + logger0.success("Loaded the pre-trained encoder network") + + # Instantiate the loss function(s) + if cfg.model.name in ("sfm", "sfm_two_stage"): + loss_fn = SFMLoss( + encoder_loss_type=cfg.model.encoder_loss_type, + encoder_loss_weight=cfg.model.encoder_loss_weight, + sigma_min=cfg.model.sigma_min, + ) + # with sfm the encoder and diffusion model are trained together + if cfg.model.name == "sfm": + loss_fn_encoder = SFMEncoderLoss(encoder_loss_type="l2") + elif cfg.model.name == "sfm_encoder": + loss_fn = SFMEncoderLoss( + encoder_loss_type=cfg.model.encoder_loss_type, + ) + else: + raise NotImplementedError(f"Model {cfg.model.name} not supported.") + + # Instantiate the optimizer + if cfg.model.name == "sfm_two_stage": + params = denoiser_net.parameters() + else: + params = list(denoiser_net.parameters()) + list(encoder_net.parameters()) + + optimizer = torch.optim.Adam( + params=params, lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8 + ) + + # Enable distributed data parallel if applicable + if dist.world_size > 1: + ddp_denoiser_net = DistributedDataParallel( + denoiser_net, + device_ids=[dist.local_rank], + broadcast_buffers=True, + output_device=dist.device, + find_unused_parameters=dist.find_unused_parameters, + ) + if cfg.model.name != "sfm_two_stage": + encoder_net = DistributedDataParallel( + encoder_net, + device_ids=[dist.local_rank], + broadcast_buffers=True, + output_device=dist.device, + find_unused_parameters=dist.find_unused_parameters, + ) + else: + # for convenience when updating the denoiser sigma + ddp_denoiser_net = denoiser_net + + # Record the current time to measure the duration of subsequent operations. + start_time = time.time() + + # Compute the number of required gradient accumulation rounds + # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size + batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( + cfg.training.hp.total_batch_size, + cfg.training.hp.batch_size_per_gpu, + dist.world_size, + ) + batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") + + ## Resume training from previous checkpoints if exists + ### TODO needs to be redone, need to store model + encoder + optimizer + if dist.world_size > 1: + torch.distributed.barrier() + try: + cur_nimg = load_checkpoint( + path=checkpoint_dir, + models=[denoiser_net, encoder_net], + optimizer=optimizer, + device=dist.device, + ) + except: + cur_nimg = 0 + + ############################################################################ + # MAIN TRAINING LOOP # + ############################################################################ + + logger0.info(f"Training for {cfg.training.hp.training_duration} images...") + done = False + + # init variables to monitor running mean of average loss since last periodic + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + while not done: + tick_start_nimg = cur_nimg + tick_start_time = time.time() + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + loss_accum = 0 + for _ in range(num_accumulation_rounds): + img_clean, img_lr, labels = next(dataset_iterator) + img_clean = img_clean.to(dist.device).to(torch.float32).contiguous() + img_lr = img_lr.to(dist.device).to(torch.float32).contiguous() + with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp): + loss = loss_fn( + denoiser_net=ddp_denoiser_net, + encoder_net=encoder_net, + img_clean=img_clean, + img_lr=img_lr, + ) + loss = loss.sum() / batch_size_per_gpu + loss_accum += loss / num_accumulation_rounds + loss.backward() + + loss_sum = torch.tensor([loss_accum], device=dist.device) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM) + average_loss = (loss_sum / dist.world_size).cpu().item() + + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 + + if dist.rank == 0: + writer.add_scalar("training_loss", average_loss, cur_nimg) + writer.add_scalar( + "training_loss_running_mean", average_loss_running_mean, cur_nimg + ) + + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + # Update weights. + lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) + if cur_nimg >= lr_rampup: + g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) + current_lr = g["lr"] + if dist.rank == 0: + writer.add_scalar("learning_rate", current_lr, cur_nimg) + + # clear any nans from the denoiser + for param in denoiser_net.parameters(): + if param.grad is not None: + torch.nan_to_num( + param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad + ) + handle_and_clip_gradients( + denoiser_net, grad_clip_threshold=cfg.training.hp.grad_clip_threshold + ) + handle_and_clip_gradients( + encoder_net, grad_clip_threshold=cfg.training.hp.grad_clip_threshold + ) + optimizer.step() + + # Update EMA. + if ema_rampup_ratio is not None: + ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio) + ema_beta = 0.5 ** ( + cfg.training.hp.total_batch_size / max(ema_halflife_nimg, 1e-8) + ) + for p_ema, p_net in zip(denoiser_ema.parameters(), denoiser_net.parameters()): + p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta)) + + cur_nimg += cfg.training.hp.total_batch_size + done = cur_nimg >= cfg.training.hp.training_duration + + # Validation + if validation_dataset_iterator is not None: + valid_loss_accum = 0 + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.validation_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + rmse_encoder_valid_accum_mean = 0 + with torch.no_grad(): + for _ in range(cfg.training.io.validation_steps): + img_clean_valid, img_lr_valid, labels_valid = next( + validation_dataset_iterator + ) + + img_clean_valid = ( + img_clean_valid.to(dist.device) + .to(torch.float32) + .contiguous() + ) + img_lr_valid = ( + img_lr_valid.to(dist.device).to(torch.float32).contiguous() + ) + loss_valid = loss_fn( + denoiser_net=ddp_denoiser_net, + encoder_net=encoder_net, + img_clean=img_clean_valid, + img_lr=img_lr_valid, + ) + loss_valid = ( + (loss_valid.sum() / batch_size_per_gpu).cpu().item() + ) + valid_loss_accum += ( + loss_valid / cfg.training.io.validation_steps + ) + + if cfg.model.name == "sfm": + rmse_encoder_valid = loss_fn_encoder( + denoiser_net=ddp_denoiser_net, + encoder_net=encoder_net, + img_clean=img_clean_valid, + img_lr=img_lr_valid, + ) + rmse_encoder_valid_accum_mean += ( + rmse_encoder_valid.mean((0, 2, 3)) + / cfg.training.io.validation_steps + ) + + valid_loss_sum = torch.tensor( + [valid_loss_accum], device=dist.device + ) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_valid_loss = valid_loss_sum / dist.world_size + if dist.rank == 0: + writer.add_scalar( + "validation_loss", average_valid_loss, cur_nimg + ) + + if dist.rank == 0: + if ( + cfg.model.name == "sfm" + and cfg.model.model_args["sigma_max"]["learnable"] + ): + denoiser_net.update_sigma_max(rmse_encoder_valid_accum_mean) + + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + logger0.info(" ".join(fields)) + torch.cuda.reset_peak_memory_stats() + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.save_checkpoint_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + if cfg.model.name in ["sfm", "sfm_two_stage"]: + save_list = [denoiser_net, encoder_net] + else: + save_list = [encoder_net] + save_checkpoint( + path=checkpoint_dir, + models=save_list, + optimizer=optimizer, + epoch=cur_nimg, + ) + + # Done. + logger0.info("Training Completed.") + + +if __name__ == "__main__": + main() diff --git a/physicsnemo/launch/logging/wandb.py b/physicsnemo/launch/logging/wandb.py index 9cb4879c8e..24e5a38974 100644 --- a/physicsnemo/launch/logging/wandb.py +++ b/physicsnemo/launch/logging/wandb.py @@ -23,9 +23,8 @@ from typing import Literal import wandb -from wandb import AlertLevel - from physicsnemo.distributed import DistributedManager +from wandb import AlertLevel from .utils import create_ddp_group_tag diff --git a/physicsnemo/metrics/diffusion/__init__.py b/physicsnemo/metrics/diffusion/__init__.py index 8673ce8eb2..d96c11deb0 100644 --- a/physicsnemo/metrics/diffusion/__init__.py +++ b/physicsnemo/metrics/diffusion/__init__.py @@ -25,3 +25,4 @@ VELoss_dfsr, VPLoss, ) +from .sfm_loss import SFMEncoderLoss, SFMLoss diff --git a/physicsnemo/metrics/diffusion/sfm_loss.py b/physicsnemo/metrics/diffusion/sfm_loss.py new file mode 100644 index 0000000000..7eb83035f7 --- /dev/null +++ b/physicsnemo/metrics/diffusion/sfm_loss.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from physicsnemo.models.diffusion import SongUNetPosEmbd + + +class SFMLoss: + """ + Loss function corresponding to Stochastic Flow matching + + Parameters + ---------- + encoder_loss_type: str + Type of loss to use ["l1", "l2", None] + encoder_loss_weight: float + Regularizer loss weights, by defaults 0.1. + sigma_min: Union[List[float], float] + Minimum value of noise sigma, default 2e-3 + Protects against values near zero that result in loss explosion. + sigma_data: float + EDM weighting, default 0.5 + """ + + def __init__( + self, + encoder_loss_type: str = "l2", + encoder_loss_weight: float = 0.1, + sigma_min: Union[List[float], float] = 0.002, + sigma_data: float = 0.5, + ): + """ + Loss function corresponding to Stochastic Flow matching + + Parameters + ---------- + encoder_loss_type: str, optional + Type of loss to use ["l1", "l2", None], defaults to 'l2' + encoder_loss_weight: float, optional + Regularizer loss weights, by defaults 0.1. + sigma_min: Union[List[float], float], optional + Minimum value of noise sigma, default 2e-3 + Protects against values near zero that result in loss explosion. + sigma_data: float, optional + EDM weighting, default 0.5 + """ + self.encoder_loss_type = encoder_loss_type + self.encoder_loss_weight = encoder_loss_weight + self.sigma_min = sigma_min + self.sigma_data = sigma_data + + if encoder_loss_type not in ["l1", "l2", None]: + raise ValueError( + f"encoder_loss_type should be one of ['l1', 'l2', None] not {encoder_loss_type}" + ) + + def __call__( + self, + denoiser_net: torch.nn.Module, + encoder_net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + ): + """ + Calculate the loss for corresponding to stochastic flow matching + + Parameters + ---------- + denoiser_net: torch.Tensor + The denoiser network making the predictions + encoder_net: torch.Tensor + The encoder network making the predictions + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + Returns + ------- + torch.Tensor + A tensor representing the combined loss calculated based on the flow matching + encoder and denoiser networks + """ + # uniformly samples from 0 to 1 in torch + if isinstance(denoiser_net, torch.nn.parallel.DistributedDataParallel): + sigma_max_per_channel = denoiser_net.module.get_sigma_max().to( + device=img_clean.device + ) + else: + sigma_max_per_channel = denoiser_net.get_sigma_max().to( + device=img_clean.device + ) + + # clamp to min value + if not isinstance(self.sigma_min, float) and len(self.sigma_min) > 1: + sigma_max_per_channel = torch.maximum( + sigma_max_per_channel, + torch.tensor(self.sigma_min, device=img_clean.device), + ) + # Normalize from 0 to 1 + sigma_max = 1.0 + else: + # just use the first value, ignore the rest + sigma_max = torch.maximum( + sigma_max_per_channel, + torch.tensor(self.sigma_min, device=img_clean.device), + )[0] + + rnd_uniform = torch.rand([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + + sampled_sigma = rnd_uniform * sigma_max + weight = (sampled_sigma**2 + self.sigma_data**2) / ( + sampled_sigma * self.sigma_data + ) ** 2 + + # augment for conditional generaiton + x_tot = torch.cat((img_clean, img_lr), dim=1) + + x_1 = x_tot[:, : img_clean.shape[1], :, :] # x_1 - target + x_low = x_tot[:, img_clean.shape[1] :, :, :] # x_low - upsampled ERA5 + + # encode the low resolution data x_low to x_0 + # check same if encoder_net is in distributed data parallel + + if isinstance(encoder_net, SongUNetPosEmbd) or ( + isinstance(encoder_net, nn.parallel.DistributedDataParallel) + and isinstance(encoder_net.module, SongUNetPosEmbd) + ): + x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None) + else: + x_0 = encoder_net(x_low) + + # convert sigma to time + # sampled_sigma = (1-t)*sigma_max + time = 1 - sampled_sigma / sigma_max # this is the time from 1 to 0 + + # we don't subtract x_0, this will be done in the sampler + x_t = ((1 - time) * x_0) + (time * x_1) + if not isinstance(self.sigma_min, float) and len(self.sigma_min) > 1: + sigma_t = ( + sigma_max_per_channel.unsqueeze(0).unsqueeze(2).unsqueeze(2) + * sampled_sigma + ) + else: + sigma_t = sampled_sigma + x_t_noisy = x_t + torch.randn_like(x_0) * sigma_t + + D_x_t = denoiser_net( + x=x_t_noisy, + sigma=sampled_sigma, + condition=x_low, + ) + # time_weight = lambda t: 1/(1 - torch.clamp(t, 0.9)) + # time_weight = lambda t: 1 + def time_weight(t): + return 1 + + sfm_loss = weight * ((time_weight(time) * (D_x_t - x_1)) ** 2) + + if self.encoder_loss_type == "l1": + encoder_loss = F.l1_loss(x_1, x_0, reduction="none") + weighted_encoder_loss = self.encoder_loss_weight * encoder_loss + elif self.encoder_loss_type == "l2": + encoder_loss = F.mse_loss(x_1, x_0, reduction="none") + weighted_encoder_loss = self.encoder_loss_weight * encoder_loss + elif self.encoder_loss_type is None: + encoder_loss = torch.tensor( + 0.0 + ) # This covers the case where there is no encoder_loss + weighted_encoder_loss = torch.tensor(0.0) + + return sfm_loss + weighted_encoder_loss + + +class SFMEncoderLoss: + """ + Loss function corresponding to Stochastic Flow matching for the encoder portion + + Parameters + ---------- + encoder_loss_type: str + Type of loss to use ["l1", "l2", None] + """ + + def __init__(self, encoder_loss_type: str = "l2", **kwargs): + if encoder_loss_type not in ["l1", "l2"]: + raise ValueError( + f"encoder_loss_type should be either l1 or l2 not {encoder_loss_type}" + ) + self.encoder_loss_type = encoder_loss_type + + def __call__( + self, + denoiser_net: torch.nn.Module, + encoder_net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + ): + """ + Calculate the loss for the enoder used in stochastic flow matching + + Parameters + ---------- + models: [torch.Tensor, torch.Tensor] + The denoiser and encoder networks making the predictions + Stored as [denoiser, encoder] + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + Returns + ------- + torch.Tensor + A tensor representing the loss calculated based on the encoder's + predictions + """ + x_1 = img_clean + x_low = img_lr + + if isinstance(encoder_net, SongUNetPosEmbd) or ( + isinstance(encoder_net, nn.parallel.DistributedDataParallel) + and isinstance(encoder_net.module, SongUNetPosEmbd) + ): + x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None) + else: + x_0 = encoder_net(x_low) + + if self.encoder_loss_type == "l1": + encoder_loss = F.l1_loss(x_1, x_0, reduction="none") + elif self.encoder_loss_type == "l2": + encoder_loss = F.mse_loss(x_1, x_0, reduction="none") + + return encoder_loss diff --git a/physicsnemo/models/diffusion/__init__.py b/physicsnemo/models/diffusion/__init__.py index 3984bffd42..64349700f7 100644 --- a/physicsnemo/models/diffusion/__init__.py +++ b/physicsnemo/models/diffusion/__init__.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ruff: noqa + +from .encoders import Conv2dSerializable from .utils import weight_init from .layers import ( AttentionOp, @@ -36,3 +38,4 @@ VEPrecond_dfsr_cond, VEPrecond_dfsr, ) +from .sfm_preconditioning import SFMPrecondSR, SFMPrecondEmpty diff --git a/physicsnemo/models/diffusion/encoders.py b/physicsnemo/models/diffusion/encoders.py new file mode 100644 index 0000000000..001728fdc0 --- /dev/null +++ b/physicsnemo/models/diffusion/encoders.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass + +import nvtx + +import physicsnemo.models.diffusion as diffusion +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + + +@dataclass +class MetaData(ModelMetaData): + name: str = "Conv2dSerializable" + # Optimization + jit: bool = True + cuda_graphs: bool = True + amp_cpu: bool = True + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = True + + +class Conv2dSerializable(Module): + """ + A serializable version of a 2d convolution + + Parameters + ---------- + in_channels: int + Number of input channels + out_channels: int + Number of output channels + kernel_size: int + Size of the convolution kernel + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + ): + super().__init__(meta=MetaData()) + self.out_channels = out_channels + self.in_channels = in_channels + self.kernel_size = kernel_size + self.net = diffusion.Conv2d( + in_channels, + out_channels, + kernel_size, + ) + + @nvtx.annotate(message="Conv2dSerializable", color="blue") + def forward(self, x): + """forward pass""" + return self.net(x) diff --git a/physicsnemo/models/diffusion/sfm_preconditioning.py b/physicsnemo/models/diffusion/sfm_preconditioning.py new file mode 100644 index 0000000000..3143b7e9bb --- /dev/null +++ b/physicsnemo/models/diffusion/sfm_preconditioning.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +from dataclasses import dataclass +from typing import List, Union + +import nvtx +import torch + +from physicsnemo.models.diffusion import ( # noqa: F401 for globals + Conv2dSerializable, + DhariwalUNet, + SongUNet, +) +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + +network_module = importlib.import_module("physicsnemo.models.diffusion") + + +@dataclass +class SFMPrecondSRMetaData(ModelMetaData): + """EDMPrecondSR meta data""" + + name: str = "SFMPrecondSR" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class SFMPrecondSR(Module): + """ + preconditioning based on the Stochastic Flow Model approach + + Parameters + ---------- + img_resolution : Union[List[int], int] + Image resolution. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_max : Union[dict, float] + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "SongUNetPosEmbd". + **model_kwargs : dict + Keyword arguments for the underlying model. + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + img_in_channels: int, + img_out_channels: int, + use_fp16: bool = False, + N_grid_channels: int = 0, + sigma_max: Union[dict, float] = float("inf"), + sigma_data: float = 0.5, + model_type: str = "SongUNetPosEmbd", + use_x_low_conditioning=None, + **model_kwargs, + ) -> None: + Module.__init__(self, meta=SFMPrecondSRMetaData) + model_class = getattr(network_module, model_type) + + self.use_fp16 = use_fp16 + self.sigma_data = sigma_data + self.img_shape_y = img_resolution[0] + self.img_shape_x = img_resolution[1] + self.use_x_low_conditioning = use_x_low_conditioning + + if isinstance(sigma_max, float): + self.sigma_max_current = torch.tensor(sigma_max) + # disable weighted updates, let that be handled externally + self.ema_weight = torch.tensor(0.0) + self.min_values = torch.tensor(0.0) + else: + sigma_max_current = torch.tensor(sigma_max["initial_values"]) + self.register_buffer("sigma_max_current", sigma_max_current) + self.ema_weight = torch.tensor(sigma_max["ema_weight"]) + self.min_values = torch.tensor(sigma_max["min_values"]) + + # SongUNetPosEmbd + if "encoder_type" in model_kwargs: + del model_kwargs["encoder_type"] + lr_channels = img_in_channels if self.use_x_low_conditioning else 0 + self.denoiser_net = model_class( + img_resolution=img_resolution, + in_channels=img_out_channels + lr_channels + N_grid_channels, + out_channels=img_out_channels, + **model_kwargs, + ) + + def update_sigma_max(self, sigma_max: float): + """ + Updates the maximum noise level + + Sigma max is updated externally so any needed accumulation + and reductions can be handled by the training loop + + Parameters + ---------- + sigma_max : float + Maximum noise level + """ + ema_weight = self.ema_weight.to(self.sigma_max_current.device) + sigma_max = torch.tensor(sigma_max).to(self.sigma_max_current.device) + # Update sigma_max_current without gradients + new_sigma_max_current = ( + ema_weight * self.sigma_max_current + (1 - ema_weight) * sigma_max + ) + self.sigma_max_current = torch.max( + new_sigma_max_current, self.min_values.to(self.sigma_max_current.device) + ) + + def get_sigma_max(self): + """returns the current max sigma""" + return self.sigma_max_current + + @nvtx.annotate(message="SFMPrecondSR", color="orange") + def forward( + self, + x: torch.Tensor, + sigma: torch.Tensor, + condition: torch.Tensor, + force_fp32: bool = False, + **model_kwargs, + ): + """ + Forward pass of the Stochastic Flow Model preconditioner + + Parameters + ---------- + x : tensor + The partially noised input image + + sigma : torch.Tensor + The image containing random noise + + condition : torch.Tensor + The low resoltuion + + force_fp32 : bool + Whether float 32 computations should be forced, default False + + model_kwargs: dict + Keyword arguments for the underlying model. + + Returns + ------- + torch.Tensor : the denoised image + + """ + x = x.to(torch.float32) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + if not self.use_x_low_conditioning: + scaled_x = c_in * x + else: + condition = condition.to(torch.float32) + scaled_x = torch.cat([c_in * x, condition], dim=1) + + F_x = self.denoiser_net( + scaled_x.to(dtype), noise_labels=c_noise.flatten(), class_labels=None + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma: Union[List[float], torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation in the + same precision as the model + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +class SFMPrecondEmpty(Module): + """ + A preconditioner that does nothing + + Parameters + ---------- + **model_kwargs : dict + Keyword arguments for the underlying model, not used + """ + + def __init__(self, **kwargs): + super().__init__() + self.param = torch.nn.Parameter(torch.tensor(0.0)) + self.label_dim = None diff --git a/physicsnemo/utils/generative/__init__.py b/physicsnemo/utils/generative/__init__.py index a708ccb3d6..c5017a0219 100644 --- a/physicsnemo/utils/generative/__init__.py +++ b/physicsnemo/utils/generative/__init__.py @@ -15,6 +15,11 @@ # limitations under the License. from .deterministic_sampler import deterministic_sampler +from .sfm_samplers import ( + SFM_encoder_sampler, + SFM_Euler_sampler, + SFM_Euler_sampler_Adaptive_Sigma, +) from .stochastic_sampler import image_batching, image_fuse, stochastic_sampler from .utils import ( EasyDict, diff --git a/physicsnemo/utils/generative/sfm_samplers.py b/physicsnemo/utils/generative/sfm_samplers.py new file mode 100644 index 0000000000..abb87a721a --- /dev/null +++ b/physicsnemo/utils/generative/sfm_samplers.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Dict + +import nvtx +import torch +import torch.nn as nn +from omegaconf import DictConfig + +from physicsnemo.models.diffusion import SongUNetPosEmbd + + +def sigma(t): + return t + + +def sigma_inv(sigma): + return sigma + + +@nvtx.annotate(message="SFM_encoder_sampler", color="red") +def SFM_encoder_sampler( + networks: Dict[str, torch.nn.Module], + img_lr: torch.Tensor, + randn_like: Callable = None, + cfg: DictConfig = None, +): + """ + Sampler for the SFM encoder, just runs the encoder + + networks: Dict + A dictionary containing "encoder_net" and "denoiser_net" entries + for the denoiser and encoder networks. + Note: denoiser_net is not used for SFM_encoder_sampler + img_lr: torch.tensor + The low resolution image used for denoising + randn_like: StackedRandomGenerator + The random noise generator used for denoising. + Note: not used for SFM_encoder_sampler + cfg: DictConfig + The configuration used for sampling + Note: not used for SFM_encoder_sampler + """ + encoder_net = networks["encoder_net"] + x_low = img_lr + # in V1 the encoder net was inside the denoiser + if isinstance(encoder_net, SongUNetPosEmbd) or ( + isinstance(encoder_net, nn.parallel.DistributedDataParallel) + and isinstance(encoder_net.module, SongUNetPosEmbd) + ): + x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None) + else: + x_0 = encoder_net(x_low) # MODULUS + + return x_0 + + +@nvtx.annotate(message="SFM_Euler_sampler", color="red") +def SFM_Euler_sampler( + networks: Dict[str, torch.nn.Module], + img_lr: torch.Tensor, + randn_like: Callable, + cfg: DictConfig, +): + """ + Sampler for the SFM encoder, just runs the encoder + + networks: Dict + A dictionary containing "encoder_net" and "denoiser_net" entries + for the denoiser and encoder networks. + Note: denoiser_net is not used for SFM_encoder_sampler + img_lr: torch.tensor + The low resolution image used for denoising + randn_like: StackedRandomGenerator + The random noise generator used for denoising. + cfg: DictConfig + The configuration used for sampling + """ + denoiser_net = networks["denoiser_net"] + encoder_net = networks["encoder_net"] + + x_low = img_lr + + # Define time steps in terms of noise level. + step_indices = torch.arange(cfg.num_steps, device=denoiser_net.device) + # STATHI TODO: This is a hack, we should treat s_max per channels + sigma_max = ( + denoiser_net.get_sigma_max()[0] + if len(denoiser_net.get_sigma_max().shape) > 0 + else denoiser_net.get_sigma_max() + ) + sigma_steps = ( + sigma_max ** (1 / cfg.rho) + + step_indices + / (cfg.num_steps - 1) + * (cfg.sigma_min ** (1 / cfg.rho) - sigma_max ** (1 / cfg.rho)) + ) ** cfg.rho + + # Define noise level cfg.schedule. + # if cfg.schedule == "linear": + # sigma = lambda t: t + # sigma_inv = lambda sigma: sigma + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(denoiser_net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + + if isinstance(encoder_net, SongUNetPosEmbd) or ( + isinstance(encoder_net, nn.parallel.DistributedDataParallel) + and isinstance(encoder_net.module, SongUNetPosEmbd) + ): + x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None) + else: + x_0 = encoder_net(x_low) # MODULUS + + # x_0 = x_0.to(torch.float64) + x_t = x_0 + sigma_max * randn_like(x_0) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_t_cur = x_t + t_hat = t_cur + x_t_hat = x_t_cur + + # Euler step. + x_t_denoised = denoiser_net(x_t_hat, sigma(t_hat), condition=x_low).to( + torch.float64 + ) + + u_t = (x_t_denoised - x_t_hat) / (torch.clamp(t_hat, min=cfg.t_min)) + + dt = t_hat - t_next # needs to be reversed + x_t = x_t_hat + u_t * dt + + return x_t + + +@nvtx.annotate(message="SFM_Euler_sampler_Adaptive_Sigma", color="red") +def SFM_Euler_sampler_Adaptive_Sigma( + networks: Dict[str, torch.nn.Module], + img_lr: torch.Tensor, + randn_like: Callable, + cfg: DictConfig, +): + """ + Sampler for the SFM encoder, just runs the encoder + + networks: Dict + A dictionary containing "encoder_net" and "denoiser_net" entries + for the denoiser and encoder networks. + Note: denoiser_net is not used for SFM_encoder_sampler + img_lr: torch.tensor + The low resolution image used for denoising + randn_like: StackedRandomGenerator + The random noise generator used for denoising. + cfg: DictConfig + The configuration used for sampling + """ + denoiser_net = networks["denoiser_net"] + encoder_net = networks["encoder_net"] + + x_low = img_lr + + # Define time steps in terms of noise level. + step_indices = torch.arange(cfg.num_steps, device=denoiser_net.device) + sigma_max_adaptive = denoiser_net.get_sigma_max() + + # set sigma_max for sampling purposes to 1.0, this normalizes time [0-1] + sigma_max = 1.0 + sigma_steps = ( + sigma_max ** (1 / cfg.rho) + + step_indices + / (cfg.num_steps - 1) + * (cfg.sigma_min ** (1 / cfg.rho) - sigma_max ** (1 / cfg.rho)) + ) ** cfg.rho + + # Define noise level cfg.schedule. + # if cfg.schedule == "linear": + # sigma = lambda t: t + # # sigma_deriv = lambda t: 1 + # sigma_inv = lambda sigma: sigma + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(denoiser_net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + + if isinstance(encoder_net, SongUNetPosEmbd) or ( + isinstance(encoder_net, nn.parallel.DistributedDataParallel) + and isinstance(encoder_net.module, SongUNetPosEmbd) + ): + x_0 = encoder_net(x_low, noise_labels=torch.tensor([0]), class_labels=None) + else: + x_0 = encoder_net(x_low) # MODULUS + + x_t = x_0 + sigma_max_adaptive.view(1, -1, 1, 1) * randn_like(x_0) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_t_cur = x_t + + t_hat = t_cur + x_t_hat = x_t_cur + + # Euler step. + x_t_denoised = denoiser_net(x_t_hat, sigma(t_hat), condition=x_low).to( + torch.float64 + ) + + u_t = (x_t_denoised - x_t_hat) / (torch.clamp(t_hat, min=cfg.t_min)) + + dt = t_hat - t_next # needs to be reversed + x_t = x_t_hat + u_t * dt + + return x_t diff --git a/test/metrics/diffusion/test_sfm_losses.py b/test/metrics/diffusion/test_sfm_losses.py new file mode 100644 index 0000000000..37eaa0a830 --- /dev/null +++ b/test/metrics/diffusion/test_sfm_losses.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import pytest +import torch + +from physicsnemo.metrics.diffusion import ( + SFMEncoderLoss, + SFMLoss, +) +from physicsnemo.models.diffusion import SongUNetPosEmbd + + +class fake_net(torch.nn.Module): + """dummy class to test sfm encoder""" + + def get_sigma_max(self): + return torch.tensor([1.0]) + + def forward(self, x, *args, **kwargs): + return x + + +def get_songunet(): + """helper that creates a songunet for testing""" + songunet_kwargs = { + "img_resolution": 8, + "in_channels": 2, + "out_channels": 4, + "embedding_type": "zero", + "label_dim": 0, + "encoder_type": "standard", + "decoder_type": "standard", + "channel_mult_noise": 1, + "resample_filter": [1, 1], + "channel_mult": [1, 2, 2], + "attn_resolutions": [28], + "N_grid_channels": 0, + "model_channels": 4, + } + return SongUNetPosEmbd(**songunet_kwargs) + + +def test_sfmloss_initialization(): + """checks SFMLoss __init__""" + + loss_fn = SFMLoss() + + assert loss_fn.encoder_loss_type == "l2" + assert loss_fn.encoder_loss_weight == 0.1 + assert loss_fn.sigma_min == 0.002 + assert loss_fn.sigma_data == 0.5 + + loss_fn = SFMLoss( + encoder_loss_type="l1", + encoder_loss_weight=[0.1, 0.2], + sigma_min=5e-4, + sigma_data=0.1, + ) + + assert loss_fn.encoder_loss_type == "l1" + assert loss_fn.encoder_loss_weight == [0.1, 0.2] + assert loss_fn.sigma_min == 5e-4 + assert loss_fn.sigma_data == 0.1 + + # test for invalid loss type + with pytest.raises( + ValueError, + match=re.escape( + "encoder_loss_type should be one of ['l1', 'l2', None] not bogus" + ), + ): + loss_fn = SFMLoss( + encoder_loss_type="bogus", + ) + + +def test_sfmloss_call(): + """checks SFMLoss __call__""" + + # dummy network for loss + dummy_denoiser = fake_net() + dummy_encoder = get_songunet() + dummy_net = torch.nn.Identity() + + image_zeros = torch.zeros((2, 4, 8, 8)) + image_rnd = torch.rand((2, 2, 8, 8)) + + # test defaults, encoder l2 loss, sigma_min is float + loss_fn = SFMLoss() + + loss_value = loss_fn(dummy_denoiser, dummy_net, image_zeros, image_zeros) + assert isinstance(loss_value, torch.Tensor) + + loss_value = loss_fn(dummy_denoiser, dummy_net, image_zeros, image_zeros) + assert isinstance(loss_value, torch.Tensor) + + # test encoder l1 loss, sigma_min is list + loss_fn = SFMLoss(encoder_loss_type="l1", sigma_min=[0.001, 0.001, 0.001, 0.001]) + loss_value = loss_fn(dummy_denoiser, dummy_net, image_zeros, image_zeros) + assert isinstance(loss_value, torch.Tensor) + + # test no encoder loss, sigma_min is list + loss_fn = SFMLoss(encoder_loss_type=None) + loss_value = loss_fn(dummy_denoiser, dummy_net, image_zeros, image_zeros) + assert isinstance(loss_value, torch.Tensor) + + # testing using songunet and no encoder loss + loss_value = loss_fn(dummy_denoiser, dummy_encoder, image_zeros, image_rnd) + assert isinstance(loss_value, torch.Tensor) + + +def test_sfmencoderloss_initialization(): + """checks SFMEncoderLoss __init__""" + loss_fn = SFMEncoderLoss() + + assert loss_fn.encoder_loss_type == "l2" + + # test for invalid loss type + with pytest.raises( + ValueError, + match="encoder_loss_type should be either l1 or l2 not bogus", + ): + loss_fn = SFMEncoderLoss( + encoder_loss_type="bogus", + ) + + +def test_sfmencoderloss_call(): + """checks SFMEncoderLoss __call__""" + # dummy network for loss + dummy_net = torch.nn.Identity() + dummy_encoder = get_songunet() + + image_zeros = torch.zeros((2, 4, 8, 8)) + image_twos = torch.ones((2, 4, 8, 8)) * 2 + image_rnd = torch.rand((2, 2, 8, 8)) + + # test l2 loss + loss_fn = SFMEncoderLoss() + + # encoder loss is deterministic + loss_value = loss_fn(dummy_net, dummy_net, image_twos, image_twos) + assert torch.equal(loss_value, image_zeros) + + loss_value = loss_fn(dummy_net, dummy_net, image_zeros, image_twos) + assert torch.equal(loss_value, image_twos * image_twos) + + # test l1 loss + loss_fn = SFMEncoderLoss("l1") + + # encoder loss is deterministic + loss_value = loss_fn(dummy_net, dummy_net, image_twos, image_twos) + assert torch.equal(loss_value, image_zeros) + + loss_value = loss_fn(dummy_net, dummy_net, image_zeros, image_twos) + assert torch.equal(loss_value, image_twos) + + # using songunet + loss_value = loss_fn(dummy_net, dummy_encoder, image_twos, image_rnd) + assert isinstance(loss_value, torch.Tensor) diff --git a/test/models/diffusion/test_encoders.py b/test/models/diffusion/test_encoders.py new file mode 100644 index 0000000000..52a5173df3 --- /dev/null +++ b/test/models/diffusion/test_encoders.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from physicsnemo.models.diffusion import Conv2dSerializable + + +def test_conv2dserializable_initialization(): + conv_net = Conv2dSerializable(in_channels=8, out_channels=4, kernel_size=3) + + assert conv_net.in_channels == 8 + assert conv_net.out_channels == 4 + assert conv_net.kernel_size == 3 + + +def test_conv2dserializable_forward(): + test_data = torch.zeros(1, 8, 16, 16) + conv_net = Conv2dSerializable(in_channels=8, out_channels=4, kernel_size=3) + + out_tensor = conv_net(test_data) + assert out_tensor.shape == (1, 4, 16, 16) diff --git a/test/models/diffusion/test_sfm_preconditioning.py b/test/models/diffusion/test_sfm_preconditioning.py new file mode 100644 index 0000000000..d8a7279e78 --- /dev/null +++ b/test/models/diffusion/test_sfm_preconditioning.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from physicsnemo.models.diffusion import SFMPrecondEmpty, SFMPrecondSR + + +# SFMPrecondEmpty test +def test_sfmprecondempty_initialzation(): + """checks SFMPrecondEmpty __init__""" + precond = SFMPrecondEmpty() + + assert isinstance(precond, SFMPrecondEmpty) + assert precond.label_dim is None + assert precond.param == torch.tensor(0.0) + + +# SFMPrecondSR tests +def test_sfmprecondsr_initialization(): + """checks SFMPrecondSR __init__""" + precond = SFMPrecondSR( + img_resolution=[32, 32], + img_in_channels=4, + img_out_channels=6, + sigma_data=0.3, + ) + + assert isinstance(precond, SFMPrecondSR) + assert precond.img_shape_y == 32 + assert precond.img_shape_x == 32 + assert precond.sigma_data == 0.3 + + # test with sigma_max as a dict + sigma_max = { + "initial_values": [0.5, 0.5, 0.5], + "ema_weight": 0.2, + "min_values": 0.1, + } + + precond = SFMPrecondSR( + img_resolution=[32, 32], + img_in_channels=4, + img_out_channels=6, + sigma_max=sigma_max, + sigma_data=0.3, + encoder_type="l1", + ) + + assert isinstance(precond, SFMPrecondSR) + assert precond.img_shape_y == 32 + assert precond.img_shape_x == 32 + assert torch.equal(precond.sigma_max_current, torch.Tensor([0.5, 0.5, 0.5])) + assert torch.equal(precond.ema_weight, torch.tensor(0.2)) + assert torch.equal(precond.min_values, torch.tensor(0.1)) + assert precond.sigma_data == 0.3 + + +# SFMPrecondSR tests +def test_sfmprecondsr_sigma(): + """checks update_sigma_max and get_sigma_max""" + + precond = SFMPrecondSR( + img_resolution=[32, 32], + img_in_channels=4, + img_out_channels=6, + sigma_data=0.3, + sigma_max=8.0, + ) + + assert precond.get_sigma_max() == 8.0 + + precond.update_sigma_max(16.0) + assert precond.sigma_max_current == 16.0 + + # test with ema as a dict + ema_weight = 0.5 + init_vals = torch.tensor([16.0, 1.0, 0.5]) + updated_vals = torch.tensor([8.0, 2.0, 2.0]) + expected_vals = init_vals * 0.5 + updated_vals * 0.5 + + sigma_max = { + "initial_values": init_vals.tolist(), + "ema_weight": ema_weight, + "min_values": 0.1, + } + + precond = SFMPrecondSR( + img_resolution=[32, 32], + img_in_channels=4, + img_out_channels=6, + sigma_max=sigma_max, + sigma_data=0.3, + encoder_type="l1", + ) + + assert torch.equal(precond.sigma_max_current, init_vals) + + precond.update_sigma_max(updated_vals.tolist()) + assert torch.equal(precond.sigma_max_current, expected_vals) + + assert torch.equal(precond.round_sigma(init_vals.tolist()), init_vals) + + +def test_sfmprecondsr_forward(): + """checks forward""" + + image_in = torch.zeros((2, 4, 32, 32)) + image_cond = torch.zeros((2, 6, 32, 32)) + sigma = torch.rand((2, 1, 1, 1)) + + precond = SFMPrecondSR( + img_resolution=[32, 32], + img_in_channels=6, + img_out_channels=4, + N_grid_channels=4, + ) + + out_val = precond(image_in, sigma=sigma, condition=None) + assert isinstance(out_val, torch.Tensor) + + precond = SFMPrecondSR( + img_resolution=[32, 32], + img_in_channels=6, + img_out_channels=4, + N_grid_channels=4, + use_x_low_conditioning=True, + ) + + out_val = precond(image_in, sigma=sigma, condition=image_cond) + assert isinstance(out_val, torch.Tensor) diff --git a/test/utils/generative/test_sfm_sampler.py b/test/utils/generative/test_sfm_sampler.py new file mode 100644 index 0000000000..abee9ffdcf --- /dev/null +++ b/test/utils/generative/test_sfm_sampler.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from omegaconf import DictConfig + +from physicsnemo.models.diffusion import SFMPrecondSR, SongUNetPosEmbd +from physicsnemo.utils.generative import ( + SFM_encoder_sampler, + SFM_Euler_sampler, + SFM_Euler_sampler_Adaptive_Sigma, + StackedRandomGenerator, +) + + +def get_songunet(): + """helper that creates a songunet for testing""" + songunet_kwargs = { + "img_resolution": 32, + "in_channels": 6, + "out_channels": 6, + "embedding_type": "zero", + "label_dim": 0, + "encoder_type": "standard", + "decoder_type": "standard", + "channel_mult_noise": 1, + "resample_filter": [1, 1], + "channel_mult": [1, 2, 2], + "attn_resolutions": [28], + "N_grid_channels": 0, + "model_channels": 4, + } + return SongUNetPosEmbd(**songunet_kwargs) + + +class fake_net(torch.nn.Module): + """dummy class to test sfm encoder""" + + def __init__(self, sigma_max=[0.5]): + self.sigma_max = torch.tensor(sigma_max) + + def get_sigma_max(self): + return self.sigma_max + + def round_sigma(self, x): + return torch.tensor(x) + + def forward(self, x, *args, **kwargs): + return x + + +def test_sfm_encoder_sampler(): + """SFM_encoder_sampler""" + dummy_net = torch.nn.Identity() + encoder = get_songunet() + + image_rnd = torch.rand((2, 6, 8, 8)) + + networks = { + "encoder_net": dummy_net, + "denoiser_net": None, + } + + out_val = SFM_encoder_sampler(networks, image_rnd) + assert torch.equal(out_val, image_rnd) + + networks = { + "encoder_net": encoder, + "denoiser_net": None, + } + out_val = SFM_encoder_sampler(networks, image_rnd) + assert isinstance(out_val, torch.Tensor) + + +def test_sfm_euler_sampler(): + """SFM_Euler_sampler""" + encoder = get_songunet() + denoiser = SFMPrecondSR( + img_resolution=[32, 32], + img_in_channels=4, + img_out_channels=6, + N_grid_channels=4, + ).to("cpu") + dummy_net = torch.nn.Identity() + + batch_seeds = torch.as_tensor([1, 2]).to("cpu") + image_rnd = torch.rand((2, 6, 32, 32)).to("cpu") + rnd = StackedRandomGenerator(image_rnd.device, batch_seeds) + + cfg = {"rho": 7, "num_steps": 5, "sigma_min": 0.01, "t_min": 0.002} + cfg = DictConfig(cfg) + + networks = { + "encoder_net": dummy_net, + "denoiser_net": denoiser, + } + + out_val = SFM_Euler_sampler( + networks=networks, img_lr=image_rnd, randn_like=rnd.randn_like, cfg=cfg + ) + assert isinstance(out_val, torch.Tensor) + + networks = { + "encoder_net": encoder, + "denoiser_net": denoiser, + } + + out_val = SFM_Euler_sampler( + networks=networks, img_lr=image_rnd, randn_like=rnd.randn_like, cfg=cfg + ) + assert isinstance(out_val, torch.Tensor) + + +def test_sfm_euler_sampler_adaptive_sigma(): + """SFM_Euler_sampler""" + encoder = get_songunet() + denoiser = SFMPrecondSR( + img_resolution=[32, 32], + img_in_channels=4, + img_out_channels=6, + N_grid_channels=4, + ).to("cpu") + dummy_net = torch.nn.Identity() + + batch_seeds = torch.as_tensor([1, 2]).to("cpu") + image_rnd = torch.rand((2, 6, 32, 32)).to("cpu") + rnd = StackedRandomGenerator(image_rnd.device, batch_seeds) + + cfg = {"rho": 7, "num_steps": 5, "sigma_min": 0.01, "t_min": 0.002} + cfg = DictConfig(cfg) + + networks = { + "encoder_net": dummy_net, + "denoiser_net": denoiser, + } + + out_val = SFM_Euler_sampler_Adaptive_Sigma( + networks=networks, img_lr=image_rnd, randn_like=rnd.randn_like, cfg=cfg + ) + assert isinstance(out_val, torch.Tensor) + + networks = { + "encoder_net": encoder, + "denoiser_net": denoiser, + } + + out_val = SFM_Euler_sampler_Adaptive_Sigma( + networks=networks, img_lr=image_rnd, randn_like=rnd.randn_like, cfg=cfg + ) + assert isinstance(out_val, torch.Tensor)