|
| 1 | +#!/bin/bash |
| 2 | + |
| 3 | +set -euo pipefail |
| 4 | +# Short MAE-initialized CRPS fine-tuning run on CNS. |
| 5 | +# |
| 6 | +# Usage: |
| 7 | +# MAE_CHECKPOINT=/path/to/mae/encoder_processor_decoder.ckpt \ |
| 8 | +# bash slurm_scripts/ablations/vit_mae_pretrain/submit_vit_mae_to_crps_large.sh |
| 9 | +# |
| 10 | +# Run submit_vit_mae_to_crps_timing.sh first. If COSINE_EPOCHS_BY_DATASET is |
| 11 | +# left blank, this script derives max_epochs from the newest matching timing.ckpt. |
| 12 | + |
| 13 | +MAE_CHECKPOINT="${MAE_CHECKPOINT:-${1:-}}" |
| 14 | +if [[ -z "${MAE_CHECKPOINT}" ]]; then |
| 15 | + echo "FATAL: set MAE_CHECKPOINT or pass the MAE checkpoint path as argv[1]" >&2 |
| 16 | + exit 1 |
| 17 | +fi |
| 18 | + |
| 19 | +if [[ ! -f "${MAE_CHECKPOINT}" ]]; then |
| 20 | + echo "FATAL: MAE_CHECKPOINT does not exist: ${MAE_CHECKPOINT}" >&2 |
| 21 | + exit 1 |
| 22 | +fi |
| 23 | + |
| 24 | +declare -A EXPERIMENTS=( |
| 25 | + ["conditioned_navier_stokes"]="epd/conditioned_navier_stokes/crps_vit_azula_large" |
| 26 | +) |
| 27 | + |
| 28 | +declare -A COSINE_EPOCHS_BY_DATASET=( |
| 29 | + # ["conditioned_navier_stokes"]=... |
| 30 | +) |
| 31 | + |
| 32 | +CRPS_BUDGET_HOURS="${CRPS_BUDGET_HOURS:-4}" |
| 33 | +if ! [[ "${CRPS_BUDGET_HOURS}" =~ ^[1-9][0-9]*$ ]]; then |
| 34 | + echo "FATAL: CRPS_BUDGET_HOURS must be a positive integer number of hours" >&2 |
| 35 | + exit 1 |
| 36 | +fi |
| 37 | + |
| 38 | +BUDGET_MAX_TIME="$(printf "00:%02d:59:00" "$((CRPS_BUDGET_HOURS - 1))")" |
| 39 | +TIMEOUT_MIN=$((CRPS_BUDGET_HOURS * 60 - 1)) |
| 40 | +RUN_DRY_STATES=("true" "false") |
| 41 | +RUN_GROUP="$(date +%Y-%m-%d)/vit_mae_to_crps" |
| 42 | +N_MEMBERS=16 |
| 43 | +BS_PER_GPU=16 |
| 44 | + |
| 45 | +find_timing_checkpoint() { |
| 46 | + local run_id="$1" |
| 47 | + |
| 48 | + if [[ ! -d outputs ]]; then |
| 49 | + return 0 |
| 50 | + fi |
| 51 | + |
| 52 | + find outputs -path "*/timing_vit_mae_to_crps/${run_id}/timing.ckpt" | sort | tail -n 1 |
| 53 | +} |
| 54 | + |
| 55 | +derive_cosine_epochs_from_timing() { |
| 56 | + local timing_ckpt="$1" |
| 57 | + local result |
| 58 | + |
| 59 | + result="$( |
| 60 | + uv run autocast time-epochs \ |
| 61 | + --from-checkpoint "${timing_ckpt}" \ |
| 62 | + -b "${CRPS_BUDGET_HOURS}" \ |
| 63 | + -m 0.02 |
| 64 | + )" |
| 65 | + |
| 66 | + sed -n 's/.*trainer.max_epochs=\([0-9][0-9]*\).*/\1/p' <<< "${result}" | tail -n 1 |
| 67 | +} |
| 68 | + |
| 69 | +resolve_cosine_epochs() { |
| 70 | + local datamodule="$1" |
| 71 | + local cached="${COSINE_EPOCHS_BY_DATASET[$datamodule]:-}" |
| 72 | + |
| 73 | + if [[ -n "${cached}" ]]; then |
| 74 | + printf '%s\n' "${cached}" |
| 75 | + return 0 |
| 76 | + fi |
| 77 | + |
| 78 | + local run_id="vit_mae_to_crps_${datamodule}_m${N_MEMBERS}" |
| 79 | + local timing_ckpt |
| 80 | + timing_ckpt="$(find_timing_checkpoint "${run_id}")" |
| 81 | + |
| 82 | + if [[ -z "${timing_ckpt}" ]]; then |
| 83 | + return 1 |
| 84 | + fi |
| 85 | + |
| 86 | + derive_cosine_epochs_from_timing "${timing_ckpt}" |
| 87 | +} |
| 88 | + |
| 89 | +for datamodule in "${!EXPERIMENTS[@]}"; do |
| 90 | + experiment="${EXPERIMENTS[$datamodule]}" |
| 91 | + if ! cosine_epochs="$(resolve_cosine_epochs "${datamodule}")"; then |
| 92 | + echo "Skipping ${datamodule}: no timing-derived cosine_epochs available" >&2 |
| 93 | + continue |
| 94 | + fi |
| 95 | + if [[ -z "${cosine_epochs}" ]]; then |
| 96 | + echo "Skipping ${datamodule}: could not parse trainer.max_epochs from timing output" >&2 |
| 97 | + continue |
| 98 | + fi |
| 99 | + |
| 100 | + wandb_name="vit_mae_to_crps_m${N_MEMBERS}" |
| 101 | + |
| 102 | + for run_dry in "${RUN_DRY_STATES[@]}"; do |
| 103 | + dry_run_arg=() |
| 104 | + run_label="slurm" |
| 105 | + if [[ "${run_dry}" == "true" ]]; then |
| 106 | + dry_run_arg=(--dry-run) |
| 107 | + run_label="slurm --dry-run" |
| 108 | + fi |
| 109 | + |
| 110 | + echo "Submitting MAE-initialized CRPS fine-tuning" |
| 111 | + echo " mode: ${run_label}" |
| 112 | + echo " datamodule: ${datamodule}" |
| 113 | + echo " local_experiment: ${experiment}" |
| 114 | + echo " mae checkpoint: ${MAE_CHECKPOINT}" |
| 115 | + echo " n_members: ${N_MEMBERS}" |
| 116 | + echo " bs_per_gpu: ${BS_PER_GPU}" |
| 117 | + echo " budget: ${CRPS_BUDGET_HOURS}h" |
| 118 | + echo " cosine_epochs: ${cosine_epochs}" |
| 119 | + echo " wandb.name: ${wandb_name}" |
| 120 | + |
| 121 | + uv run autocast epd --mode slurm "${dry_run_arg[@]}" \ |
| 122 | + --run-group "${RUN_GROUP}" \ |
| 123 | + datamodule="${datamodule}" \ |
| 124 | + local_experiment="${experiment}" \ |
| 125 | + model.n_members="${N_MEMBERS}" \ |
| 126 | + datamodule.batch_size="${BS_PER_GPU}" \ |
| 127 | + +resume_from_checkpoint="${MAE_CHECKPOINT}" \ |
| 128 | + +resume_weights_only=true \ |
| 129 | + logging.wandb.enabled=true \ |
| 130 | + logging.wandb.name="${wandb_name}" \ |
| 131 | + logging.wandb.log_model=all \ |
| 132 | + optimizer.cosine_epochs="${cosine_epochs}" \ |
| 133 | + hydra.launcher.timeout_min="${TIMEOUT_MIN}" \ |
| 134 | + trainer.max_time="${BUDGET_MAX_TIME}" \ |
| 135 | + +trainer.max_epochs="${cosine_epochs}" \ |
| 136 | + trainer.callbacks.0.every_n_train_steps_fraction=0.05 \ |
| 137 | + trainer.callbacks.0.every_n_epochs=0 \ |
| 138 | + trainer.callbacks.0.save_top_k=-1 \ |
| 139 | + trainer.callbacks.0.filename=\"snapshot-{progress_token}-{epoch:04d}-{step:08d}\" |
| 140 | + done |
| 141 | +done |
0 commit comments