Skip to content

Fix Stormcast example#1490

Open
albertocarpentieri wants to merge 2 commits intoNVIDIA:mainfrom
albertocarpentieri:acarpentieri/stormcast
Open

Fix Stormcast example#1490
albertocarpentieri wants to merge 2 commits intoNVIDIA:mainfrom
albertocarpentieri:acarpentieri/stormcast

Conversation

@albertocarpentieri
Copy link
Contributor

Fix Stormcast example

Description

  • Added a safety check in examples/weather/stormcast/utils/parallel.py to skip nested_scatter when use_shard_tensor=False, returning tensors unchanged instead of sharding unnecessarily.
  • Refactored UNet wiring in examples/weather/stormcast/utils/nn.py and examples/weather/stormcast/utils/trainer.py to rely on model.hyperparameters (rather than duplicated top-level fields) and to pass use_apex_gn explicitly.
  • Updated StormCast configs to align with the new hyperparameter layout
  • Expanded examples/weather/stormcast/test_training.py coverage to parametrize num_invariant_channels and ensure diffusion condition lists only include "invariant" when invariants are actually provided

Checklist

root added 2 commits March 11, 2026 09:13
Signed-off-by: root <root@pool0-01605.cm.cluster>
… into hyperparameters model section

Signed-off-by: root <root@pool0-01605.cm.cluster>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 11, 2026

Greptile Summary

This PR fixes several issues in the StormCast training example. The changes consolidate model hyperparameter wiring to use model.hyperparameters (removing duplicated top-level fields like spatial_pos_embed, channel_mult, attn_resolutions), fix a crash in ParallelHelper.distribute_tensor when use_shard_tensor=False, explicitly thread use_apex_gn through the UNet construction path, align configs with the new layout, and improve test coverage by parametrizing num_invariant_channels.

Key changes:

  • utils/parallel.py: Fixes a bug where distribute_tensor called nested_scatter unconditionally — now correctly returns the tensor unchanged when use_shard_tensor=False.
  • utils/nn.py: Removes spatial_embedding, channel_mult, and attn_resolutions as explicit parameters (these are now passed via **model_kwargs from the hyperparameters config block) and adds use_apex_gn as an explicit argument.
  • utils/trainer.py: Wires use_apex_gn into get_preconditioned_unet and passes hyperparameters through cleanly. Adds a log when invariant conditions are configured but the dataset provides none — though this uses .info instead of .warning, which may cause the misconfiguration to go unnoticed.
  • test_training.py: Parametrizes num_invariant_channels and correctly gates "invariant" on the condition list based on whether invariants are actually provided.
  • Config files: diffusion.yaml moves spatial_pos_embed: True into hyperparameters.additive_pos_embed: True; stormcast.yaml removes the now-obsolete top-level spatial_pos_embed field.

Important Files Changed

Filename Overview
examples/weather/stormcast/utils/parallel.py Bug fix: distribute_tensor now correctly skips nested_scatter when use_shard_tensor=False, returning the tensor unchanged instead of attempting to shard it unnecessarily.
examples/weather/stormcast/utils/nn.py Refactored get_preconditioned_unet to remove hardcoded spatial_embedding, channel_mult, and attn_resolutions defaults (now passed via **model_kwargs from hyperparameters config) and added explicit use_apex_gn parameter. Changes look correct.
examples/weather/stormcast/utils/trainer.py Updated UNet wiring to use model_cfg.hyperparameters and pass use_apex_gn explicitly. Added a log when invariant conditions are configured but the dataset provides no invariants — however this uses .info instead of .warning, risking the misconfiguration being missed. The condition_list is also not cleaned up after the message, causing a subsequent "Model conditions" log to still list "invariant" as active.
examples/weather/stormcast/test_training.py Added num_invariant_channels parametrization to test_model_types, correctly conditionally appending "invariant" to diffusion_conditions only when invariants are present. Condition lists are now set to conservative defaults per model type with invariant added separately.
examples/weather/stormcast/datasets/mock.py Added num_invariant_channels parameter and implemented get_invariants() returning a deterministic random array when channels > 0, or None otherwise. Implementation is clean and consistent with the dataset interface.

Last reviewed commit: 558a869

Comment on lines +350 to +352
self.logger.info(
"Invariant conditions specified in model configuration, but dataset provides no invariants. Ignoring invariant conditions."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use warning level for misconfiguration log

When invariant conditions are specified in the model config but the dataset provides no invariants, the mismatch is silently logged at info level. This is a configuration issue the user should be explicitly alerted to — the specified conditions are being silently ignored, which could produce surprising results (e.g., a diffusion model trained without invariants while the config claims it uses them). A warning would be more appropriate here.

Suggested change
self.logger.info(
"Invariant conditions specified in model configuration, but dataset provides no invariants. Ignoring invariant conditions."
)
self.logger.warning(
"Invariant conditions specified in model configuration, but dataset provides no invariants. Ignoring invariant conditions."
)

Additionally, note that self.condition_list is not updated to remove "invariant" after this point. The subsequent "Model conditions" log in _setup_model will still include "invariant" as an active condition, which contradicts the "Ignoring invariant conditions" message and may mislead users inspecting the logs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant