Skip to content

Commit 1bad349

Browse files
committed
Add MAE to CRPS submitters
Add timing and short fine-tune Slurm submitters for CRPS runs initialized from deterministic MAE checkpoints. The scripts use n_members=16 with batch_size=16 to preserve the comparison effective global batch while starting a fresh optimizer and time budget.
1 parent 92fed52 commit 1bad349

3 files changed

Lines changed: 234 additions & 6 deletions

File tree

slurm_scripts/ablations/vit_mae_pretrain/README.md

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ but trains without the ensemble path:
2020
| file | purpose |
2121
|---|---|
2222
| `local_hydra/local_experiment/ablations/vit_mae_pretrain/conditioned_navier_stokes/vit_azula_large_mae_no_ensemble.yaml` | CNS deterministic MAE preset |
23-
| `submit_vit_mae_pretrain_timing.sh` | 5-epoch timing run -> `timing.ckpt` |
24-
| `submit_vit_mae_pretrain_large.sh` | 24h production run, keeping and W&B-logging all progress checkpoints |
23+
| `submit_vit_mae_pretrain_timing.sh` | 5-epoch MAE timing run -> `timing.ckpt` |
24+
| `submit_vit_mae_pretrain_large.sh` | 24h MAE production run, keeping and W&B-logging all progress checkpoints |
25+
| `submit_vit_mae_to_crps_timing.sh` | 5-epoch timing for MAE-initialized CRPS fine-tuning with `n_members=16` |
26+
| `submit_vit_mae_to_crps_large.sh` | short MAE-initialized CRPS fine-tune, defaulting to a 4h budget |
2527

2628
## Workflow
2729

@@ -57,7 +59,22 @@ progress with `save_top_k=-1`, keeps `last.ckpt`, and sets
5759
`logging.wandb.log_model=all` so W&B logs every checkpoint artifact emitted by
5860
the checkpoint callbacks.
5961

60-
For the follow-up shortened CRPS fine-tune, start from one of these checkpoints
61-
with `resume_weights_only=true` rather than full-state resume. That loads the
62-
deterministic MAE weights into the CRPS ensemble model while starting a fresh
63-
optimizer, scheduler, and time budget.
62+
For the follow-up shortened CRPS fine-tune, use the `vit_mae_to_crps` scripts
63+
and point `MAE_CHECKPOINT` at one of the MAE checkpoints:
64+
65+
```bash
66+
MAE_CHECKPOINT=/path/to/mae/encoder_processor_decoder.ckpt \
67+
bash slurm_scripts/ablations/vit_mae_pretrain/submit_vit_mae_to_crps_timing.sh
68+
69+
MAE_CHECKPOINT=/path/to/mae/encoder_processor_decoder.ckpt \
70+
bash slurm_scripts/ablations/vit_mae_pretrain/submit_vit_mae_to_crps_large.sh
71+
```
72+
73+
The CRPS fine-tune uses `n_members=16` with `datamodule.batch_size=16`, keeping
74+
the effective global batch at `16 x 16 x 4 GPUs = 1024`. It also uses
75+
`resume_weights_only=true` rather than full-state resume, so the deterministic
76+
MAE weights initialize the CRPS ensemble model while the optimizer, scheduler,
77+
and time budget start fresh.
78+
79+
The default CRPS fine-tune budget is 4h. Override it for both timing and large
80+
runs with, for example, `CRPS_BUDGET_HOURS=6`.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#!/bin/bash
2+
3+
set -euo pipefail
4+
# Short MAE-initialized CRPS fine-tuning run on CNS.
5+
#
6+
# Usage:
7+
# MAE_CHECKPOINT=/path/to/mae/encoder_processor_decoder.ckpt \
8+
# bash slurm_scripts/ablations/vit_mae_pretrain/submit_vit_mae_to_crps_large.sh
9+
#
10+
# Run submit_vit_mae_to_crps_timing.sh first. If COSINE_EPOCHS_BY_DATASET is
11+
# left blank, this script derives max_epochs from the newest matching timing.ckpt.
12+
13+
MAE_CHECKPOINT="${MAE_CHECKPOINT:-${1:-}}"
14+
if [[ -z "${MAE_CHECKPOINT}" ]]; then
15+
echo "FATAL: set MAE_CHECKPOINT or pass the MAE checkpoint path as argv[1]" >&2
16+
exit 1
17+
fi
18+
19+
if [[ ! -f "${MAE_CHECKPOINT}" ]]; then
20+
echo "FATAL: MAE_CHECKPOINT does not exist: ${MAE_CHECKPOINT}" >&2
21+
exit 1
22+
fi
23+
24+
declare -A EXPERIMENTS=(
25+
["conditioned_navier_stokes"]="epd/conditioned_navier_stokes/crps_vit_azula_large"
26+
)
27+
28+
declare -A COSINE_EPOCHS_BY_DATASET=(
29+
# ["conditioned_navier_stokes"]=...
30+
)
31+
32+
CRPS_BUDGET_HOURS="${CRPS_BUDGET_HOURS:-4}"
33+
if ! [[ "${CRPS_BUDGET_HOURS}" =~ ^[1-9][0-9]*$ ]]; then
34+
echo "FATAL: CRPS_BUDGET_HOURS must be a positive integer number of hours" >&2
35+
exit 1
36+
fi
37+
38+
BUDGET_MAX_TIME="$(printf "00:%02d:59:00" "$((CRPS_BUDGET_HOURS - 1))")"
39+
TIMEOUT_MIN=$((CRPS_BUDGET_HOURS * 60 - 1))
40+
RUN_DRY_STATES=("true" "false")
41+
RUN_GROUP="$(date +%Y-%m-%d)/vit_mae_to_crps"
42+
N_MEMBERS=16
43+
BS_PER_GPU=16
44+
45+
find_timing_checkpoint() {
46+
local run_id="$1"
47+
48+
if [[ ! -d outputs ]]; then
49+
return 0
50+
fi
51+
52+
find outputs -path "*/timing_vit_mae_to_crps/${run_id}/timing.ckpt" | sort | tail -n 1
53+
}
54+
55+
derive_cosine_epochs_from_timing() {
56+
local timing_ckpt="$1"
57+
local result
58+
59+
result="$(
60+
uv run autocast time-epochs \
61+
--from-checkpoint "${timing_ckpt}" \
62+
-b "${CRPS_BUDGET_HOURS}" \
63+
-m 0.02
64+
)"
65+
66+
sed -n 's/.*trainer.max_epochs=\([0-9][0-9]*\).*/\1/p' <<< "${result}" | tail -n 1
67+
}
68+
69+
resolve_cosine_epochs() {
70+
local datamodule="$1"
71+
local cached="${COSINE_EPOCHS_BY_DATASET[$datamodule]:-}"
72+
73+
if [[ -n "${cached}" ]]; then
74+
printf '%s\n' "${cached}"
75+
return 0
76+
fi
77+
78+
local run_id="vit_mae_to_crps_${datamodule}_m${N_MEMBERS}"
79+
local timing_ckpt
80+
timing_ckpt="$(find_timing_checkpoint "${run_id}")"
81+
82+
if [[ -z "${timing_ckpt}" ]]; then
83+
return 1
84+
fi
85+
86+
derive_cosine_epochs_from_timing "${timing_ckpt}"
87+
}
88+
89+
for datamodule in "${!EXPERIMENTS[@]}"; do
90+
experiment="${EXPERIMENTS[$datamodule]}"
91+
if ! cosine_epochs="$(resolve_cosine_epochs "${datamodule}")"; then
92+
echo "Skipping ${datamodule}: no timing-derived cosine_epochs available" >&2
93+
continue
94+
fi
95+
if [[ -z "${cosine_epochs}" ]]; then
96+
echo "Skipping ${datamodule}: could not parse trainer.max_epochs from timing output" >&2
97+
continue
98+
fi
99+
100+
wandb_name="vit_mae_to_crps_m${N_MEMBERS}"
101+
102+
for run_dry in "${RUN_DRY_STATES[@]}"; do
103+
dry_run_arg=()
104+
run_label="slurm"
105+
if [[ "${run_dry}" == "true" ]]; then
106+
dry_run_arg=(--dry-run)
107+
run_label="slurm --dry-run"
108+
fi
109+
110+
echo "Submitting MAE-initialized CRPS fine-tuning"
111+
echo " mode: ${run_label}"
112+
echo " datamodule: ${datamodule}"
113+
echo " local_experiment: ${experiment}"
114+
echo " mae checkpoint: ${MAE_CHECKPOINT}"
115+
echo " n_members: ${N_MEMBERS}"
116+
echo " bs_per_gpu: ${BS_PER_GPU}"
117+
echo " budget: ${CRPS_BUDGET_HOURS}h"
118+
echo " cosine_epochs: ${cosine_epochs}"
119+
echo " wandb.name: ${wandb_name}"
120+
121+
uv run autocast epd --mode slurm "${dry_run_arg[@]}" \
122+
--run-group "${RUN_GROUP}" \
123+
datamodule="${datamodule}" \
124+
local_experiment="${experiment}" \
125+
model.n_members="${N_MEMBERS}" \
126+
datamodule.batch_size="${BS_PER_GPU}" \
127+
+resume_from_checkpoint="${MAE_CHECKPOINT}" \
128+
+resume_weights_only=true \
129+
logging.wandb.enabled=true \
130+
logging.wandb.name="${wandb_name}" \
131+
logging.wandb.log_model=all \
132+
optimizer.cosine_epochs="${cosine_epochs}" \
133+
hydra.launcher.timeout_min="${TIMEOUT_MIN}" \
134+
trainer.max_time="${BUDGET_MAX_TIME}" \
135+
+trainer.max_epochs="${cosine_epochs}" \
136+
trainer.callbacks.0.every_n_train_steps_fraction=0.05 \
137+
trainer.callbacks.0.every_n_epochs=0 \
138+
trainer.callbacks.0.save_top_k=-1 \
139+
trainer.callbacks.0.filename=\"snapshot-{progress_token}-{epoch:04d}-{step:08d}\"
140+
done
141+
done
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/bin/bash
2+
3+
set -euo pipefail
4+
# Timing run for MAE-initialized CRPS fine-tuning on CNS.
5+
#
6+
# Usage:
7+
# MAE_CHECKPOINT=/path/to/mae/encoder_processor_decoder.ckpt \
8+
# bash slurm_scripts/ablations/vit_mae_pretrain/submit_vit_mae_to_crps_timing.sh
9+
#
10+
# This fine-tune uses CRPS with n_members=16 and batch_size=16/GPU, keeping
11+
# the effective global batch at 16 x 16 x 4 GPUs = 1024.
12+
13+
MAE_CHECKPOINT="${MAE_CHECKPOINT:-${1:-}}"
14+
if [[ -z "${MAE_CHECKPOINT}" ]]; then
15+
echo "FATAL: set MAE_CHECKPOINT or pass the MAE checkpoint path as argv[1]" >&2
16+
exit 1
17+
fi
18+
19+
if [[ ! -f "${MAE_CHECKPOINT}" ]]; then
20+
echo "FATAL: MAE_CHECKPOINT does not exist: ${MAE_CHECKPOINT}" >&2
21+
exit 1
22+
fi
23+
24+
declare -A EXPERIMENTS=(
25+
["conditioned_navier_stokes"]="epd/conditioned_navier_stokes/crps_vit_azula_large"
26+
)
27+
28+
CRPS_BUDGET_HOURS="${CRPS_BUDGET_HOURS:-4}"
29+
NUM_TIMING_EPOCHS=5
30+
RUN_GROUP="$(date +%Y-%m-%d)/timing_vit_mae_to_crps"
31+
N_MEMBERS=16
32+
BS_PER_GPU=16
33+
34+
for datamodule in "${!EXPERIMENTS[@]}"; do
35+
experiment="${EXPERIMENTS[$datamodule]}"
36+
run_id="vit_mae_to_crps_${datamodule}_m${N_MEMBERS}"
37+
38+
echo "Submitting MAE-initialized CRPS timing run"
39+
echo " datamodule: ${datamodule}"
40+
echo " local_experiment: ${experiment}"
41+
echo " mae checkpoint: ${MAE_CHECKPOINT}"
42+
echo " n_members: ${N_MEMBERS}"
43+
echo " bs_per_gpu: ${BS_PER_GPU}"
44+
echo " run_id: ${run_id}"
45+
echo " timing epochs: ${NUM_TIMING_EPOCHS}"
46+
echo " budget: ${CRPS_BUDGET_HOURS}h"
47+
echo " run_group: ${RUN_GROUP}"
48+
echo ""
49+
50+
uv run autocast time-epochs --kind epd --mode slurm \
51+
--run-group "${RUN_GROUP}" \
52+
--run-id "${run_id}" \
53+
-n "${NUM_TIMING_EPOCHS}" \
54+
-b "${CRPS_BUDGET_HOURS}" \
55+
datamodule="${datamodule}" \
56+
local_experiment="${experiment}" \
57+
model.n_members="${N_MEMBERS}" \
58+
datamodule.batch_size="${BS_PER_GPU}" \
59+
+resume_from_checkpoint="${MAE_CHECKPOINT}" \
60+
+resume_weights_only=true
61+
62+
echo ""
63+
echo "---"
64+
echo ""
65+
done
66+
67+
echo "All MAE-initialized CRPS timing jobs submitted."
68+
echo ""
69+
echo "Once SLURM jobs complete, collect all results with:"
70+
echo " for f in outputs/${RUN_GROUP}/vit_mae_to_crps_*/retrieve.sh; do bash \"\$f\"; done"

0 commit comments

Comments
 (0)