Skip to content

Commit 92fed52

Browse files
committed
Add ViT MAE pretrain ablation
Add a CNS-only deterministic ViT MAE pretraining preset and matching Slurm timing and 24h submitters. The run disables ensemble expansion, logs all progress checkpoints, and documents how the checkpoints can seed later CRPS fine-tuning.
1 parent b16852e commit 92fed52

5 files changed

Lines changed: 276 additions & 0 deletions

File tree

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# @package _global_
2+
defaults:
3+
- /distributed: ddp_4gpu_slurm
4+
- override /datamodule: conditioned_navier_stokes
5+
- override /encoder@model.encoder: permute_concat
6+
- override /decoder@model.decoder: channels_last
7+
- override /processor@model.processor: vit_azula_large
8+
- override /optimizer: adamw_half
9+
- _self_
10+
11+
experiment_name: ablation_vit_mae_pretrain_conditioned_navier_stokes
12+
13+
datamodule:
14+
use_normalization: true
15+
# Match the CRPS ViT effective per-GPU batch:
16+
# baseline CRPS uses batch_size=32 x n_members=8 = 256 examples.
17+
batch_size: 256
18+
19+
float32_matmul_precision: high
20+
21+
logging:
22+
wandb:
23+
enabled: true
24+
25+
output:
26+
skip_test: true
27+
28+
optimizer:
29+
learning_rate: 2e-4
30+
warmup: 0
31+
32+
model:
33+
train_in_latent_space: false
34+
# n_members=1 keeps setup on deterministic EncoderProcessorDecoder.
35+
n_members: 1
36+
encoder:
37+
with_constants: true
38+
processor:
39+
hidden_dim: 568
40+
num_heads: 8
41+
n_layers: 12
42+
n_noise_channels: 1024
43+
loss_func:
44+
_target_: torch.nn.L1Loss
45+
train_metrics:
46+
mae:
47+
_target_: autocast.metrics.MAE
48+
rmse:
49+
_target_: autocast.metrics.RMSE
50+
val_metrics:
51+
mae:
52+
_target_: autocast.metrics.MAE
53+
rmse:
54+
_target_: autocast.metrics.RMSE

slurm_scripts/ablations/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ small edit.
2424
| fm_vs_diffusion | comparison | CNS | 1 | stub |
2525
| arch_unet_fno_vit | comparison | CNS | 2 | stub |
2626
| model_size | sweep | CNS | 2 active (+2 staged) | in progress |
27+
| vit_mae_pretrain | pretrain | CNS | 1 | staged |
2728
| cached_latent_crps | comparison | CNS | 1 (done, 2026-04-20) | stub |
2829
| cond_global_vs_permute | comparison | CNS | 1 (done for CRPS-ViT, 2026-04-18) | stub |
2930
| eval_only/ode_steps | eval-only | FM runs | 0 | stub |
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# ViT MAE pretraining
2+
3+
Deterministic MAE pretraining run for the CNS ambient ViT baseline. The model
4+
keeps the CRPS ViT architecture from
5+
`local_hydra/local_experiment/epd/conditioned_navier_stokes/crps_vit_azula_large.yaml`
6+
but trains without the ensemble path:
7+
8+
- `model.n_members=1`, which instantiates `EncoderProcessorDecoder` instead of
9+
`EncoderProcessorDecoderEnsemble`.
10+
- `model.loss_func=torch.nn.L1Loss`, so `train_loss` and `val_loss` are MAE in
11+
normalized space.
12+
- deterministic `MAE` and `RMSE` metrics replace ensemble CRPS metrics.
13+
- `datamodule.batch_size=256` preserves the CRPS baseline's effective per-GPU
14+
batch size (`32 x 8 = 256`) now that there is no ensemble expansion.
15+
16+
**Status:** staged - run timing first, then launch the 24h production script.
17+
18+
## Files
19+
20+
| file | purpose |
21+
|---|---|
22+
| `local_hydra/local_experiment/ablations/vit_mae_pretrain/conditioned_navier_stokes/vit_azula_large_mae_no_ensemble.yaml` | CNS deterministic MAE preset |
23+
| `submit_vit_mae_pretrain_timing.sh` | 5-epoch timing run -> `timing.ckpt` |
24+
| `submit_vit_mae_pretrain_large.sh` | 24h production run, keeping and W&B-logging all progress checkpoints |
25+
26+
## Workflow
27+
28+
1. Submit timing:
29+
30+
```bash
31+
bash slurm_scripts/ablations/vit_mae_pretrain/submit_vit_mae_pretrain_timing.sh
32+
```
33+
34+
2. After the timing job finishes, collect the schedule:
35+
36+
```bash
37+
uv run autocast time-epochs --from-checkpoint <path>/timing.ckpt -b 24
38+
```
39+
40+
3. Paste the emitted `trainer.max_epochs` value into
41+
`COSINE_EPOCHS_BY_DATASET` in `submit_vit_mae_pretrain_large.sh`, or leave it
42+
blank and let the script derive it from the newest matching timing checkpoint.
43+
44+
4. Submit the 24h pretraining run:
45+
46+
```bash
47+
bash slurm_scripts/ablations/vit_mae_pretrain/submit_vit_mae_pretrain_large.sh
48+
```
49+
50+
The production script intentionally runs dry-run first and then submits the real
51+
job, following the other ablation submitters.
52+
53+
## Checkpoints and CRPS fine-tuning
54+
55+
The 24h script saves local progress checkpoints every ~5% of optimizer-step
56+
progress with `save_top_k=-1`, keeps `last.ckpt`, and sets
57+
`logging.wandb.log_model=all` so W&B logs every checkpoint artifact emitted by
58+
the checkpoint callbacks.
59+
60+
For the follow-up shortened CRPS fine-tune, start from one of these checkpoints
61+
with `resume_weights_only=true` rather than full-state resume. That loads the
62+
deterministic MAE weights into the CRPS ensemble model while starting a fresh
63+
optimizer, scheduler, and time budget.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#!/bin/bash
2+
3+
set -euo pipefail
4+
# 24h deterministic ViT MAE pretraining on CNS.
5+
#
6+
# Populate COSINE_EPOCHS_BY_DATASET after running
7+
# submit_vit_mae_pretrain_timing.sh and extracting timing.ckpt with:
8+
# uv run autocast time-epochs --from-checkpoint <path>/timing.ckpt -b 24
9+
# If left blank, the script falls back to the newest matching timing.ckpt
10+
# under outputs/*/timing_vit_mae_pretrain/.
11+
12+
declare -A EXPERIMENTS=(
13+
["conditioned_navier_stokes"]="ablations/vit_mae_pretrain/conditioned_navier_stokes/vit_azula_large_mae_no_ensemble"
14+
)
15+
16+
declare -A COSINE_EPOCHS_BY_DATASET=(
17+
# ["conditioned_navier_stokes"]=...
18+
)
19+
20+
BUDGET_MAX_TIME="00:23:59:00"
21+
TIMEOUT_MIN=1439
22+
RUN_DRY_STATES=("true" "false")
23+
RUN_GROUP="$(date +%Y-%m-%d)/vit_mae_pretrain"
24+
25+
find_timing_checkpoint() {
26+
local run_id="$1"
27+
28+
if [[ ! -d outputs ]]; then
29+
return 0
30+
fi
31+
32+
find outputs -path "*/timing_vit_mae_pretrain/${run_id}/timing.ckpt" | sort | tail -n 1
33+
}
34+
35+
derive_cosine_epochs_from_timing() {
36+
local timing_ckpt="$1"
37+
local result
38+
39+
result="$(
40+
uv run autocast time-epochs --from-checkpoint "${timing_ckpt}" -b 24
41+
)"
42+
43+
sed -n 's/.*trainer.max_epochs=\([0-9][0-9]*\).*/\1/p' <<< "${result}" | tail -n 1
44+
}
45+
46+
resolve_cosine_epochs() {
47+
local datamodule="$1"
48+
local cached="${COSINE_EPOCHS_BY_DATASET[$datamodule]:-}"
49+
50+
if [[ -n "${cached}" ]]; then
51+
printf '%s\n' "${cached}"
52+
return 0
53+
fi
54+
55+
local run_id="vit_mae_pretrain_${datamodule}"
56+
local timing_ckpt
57+
timing_ckpt="$(find_timing_checkpoint "${run_id}")"
58+
59+
if [[ -z "${timing_ckpt}" ]]; then
60+
return 1
61+
fi
62+
63+
derive_cosine_epochs_from_timing "${timing_ckpt}"
64+
}
65+
66+
for datamodule in "${!EXPERIMENTS[@]}"; do
67+
experiment="${EXPERIMENTS[$datamodule]}"
68+
if ! cosine_epochs="$(resolve_cosine_epochs "${datamodule}")"; then
69+
echo "Skipping ${datamodule}: no timing-derived cosine_epochs available" >&2
70+
continue
71+
fi
72+
if [[ -z "${cosine_epochs}" ]]; then
73+
echo "Skipping ${datamodule}: could not parse trainer.max_epochs from timing output" >&2
74+
continue
75+
fi
76+
77+
wandb_name="vit_mae_pretrain_no_ensemble"
78+
79+
for run_dry in "${RUN_DRY_STATES[@]}"; do
80+
dry_run_arg=()
81+
run_label="slurm"
82+
if [[ "${run_dry}" == "true" ]]; then
83+
dry_run_arg=(--dry-run)
84+
run_label="slurm --dry-run"
85+
fi
86+
87+
echo "Submitting ViT MAE pretraining"
88+
echo " mode: ${run_label}"
89+
echo " datamodule: ${datamodule}"
90+
echo " local_experiment: ${experiment}"
91+
echo " cosine_epochs: ${cosine_epochs}"
92+
echo " wandb.name: ${wandb_name}"
93+
94+
uv run autocast epd --mode slurm "${dry_run_arg[@]}" \
95+
--run-group "${RUN_GROUP}" \
96+
datamodule="${datamodule}" \
97+
local_experiment="${experiment}" \
98+
logging.wandb.enabled=true \
99+
logging.wandb.name="${wandb_name}" \
100+
logging.wandb.log_model=all \
101+
optimizer.cosine_epochs="${cosine_epochs}" \
102+
hydra.launcher.timeout_min="${TIMEOUT_MIN}" \
103+
trainer.max_time="${BUDGET_MAX_TIME}" \
104+
+trainer.max_epochs="${cosine_epochs}" \
105+
trainer.callbacks.0.every_n_train_steps_fraction=0.05 \
106+
trainer.callbacks.0.every_n_epochs=0 \
107+
trainer.callbacks.0.save_top_k=-1 \
108+
trainer.callbacks.0.filename=\"snapshot-{progress_token}-{epoch:04d}-{step:08d}\"
109+
done
110+
done
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/bin/bash
2+
3+
set -euo pipefail
4+
# Timing run for deterministic ViT MAE pretraining on CNS.
5+
#
6+
# This starts from the CRPS ViT ambient architecture but disables the ensemble
7+
# path (n_members=1) and trains with torch.nn.L1Loss. Run this first, then
8+
# derive the 24h cosine schedule from timing.ckpt:
9+
# uv run autocast time-epochs --from-checkpoint <path>/timing.ckpt -b 24
10+
11+
declare -A EXPERIMENTS=(
12+
["conditioned_navier_stokes"]="ablations/vit_mae_pretrain/conditioned_navier_stokes/vit_azula_large_mae_no_ensemble"
13+
)
14+
15+
BUDGET_HOURS=24
16+
NUM_TIMING_EPOCHS=5
17+
RUN_GROUP="$(date +%Y-%m-%d)/timing_vit_mae_pretrain"
18+
19+
for datamodule in "${!EXPERIMENTS[@]}"; do
20+
experiment="${EXPERIMENTS[$datamodule]}"
21+
run_id="vit_mae_pretrain_${datamodule}"
22+
23+
echo "Submitting ViT MAE pretrain timing run"
24+
echo " datamodule: ${datamodule}"
25+
echo " local_experiment: ${experiment}"
26+
echo " run_id: ${run_id}"
27+
echo " timing epochs: ${NUM_TIMING_EPOCHS}"
28+
echo " budget: ${BUDGET_HOURS}h"
29+
echo " run_group: ${RUN_GROUP}"
30+
echo ""
31+
32+
uv run autocast time-epochs --kind epd --mode slurm \
33+
--run-group "${RUN_GROUP}" \
34+
--run-id "${run_id}" \
35+
-n "${NUM_TIMING_EPOCHS}" \
36+
-b "${BUDGET_HOURS}" \
37+
datamodule="${datamodule}" \
38+
local_experiment="${experiment}"
39+
40+
echo ""
41+
echo "---"
42+
echo ""
43+
done
44+
45+
echo "All ViT MAE pretrain timing jobs submitted."
46+
echo ""
47+
echo "Once SLURM jobs complete, collect all results with:"
48+
echo " for f in outputs/${RUN_GROUP}/vit_mae_pretrain_*/retrieve.sh; do bash \"\$f\"; done"

0 commit comments

Comments
 (0)