diff --git a/CHANGELOG.md b/CHANGELOG.md index 1dd4f90809..a2acbbf312 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Safe API to override `__init__`'s arguments saved in checkpoint file with `Module.from_checkpoint("chkpt.mdlus", models_args)`. - PyTorch Geometric MeshGraphNet backend. +- Non-regression tests for CorrDiff train/generate. ### Changed diff --git a/physicsnemo/models/diffusion/preconditioning.py b/physicsnemo/models/diffusion/preconditioning.py index 269b6abb1c..1229026469 100644 --- a/physicsnemo/models/diffusion/preconditioning.py +++ b/physicsnemo/models/diffusion/preconditioning.py @@ -22,7 +22,7 @@ import importlib import warnings from dataclasses import dataclass -from typing import List, Literal, Tuple, Union +from typing import Any, List, Literal, Tuple, Union import numpy as np import torch @@ -797,7 +797,7 @@ def __init__( sigma_data: float = 0.5, sigma_min=0.0, sigma_max=float("inf"), - **model_kwargs: dict, + **model_kwargs: Any, ): super().__init__(meta=EDMPrecondSuperResolutionMetaData) diff --git a/physicsnemo/models/diffusion/unet.py b/physicsnemo/models/diffusion/unet.py index 10e469f51c..4e4954f5e1 100644 --- a/physicsnemo/models/diffusion/unet.py +++ b/physicsnemo/models/diffusion/unet.py @@ -171,7 +171,7 @@ def __init__( model_type: Literal[ "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" ] = "SongUNetPosEmbd", - **model_kwargs: dict, + **model_kwargs: Any, ): super().__init__(meta=MetaData) diff --git a/test/models/diffusion/test_corrdiff_train.py b/test/models/diffusion/test_corrdiff_train.py new file mode 100644 index 0000000000..cd44094b78 --- /dev/null +++ b/test/models/diffusion/test_corrdiff_train.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest +import torch +from utils_models import ( + setup_model_lt_aware_ce_regression, + setup_model_lt_aware_patched_diffusion, +) +from validate_utils import validate_accuracy + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_corrdiff_lt_aware_ce_regression(device): + """ + Non-regression test for lead-time aware regression model (with CE loss) in + a CorrDiff training loop. Compares results from v1.0.1 checkpoint with + current model. + """ + + from physicsnemo.metrics.diffusion import RegressionLossCE + from physicsnemo.models import Module + + reg_model_chkpt = Module.from_checkpoint( + str( + Path(__file__).parents[1].resolve() + / Path("data") + / Path("corrdiff_lt_aware_ce_regression_UNet.mdlus") + ) + ).to(device) + reg_model = setup_model_lt_aware_ce_regression() + loss_fn = RegressionLossCE(prob_channels=[3]) + + # Generate data + torch.manual_seed(0) + B, H, W = 4, 48, 32 + C_lr, C_hr = 3, 4 + lead_time_steps = 3 + x_lr = torch.randn(B, C_lr, H, W).to(device) + x_hr = torch.randn(B, C_hr, H, W).to(device) + lead_time_label = torch.randint(0, lead_time_steps, (B,)).to(device) + + # Compute loss with v1.0.1 checkpoint + loss_chkpt = loss_fn(reg_model_chkpt, x_hr, x_lr, lead_time_label=lead_time_label) + assert validate_accuracy( + loss_chkpt, file_name="lt_aware_ce_regression_loss_v1.1.1.pth" + ) + loss_chkpt = loss_chkpt.sum() / B + loss_chkpt.backward() + + # Compute loss with current model + loss = loss_fn(reg_model, x_hr, x_lr, lead_time_label=lead_time_label) + assert validate_accuracy(loss, file_name="lt_aware_ce_regression_loss_v1.1.1.pth") + loss = loss.sum() / B + loss.backward() + + return + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_corrdiff_lt_aware_patched_diffusion(device): + """ + Non-regression test for lead-time aware patched diffusion model with + ResidualLoss in a CorrDiff training loop. Compares results from v1.0.1 + checkpoint with current model. + """ + + from physicsnemo.metrics.diffusion import ResidualLoss + from physicsnemo.models import Module + from physicsnemo.utils.patching import RandomPatching2D + + patching = RandomPatching2D(img_shape=(48, 32), patch_shape=(12, 12), patch_num=5) + patch_nums_iter = [2, 2, 1] + + reg_model_chkpt = Module.from_checkpoint( + str( + Path(__file__).parents[1].resolve() + / Path("data") + / Path("corrdiff_lt_aware_ce_regression_UNet.mdlus") + ) + ).to(device) + reg_model = setup_model_lt_aware_ce_regression() + + diff_model_chkpt = Module.from_checkpoint( + str( + Path(__file__).parents[1].resolve() + / Path("data") + / Path("corrdiff_lt_aware_patched_diffusion_EDMPrecondSR.mdlus") + ) + ).to(device) + diff_model = setup_model_lt_aware_patched_diffusion() + + loss_fn_chkpt = ResidualLoss( + regression_net=reg_model_chkpt, hr_mean_conditioning=True + ) + loss_fn = ResidualLoss(regression_net=reg_model, hr_mean_conditioning=True) + + # Generate data + torch.manual_seed(0) + B, H, W = 4, 48, 32 + C_lr, C_hr = 3, 4 + lead_time_steps = 3 + x_lr = torch.randn(B, C_lr, H, W).to(device) + x_hr = torch.randn(B, C_hr, H, W).to(device) + lead_time_label = torch.randint(0, lead_time_steps, (B,)).to(device) + + # Compute loss with v1.0.1 checkpoint + loss_fn_chkpt.y_mean = None + for patch_num_per_iter in patch_nums_iter: + patching.set_patch_num(patch_num_per_iter) + loss_chkpt = loss_fn_chkpt( + diff_model_chkpt, + x_hr, + x_lr, + patching=patching, + lead_time_label=lead_time_label, + use_patch_grad_acc=True, + ) + assert validate_accuracy( + loss_chkpt, + file_name=f"lt_aware_patched_diffusion_loss_iter_{patch_num_per_iter}_v1.1.1.pth", + ) + loss_chkpt = loss_chkpt.sum() / B + loss_chkpt.backward() + + # Compute loss with current model + loss_fn.y_mean = None + for patch_num_per_iter in patch_nums_iter: + patching.set_patch_num(patch_num_per_iter) + loss = loss_fn( + diff_model, + x_hr, + x_lr, + patching=patching, + lead_time_label=lead_time_label, + use_patch_grad_acc=True, + ) + assert validate_accuracy( + loss, + file_name=f"lt_aware_patched_diffusion_loss_iter_{patch_num_per_iter}_v1.1.1.pth", + ) + loss = loss.sum() / B + loss.backward() + + return diff --git a/test/models/diffusion/utils_models.py b/test/models/diffusion/utils_models.py new file mode 100644 index 0000000000..130dbd2cd0 --- /dev/null +++ b/test/models/diffusion/utils_models.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: E402 + +import torch + + +def setup_model_lt_aware_ce_regression(): + + from physicsnemo.models.diffusion import UNet + + torch.manual_seed(0) + H, W = 48, 32 + C_lr, C_hr = 3, 4 + N_grid_channels, lead_time_channels = 4, 7 + + model = UNet( + img_resolution=(H, W), + img_in_channels=C_lr + N_grid_channels + lead_time_channels, + img_out_channels=C_hr, + model_type="SongUNetPosLtEmbd", + model_channels=16, + channel_mult=[1, 2, 2], + channel_mult_emb=2, + num_blocks=2, + attn_resolutions=[8], + N_grid_channels=N_grid_channels, + embedding_type="zero", + lead_time_channels=lead_time_channels, + lead_time_steps=3, + prob_channels=[3], + ) + return model + + +def setup_model_lt_aware_patched_diffusion(): + + from physicsnemo.models.diffusion import EDMPrecondSuperResolution + + torch.manual_seed(0) + H, W = 48, 32 + C_lr, C_hr = 3, 4 + N_grid_channels, lead_time_channels = 6, 7 + + model = EDMPrecondSuperResolution( + img_resolution=(H, W), + img_in_channels=2 * C_lr + N_grid_channels + lead_time_channels + C_hr, + img_out_channels=C_hr, + model_type="SongUNetPosLtEmbd", + model_channels=16, + channel_mult=[1, 2, 2], + channel_mult_emb=2, + num_blocks=2, + attn_resolutions=[8], + N_grid_channels=N_grid_channels, + gridtype="learnable", + lead_time_channels=lead_time_channels, + lead_time_steps=3, + prob_channels=[3], + ) + return model diff --git a/test/models/diffusion/validate_utils.py b/test/models/diffusion/validate_utils.py new file mode 100644 index 0000000000..a80daec657 --- /dev/null +++ b/test/models/diffusion/validate_utils.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import Tuple, Union + +import torch + +Tensor = torch.Tensor +logger = logging.getLogger("__name__") + + +def compare_output( + output_1: Union[Tensor, Tuple[Tensor, ...]], + output_2: Union[Tensor, Tuple[Tensor, ...]], + rtol: float = 1e-5, + atol: float = 1e-5, +) -> bool: + """Compares model outputs and returns if they are the same + + Parameters + ---------- + output_1 : Union[Tensor, Tuple[Tensor, ...]] + Output one + output_2 : Union[Tensor, Tuple[Tensor, ...]] + Output two + rtol : float, optional + Relative tolerance of error allowed, by default 1e-5 + atol : float, optional + Absolute tolerance of error allowed, by default 1e-5 + + Returns + ------- + bool + If outputs are the same + """ + # Output of tensor + if isinstance(output_1, Tensor) and isinstance(output_2, Tensor): + return torch.allclose(output_1, output_2, rtol, atol) + # Output of tuple of tensors + elif isinstance(output_1, tuple): + # Loop through tuple of outputs + for i, (out_1, out_2) in enumerate(zip(output_1, output_2)): + # If tensor use allclose + if isinstance(out_1, Tensor): + if not torch.allclose(out_1, out_2, rtol, atol): + logger.warning(f"Failed comparison between outputs {i}") + logger.warning( + f"Max Difference: {torch.amax(torch.abs(out_1 - out_2))}" + ) + logger.warning(f"Difference: {out_1 - out_2}") + return False + # Otherwise assume primative + else: + if not out_1 == out_2: + return False + # Unsupported output type + else: + logger.error( + "Model returned invalid type for unit test, should be Tensor or Tuple[Tensor]" + ) + return False + + return True + + +def save_output(output: Union[Tensor, Tuple[Tensor, ...]], file_name: Path): + """Saves output of model to file + + Parameters + ---------- + output : Union[Tensor, Tuple[Tensor, ...]] + Output from netwrok model + file_name : Path + File path + + Raises + ------ + IOError + If file path has a parent directory that does not exist + ValueError + If model outputs are larger than 10mb + """ + if not file_name.parent.is_dir(): + raise IOError( + f"Folder path, {file_name.parent}, for output accuracy data not found" + ) + + # Check size of outputs + output_size = 0 + for out_tensor in output: + out_tensor = out_tensor.detach().contiguous().cpu() + output_size += out_tensor.element_size() * out_tensor.nelement() + + if output_size > 10**7: + raise ValueError( + "Outputs are greater than 10mb which is too large for this test" + ) + + output_dict = {i: data.detach().contiguous().cpu() for i, data in enumerate(output)} + torch.save(output_dict, file_name) + + +@torch.no_grad() +def validate_accuracy( + output: Tensor, + rtol: float = 1e-3, + atol: float = 1e-3, + file_name: Union[str, None] = None, +) -> bool: + """Validates the accuracy of a tensor with a reference output + + Parameters + ---------- + output : Tensor + Output tensor + rtol : float, optional + Relative tolerance of error allowed, by default 1e-3 + atol : float, optional + Absolute tolerance of error allowed, by default 1e-3 + file_name : Union[str, None], optional + Override the default file name of the stored target output, by default None + + Returns + ------- + bool + Test passed + + Raises + ------ + IOError + Target output tensor file for this model was not found + """ + # File name / path + # Output files should live in test/utils/data + + # Always use tuples for this comparison / saving + if isinstance(output, Tensor): + device = output.device + output = (output,) + else: + device = output[0].device + + file_name = ( + Path(__file__).parents[1].resolve() / Path("data") / Path(file_name.lower()) + ) + # If file does not exist, we will create it then error + # Model should then reproduce it on next pytest run + if not file_name.exists(): + save_output(output, file_name) + raise IOError( + f"Output check file {str(file_name)} wasn't found so one was created. Please re-run the test." + ) + # Load tensor dictionary and check + else: + tensor_dict = torch.load(str(file_name)) + if isinstance(tensor_dict, dict): + output_target = tuple([value.to(device) for value in tensor_dict.values()]) + elif isinstance(tensor_dict, Tensor): + output_target = tuple(tensor_dict.to(device)) + else: + raise ValueError(f"Invalid tensor dictionary: {tensor_dict}") + + return compare_output(output, output_target, rtol, atol)