Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/weather/stormcast/config/dataset/mock.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ name: mock.MockDataset

num_state_channels: 3
num_background_channels: 4
num_invariant_channels: 2
num_scalar_cond_channels: 2
image_size: [256, 128]
num_samples: 100
Expand Down
3 changes: 2 additions & 1 deletion examples/weather/stormcast/config/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ model:
use_regression_net: True
regression_weights: "stormcast_checkpoints/regression/StormCastUNet.0.0.mdlus"
previous_step_conditioning: True
spatial_pos_embed: True
hyperparameters:
additive_pos_embed: True

training:
loss:
Expand Down
2 changes: 1 addition & 1 deletion examples/weather/stormcast/config/model/stormcast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ regression_conditions: ["state", "background", "invariant"] # list consisting of
diffusion_conditions: ["state", "regression", "invariant"] # list consisting of "state", "regression", "background", "invariant" (default from StormCast paper)

# Model hyperparameters
spatial_pos_embed: False # Whether or not to add an additive position embed after the first conv in the U-Net
model_type: "SongUNetPosEmbd" # Model class to use
# Example overrides for architecture hyperparameters (uncomment to customize)
# hyperparameters:
Expand All @@ -33,6 +32,7 @@ model_type: "SongUNetPosEmbd" # Model class to use
# attn_resolutions: [] # Internal resolutions within the U-Net to apply self-attention
# bottleneck_attention: false
# checkpoint_level: 0
# additive_pos_embed: false # Whether or not to add an additive position embed after the first conv in the U-Net

# Pretrained regression model
regression_weights: "stormcast_checkpoints/regression/UNet.0.0.mdlus" # Path to pretrained regression network,
Expand Down
16 changes: 16 additions & 0 deletions examples/weather/stormcast/datasets/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
self,
num_state_channels: int = 3,
num_background_channels: int = 4,
num_invariant_channels: int = 2,
num_scalar_cond_channels: int = 2,
image_size: tuple[int, int] = (256, 128),
num_samples: int = 100,
Expand All @@ -45,6 +46,7 @@ def __init__(
):
self._num_state_channels = num_state_channels
self._num_background_channels = num_background_channels
self._num_invariant_channels = num_invariant_channels
self._num_scalar_cond_channels = num_scalar_cond_channels
self._image_size = image_size
self._num_samples = num_samples
Expand Down Expand Up @@ -111,6 +113,20 @@ def image_shape(self) -> tuple[int, int]:
"""Return the (height, width) of the data."""
return self._image_size

def get_invariants(self) -> np.ndarray | None:
"""Return invariants used for training."""
if self._num_invariant_channels > 0:
rng = np.random.default_rng(seed=42)
return rng.normal(
size=(
self._num_invariant_channels,
self._image_size[0],
self._image_size[1],
)
).astype(np.float32)
else:
return None


class MockDataset(_MockDataset):
def __init__(self, params, train):
Expand Down
14 changes: 10 additions & 4 deletions examples/weather/stormcast/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def test_checkpoint_integrity(
"model_type", ["hybrid", "nowcasting", "downscaling", "unconditional"]
)
@pytest.mark.parametrize("num_scalar_cond_channels", [0, 2])
@pytest.mark.parametrize("num_invariant_channels", [0, 2])
def test_model_types(
tmp_path: Path,
cfg_diffusion: DictConfig,
Expand All @@ -352,6 +353,7 @@ def test_model_types(
net_architecture: Literal["unet", "dit"],
model_type: Literal["hybrid", "nowcasting", "downscaling", "unconditional"],
num_scalar_cond_channels: int,
num_invariant_channels: int,
):
"""Test that training runs with different model configurations."""
dist = DistributedManager()
Expand All @@ -371,20 +373,24 @@ def test_model_types(
cfg_diffusion.training.rundir = rundir
cfg_diffusion.dataset.model_type = model_type
cfg_diffusion.dataset.num_scalar_cond_channels = num_scalar_cond_channels
cfg_diffusion.dataset.num_invariant_channels = num_invariant_channels

if model_type == "hybrid":
cfg_diffusion.model.diffusion_conditions = ["state", "background", "invariant"]
cfg_diffusion.model.diffusion_conditions = ["state", "background"]
elif model_type == "nowcasting":
cfg_diffusion.model.diffusion_conditions = ["state", "invariant"]
cfg_diffusion.model.diffusion_conditions = ["state"]
elif model_type == "downscaling":
cfg_diffusion.model.diffusion_conditions = ["background", "invariant"]
cfg_diffusion.model.diffusion_conditions = ["background"]
elif model_type == "unconditional":
cfg_diffusion.model.diffusion_conditions = ["invariant"]
cfg_diffusion.model.diffusion_conditions = []
else:
raise ValueError(
"Model_type must be one of ['hybrid', 'nowcasting', 'downscaling', 'unconditional']."
)

if num_invariant_channels > 0:
cfg_diffusion.model.diffusion_conditions.append("invariant")

unsupported_scalar_conds = (
num_scalar_cond_channels > 0 and net_architecture != "dit"
)
Expand Down
13 changes: 4 additions & 9 deletions examples/weather/stormcast/utils/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ def get_preconditioned_unet(
name: str,
target_channels: int,
conditional_channels: int = 0,
spatial_embedding: bool = True,
img_resolution: tuple = (512, 640),
model_type: str | None = None,
channel_mult: list = [1, 2, 2, 2, 2],
attn_resolutions: list = [],
lead_time_steps: int = 0,
lead_time_channels: int = 4,
amp_mode: bool = False,
use_apex_gn: bool = False,
**model_kwargs,
) -> EDMPrecond | StormCastUNet:
"""
Expand All @@ -54,13 +52,12 @@ def get_preconditioned_unet(
name: 'regression' or 'diffusion' to select between either model type
target_channels: The number of channels in the target
conditional_channels: The number of channels in the conditioning
spatial_embedding: whether or not to use the additive spatial embedding in the U-Net
img_resolution: resolution of the data (U-Net inputs/outputs)
model_type: the model class to use, or None to select it automatically
channel_mult: the channel multipliers for the different levels of the U-Net
attn_resolutions: resolution of internal U-Net stages to use self-attention
lead_time_steps: the number of possible lead time steps, if 0 lead time embedding will be disabled
lead_time_channels: the number of channels to use for each lead time embedding
amp_mode: whether to use automatic mixed precision
use_apex_gn: whether to use Apex GroupNorm
Returns:
EDMPrecond or StormCastUNet: a wrapped torch module net(x+n, sigma, condition, class_labels) -> x
"""
Expand All @@ -72,10 +69,8 @@ def get_preconditioned_unet(
"img_resolution": img_resolution,
"img_out_channels": target_channels,
"model_type": model_type,
"channel_mult": channel_mult,
"attn_resolutions": attn_resolutions,
"additive_pos_embed": spatial_embedding,
"amp_mode": amp_mode,
"use_apex_gn": use_apex_gn,
}
model_params.update(model_kwargs)

Expand Down
7 changes: 5 additions & 2 deletions examples/weather/stormcast/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,11 @@ def distribute_tensor(self, x: torch.Tensor) -> ShardTensor:
ShardTensor
Sharded or replicated tensor on domain mesh.
"""
source_rank = self.get_domain_group_zero_rank()
return self.nested_scatter(x, source_rank)
if self.use_shard_tensor:
source_rank = self.get_domain_group_zero_rank()
return self.nested_scatter(x, source_rank)
else:
return x

def distribute_model(self, model: torch.nn.Module) -> FSDP:
"""Shard model parameters across the domain mesh and wrap with FSDP.
Expand Down
10 changes: 8 additions & 2 deletions examples/weather/stormcast/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,13 @@ def _setup_data(self):
)
else:
self.invariant_tensor = None
if (
"invariant" in self.cfg.model.diffusion_conditions
or "invariant" in self.cfg.model.regression_conditions
):
self.logger.info(
"Invariant conditions specified in model configuration, but dataset provides no invariants. Ignoring invariant conditions."
)
Comment on lines +350 to +352
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use warning level for misconfiguration log

When invariant conditions are specified in the model config but the dataset provides no invariants, the mismatch is silently logged at info level. This is a configuration issue the user should be explicitly alerted to — the specified conditions are being silently ignored, which could produce surprising results (e.g., a diffusion model trained without invariants while the config claims it uses them). A warning would be more appropriate here.

Suggested change
self.logger.info(
"Invariant conditions specified in model configuration, but dataset provides no invariants. Ignoring invariant conditions."
)
self.logger.warning(
"Invariant conditions specified in model configuration, but dataset provides no invariants. Ignoring invariant conditions."
)

Additionally, note that self.condition_list is not updated to remove "invariant" after this point. The subsequent "Model conditions" log in _setup_model will still include "invariant" as an active condition, which contradicts the "Ignoring invariant conditions" message and may mislead users inspecting the logs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albertocarpentieri this suggestion seems sensible to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning is not currently supported by the logger, I can add it tho


if (
self.cfg.model.architecture != "dit"
Expand Down Expand Up @@ -394,10 +401,9 @@ def _setup_model(self) -> Module:
img_resolution=self.dataset_train.image_shape(),
target_channels=len(self.state_channels),
conditional_channels=num_condition_channels,
spatial_embedding=model_cfg.spatial_pos_embed,
attn_resolutions=model_cfg.attn_resolutions,
lead_time_steps=self.lead_time_steps,
amp_mode=self.enable_amp,
use_apex_gn=self.use_apex_gn,
**model_cfg.hyperparameters,
)
elif model_cfg.architecture == "dit":
Expand Down
Loading