Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# @package _global_
defaults:
- /distributed: ddp_4gpu_slurm
- override /datamodule: conditioned_navier_stokes
- override /encoder@model.encoder: permute_concat
- override /decoder@model.decoder: channels_last
- override /processor@model.processor: unet_azula_large
- override /optimizer: adamw_half
- _self_

experiment_name: ablation_arch_crps_unet_azula_80m_conditioned_navier_stokes

datamodule:
use_normalization: true
batch_size: 32

float32_matmul_precision: high

logging:
wandb:
enabled: true

output:
skip_test: true

optimizer:
learning_rate: 2e-4
warmup: 0

model:
train_in_latent_space: false
n_members: 8
encoder:
with_constants: true
processor:
# Both backbones at their Azula-canonical FFN ratios (UNet ffn_factor=1,
# ViT ffn_factor=4). hid_channels follows the canonical 1:2:4:8 doubling;
# base=62 lands at ~81.3M processor params, matching the 81.0M CRPS ViT
# baseline within 0.3%. periodic=false matches CNS Neumann BCs.
hid_channels: [62, 124, 248, 496]
hid_blocks: [3, 3, 3, 3]
norm: layer
ffn_factor: 1
dropout: 0.0
periodic: false
gradient_checkpointing: false
n_noise_channels: 1024
loss_func:
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
train_metrics:
afcrps:
_target_: autocast.metrics.ensemble.AlphaFairCRPS
afcrps_mae_term:
_target_: autocast.metrics.ensemble.AlphaFairCRPSMAETerm
afcrps_spread_term:
_target_: autocast.metrics.ensemble.AlphaFairCRPSSpreadTerm
val_metrics:
afcrps:
_target_: autocast.metrics.ensemble.AlphaFairCRPS
afcrps_mae_term:
_target_: autocast.metrics.ensemble.AlphaFairCRPSMAETerm
afcrps_spread_term:
_target_: autocast.metrics.ensemble.AlphaFairCRPSSpreadTerm
spread:
_target_: autocast.metrics.ensemble.EnsembleSpread
multicoverage:
_target_: autocast.metrics.MultiCoverage
multiwinkler:
_target_: autocast.metrics.ensemble.MultiWinkler
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# @package _global_
defaults:
- /distributed: ddp_4gpu_slurm
- override /datamodule: conditioned_navier_stokes
- override /encoder@model.encoder: permute_concat
- override /decoder@model.decoder: channels_last
- override /processor@model.processor: vit_azula_large
- override /optimizer: adamw_half
- _self_

experiment_name: ablation_crps_variant_fair_conditioned_navier_stokes

datamodule:
use_normalization: true
batch_size: 32

float32_matmul_precision: high

logging:
wandb:
enabled: true

output:
skip_test: true

optimizer:
learning_rate: 2e-4
warmup: 0

model:
train_in_latent_space: false
n_members: 8
encoder:
with_constants: true
processor:
hidden_dim: 568
num_heads: 8
n_layers: 12
n_noise_channels: 1024
loss_func:
_target_: autocast.losses.ensemble.FairCRPSLoss
train_metrics:
fcrps:
_target_: autocast.metrics.ensemble.FairCRPS
fcrps_mae_term:
_target_: autocast.metrics.ensemble.FairCRPSMAETerm
fcrps_spread_term:
_target_: autocast.metrics.ensemble.FairCRPSSpreadTerm
val_metrics:
fcrps:
_target_: autocast.metrics.ensemble.FairCRPS
fcrps_mae_term:
_target_: autocast.metrics.ensemble.FairCRPSMAETerm
fcrps_spread_term:
_target_: autocast.metrics.ensemble.FairCRPSSpreadTerm
spread:
_target_: autocast.metrics.ensemble.EnsembleSpread
multicoverage:
_target_: autocast.metrics.MultiCoverage
multiwinkler:
_target_: autocast.metrics.ensemble.MultiWinkler
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# @package _global_
defaults:
- /distributed: ddp_4gpu_slurm
- override /datamodule: conditioned_navier_stokes
- override /encoder@model.encoder: permute_concat
- override /decoder@model.decoder: channels_last
- override /processor@model.processor: vit_azula_large
- override /optimizer: adamw_half
- _self_

experiment_name: ablation_crps_variant_plain_conditioned_navier_stokes

datamodule:
use_normalization: true
batch_size: 32

float32_matmul_precision: high

logging:
wandb:
enabled: true

output:
skip_test: true

optimizer:
learning_rate: 2e-4
warmup: 0

model:
train_in_latent_space: false
n_members: 8
encoder:
with_constants: true
processor:
hidden_dim: 568
num_heads: 8
n_layers: 12
n_noise_channels: 1024
loss_func:
_target_: autocast.losses.ensemble.CRPSLoss
train_metrics:
crps:
_target_: autocast.metrics.ensemble.CRPS
crps_mae_term:
_target_: autocast.metrics.ensemble.CRPSMAETerm
crps_spread_term:
_target_: autocast.metrics.ensemble.CRPSSpreadTerm
val_metrics:
crps:
_target_: autocast.metrics.ensemble.CRPS
crps_mae_term:
_target_: autocast.metrics.ensemble.CRPSMAETerm
crps_spread_term:
_target_: autocast.metrics.ensemble.CRPSSpreadTerm
spread:
_target_: autocast.metrics.ensemble.EnsembleSpread
multicoverage:
_target_: autocast.metrics.MultiCoverage
multiwinkler:
_target_: autocast.metrics.ensemble.MultiWinkler
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# @package _global_
defaults:
- /distributed: ddp_4gpu_slurm
- override /datamodule: conditioned_navier_stokes
- override /encoder@model.encoder: identity
- override /decoder@model.decoder: identity
- override /processor@model.processor: diffusion_vit
- override /backbone@model.processor.backbone: vit
- override /optimizer: adamw_half
- _self_

experiment_name: ablation_diffusion_vit_large_conditioned_navier_stokes

# Match the CNS FM ambient baseline as closely as possible: same identity
# conditioning path, batch size, optimizer, and ViT backbone.
datamodule:
use_normalization: true
batch_size: 256

float32_matmul_precision: high

logging:
wandb:
enabled: true

output:
skip_test: true

optimizer:
learning_rate: 1e-4
warmup: 0

model:
train_in_latent_space: true
processor:
denoiser_type: karras
sampler_steps: 50
sampler: euler
backbone:
hid_channels: 704
hid_blocks: 12
attention_heads: 8
patch_size: 4
val_metrics: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# @package _global_
defaults:
- /distributed: ddp_4gpu_slurm
- override /datamodule: conditioned_navier_stokes
- override /encoder@model.encoder: permute_concat
- override /decoder@model.decoder: channels_last
- override /processor@model.processor: vit_azula_large
- override /optimizer: adamw_half
- _self_

experiment_name: ablation_noise_channels_crps_vit_256_conditioned_navier_stokes

datamodule:
use_normalization: true
batch_size: 32

float32_matmul_precision: high

logging:
wandb:
enabled: true

output:
skip_test: true

optimizer:
learning_rate: 2e-4
warmup: 0

model:
train_in_latent_space: false
n_members: 8
encoder:
with_constants: true
processor:
# With n_noise_channels=256, hidden_dim=568 drops the processor to ~53.4M
# params. Keep depth/heads fixed and use width as the single balancing knob;
# hidden_dim=704 gives ~79.9M processor params for CNS ambient shapes.
hidden_dim: 704
num_heads: 8
n_layers: 12
n_noise_channels: 256
loss_func:
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
train_metrics:
afcrps:
_target_: autocast.metrics.ensemble.AlphaFairCRPS
afcrps_mae_term:
_target_: autocast.metrics.ensemble.AlphaFairCRPSMAETerm
afcrps_spread_term:
_target_: autocast.metrics.ensemble.AlphaFairCRPSSpreadTerm
val_metrics:
afcrps:
_target_: autocast.metrics.ensemble.AlphaFairCRPS
afcrps_mae_term:
_target_: autocast.metrics.ensemble.AlphaFairCRPSMAETerm
afcrps_spread_term:
_target_: autocast.metrics.ensemble.AlphaFairCRPSSpreadTerm
spread:
_target_: autocast.metrics.ensemble.EnsembleSpread
multicoverage:
_target_: autocast.metrics.MultiCoverage
multiwinkler:
_target_: autocast.metrics.ensemble.MultiWinkler
34 changes: 28 additions & 6 deletions slurm_scripts/ablations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ small edit.
|---|---|---|---|---|
| ensemble_size (m=16, fixed bs=32) | sweep | CNS | 1 | ready |
| ensemble_size (m=16, fixed global eff. bs=1024) | sweep | GS / GPE / CNS / AD | 4 | timing ready |
| noise_channels | sweep | CNS | 1+ | stub |
| crps_variants (AlphaFair / Fair / CRPS) | comparison | CNS | 3 | stub |
| fm_vs_diffusion | comparison | CNS | 1 | stub |
| arch_unet_fno_vit | comparison | CNS | 2 | stub |
| planned_cns batch | mixed | CNS | 8 | timing scripted |
| noise_channels | sweep | CNS | 1 | config + planned |
| crps_variants (AlphaFair / Fair / CRPS) | comparison | CNS | 2 new (+baseline) | config + planned |
| fm_vs_diffusion | comparison | CNS | 1 | config + planned |
| arch_unet_fno_vit | comparison | CNS | 1 U-Net (+ViT baseline) | config + planned |
| model_size | sweep | CNS | 2 active (+2 staged) | in progress |
| vit_mae_pretrain | pretrain | CNS | 1 | staged |
| cached_latent_crps | comparison | CNS | 1 (done, 2026-04-20) | stub |
| cond_global_vs_permute | comparison | CNS | 1 (done for CRPS-ViT, 2026-04-18) | stub |
| cached_latent_crps | comparison | CNS | 1 (basis: 2026-04-20) | eval ready |
| cond_global_vs_permute | comparison | CNS | 1 planned rerun (+old 2026-04-18 point) | config ready |
| eval_only/ode_steps | eval-only | FM runs | 0 | stub |
| eval_only/ema | eval-only | EMA ckpts | 0 | stub |

Expand All @@ -35,6 +36,27 @@ small edit.
ablation — no new training required, but they should be eval'd through
the same pipeline.

## Planned CNS batch

The current planned CNS batch is centralized in
`submit_planned_cns_timing.sh` and `submit_planned_cns_large.sh` so the
cross-ablation run list can be submitted consistently after timing. It covers:

| planned run | study folder | implementation |
|---|---|---|
| U-Net m=8 CRPS CNS | `arch_unet_fno_vit` | `crps_unet_azula_80m`, ~80.9M params |
| Diffusion CNS | `fm_vs_diffusion` | diffusion processor with the FM 704/12/8 ViT backbone |
| CNS m=8 fair CRPS | `crps_variants` | FairCRPS loss on the 80M CRPS ViT |
| CNS m=8 CRPS | `crps_variants` | plain CRPS loss on the 80M CRPS ViT |
| CNS ViT noise channels=256 | `noise_channels` | CRPS ViT with `n_noise_channels=256`, `hidden_dim=704` (~79.9M params) |
| CNS m=4 ViT | `ensemble_size` | canonical CRPS ViT plus `n_members=4`, `batch_size=64` |
| CNS m=8 latent CRPS | `cached_latent_crps` | 2026-04-20 cached-latent CRPS basis |
| CNS m=8 CRPS ViT global cond | `cond_global_vs_permute` | identity encoder + `include_global_cond=true` |

Use the 2026-04-24 CRPS ambient runs for current CRPS comparison numbers and
the 2026-04-20 `diff_*` cached-latent runs as the FM/diff basis. The comparison
eval scripts have those dates wired in.

## Design notes

- **Flexible by construction.** Each ablation is a self-contained
Expand Down
22 changes: 12 additions & 10 deletions slurm_scripts/ablations/arch_unet_fno_vit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Compare U-Net and FNO backbones against the ViT (Azula) baseline on the
CRPS ambient path.

**Status:** stub — no scripts yet.
**Status:** U-Net CNS config added; FNO remains unscheduled.

## Baseline

Expand All @@ -19,19 +19,21 @@ Swap `model.processor` backbone while trying to match parameter count
CRPS.
- `local_hydra/local_experiment/epd_crps_fno.yaml` — FNO + CRPS.

Each will need per-CNS `local_experiment/ablations/arch/<arch>.yaml`
that matches the ambient baseline's encoder/decoder/loss so only the
backbone varies.
The planned U-Net run uses
`local_hydra/local_experiment/ablations/arch_unet_fno_vit/conditioned_navier_stokes/crps_unet_azula_80m.yaml`.
It matches the ambient baseline's encoder/decoder/loss and uses an Azula U-Net
channel ladder `[47, 94, 188, 376]`, measured at ~80.9M processor params for
CNS ambient shapes.

FNO still needs a matching per-CNS config before scheduling.

## Datasets

CNS only for now. Table says 2 datasets × 2 non-ViT archs = 4 runs
(CNS gives 2: U-Net and FNO).
CNS only for now. Current planned coverage is U-Net only; FNO is held back
until the parameter-matching decision is settled.

## Outstanding decisions

- How to match parameter count across architectures — the comparison
table for the main study (see `slurm_scripts/comparison/README.md`)
locked ~80M for ViT variants; we need equivalent targets for U-Net
and FNO.
- How to match FNO parameter count — the U-Net target is now fixed at ~80.9M
to match the 80M ViT variants.
- Whether FNO needs a different patch-size / token structure.
Loading
Loading