From 9065da9ebbdc0d10e508d2925afa10bb4b29bd4d Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 30 Mar 2026 20:47:31 +0200 Subject: [PATCH 01/31] Add script courtesy of @yawenzzzz --- ...ase_band_dropout_no_s1_drop_random_time.py | 359 ++++++++++++++++++ 1 file changed, 359 insertions(+) create mode 100644 scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py new file mode 100644 index 000000000..7b2ae1ea7 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py @@ -0,0 +1,359 @@ +"""Base script for single bandset + random band dropout (no S1) + random time with decode masking + masked-negatives loss. + +- Single bandset S2 (all 12 bands) / Landsat (all 11 bands) +- Random band dropout (rate ~ Uniform(0, 0.3)) on S2 and Landsat only (no S1 dropout) +- Random time with decode masking +- Masked negatives patch discrimination loss +- InfoNCE weight 0.05 +- Rank microbatch size 64 +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel.data_parallel import ( + DataParallelConfig, + DataParallelType, +) +from olmo_core.optim import AdamWConfig +from olmo_core.optim.scheduler import CosWithWarmup +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig + +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig +from olmoearth_pretrain.data.dataset import OlmoEarthDatasetConfig +from olmoearth_pretrain.internal.common import ( + build_common_components as build_common_components_default, +) +from olmoearth_pretrain.internal.experiment import ( + CommonComponents, + OlmoEarthVisualizeConfig, + SubCmd, + main, +) +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexi_vit import ( + PoolingType, +) +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig +from olmoearth_pretrain.nn.tokenization import ModalityTokenization, TokenizationConfig +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthSpeedMonitorCallback, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import DownstreamTaskConfig +from olmoearth_pretrain.train.loss import LossConfig +from olmoearth_pretrain.train.masking import MaskingConfig +from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( + ContrastiveLatentMIMTrainModuleConfig, +) + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 +RANDOM_BAND_DROPOUT_MAX_RATE = 0.2 + +S2_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + [ + "B02", + "B03", + "B04", + "B08", + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12", + "B01", + "B09", + ], + ] +) + +LANDSAT_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], + ] +) + +ONLY_DECODE_MODALITIES = [ + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, +] + +# No S1 dropout — only apply band dropout to S2 and Landsat. +BAND_DROPOUT_MODALITIES = [ + Modality.SENTINEL2_L2A.name, + Modality.LANDSAT.name, +] + + +def _tokenization_config() -> TokenizationConfig: + return TokenizationConfig( + overrides={ + "sentinel2_l2a": S2_SINGLE_BANDSET, + "landsat": LANDSAT_SINGLE_BANDSET, + } + ) + + +def _masking_config( + tokenization_config: TokenizationConfig | None = None, +) -> MaskingConfig: + return MaskingConfig( + strategy_config={ + "type": "random_time_with_decode", + "encode_ratio": 0.5, + "decode_ratio": 0.5, + "only_decode_modalities": ONLY_DECODE_MODALITIES, + }, + tokenization_config=tokenization_config, + ) + + +def build_common_components( + script: str, cmd: SubCmd, run_name: str, cluster: str, overrides: list[str] +) -> CommonComponents: + """Build the common components for an experiment.""" + config = build_common_components_default(script, cmd, run_name, cluster, overrides) + config.training_modalities = [ + Modality.SENTINEL2_L2A.name, + Modality.SENTINEL1.name, + Modality.LANDSAT.name, + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + ] + config.tokenization_config = _tokenization_config() + return config + + +def build_train_module_config( + common: CommonComponents, +) -> ContrastiveLatentMIMTrainModuleConfig: + """Build the train module config for an experiment.""" + return ContrastiveLatentMIMTrainModuleConfig( + optim_config=AdamWConfig(lr=0.0001, weight_decay=0.02, fused=False), + rank_microbatch_size=64, + masking_config=_masking_config(common.tokenization_config), + loss_config=LossConfig( + loss_config={ + "type": "modality_patch_discrimination_masked_negatives", + "tau": 0.1, + "same_target_threshold": 0.999, + "mask_negatives_for_modalities": ONLY_DECODE_MODALITIES, + } + ), + contrastive_config=LossConfig( + loss_config={ + "type": "InfoNCE", + "weight": 0.05, + } + ), + token_exit_cfg={modality: 0 for modality in common.training_modalities}, + max_grad_norm=1.0, + scheduler=CosWithWarmup(warmup_steps=8000), + ema_decay=(1.0, 1.0), + dp_config=DataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + ), + ) + + +def build_dataloader_config(common: CommonComponents) -> OlmoEarthDataLoaderConfig: + """Build the dataloader config for an experiment.""" + return OlmoEarthDataLoaderConfig( + num_workers=16, + global_batch_size=512, + token_budget=2250, + prefetch_factor=4, + sampled_hw_p_list=list(range(1, 13)), + min_patch_size=MIN_PATCH_SIZE, + max_patch_size=MAX_PATCH_SIZE, + work_dir=common.save_folder, + seed=3622, + num_masked_views=2, + masking_config=_masking_config(common.tokenization_config), + ) + + +def build_dataset_config(common: CommonComponents) -> OlmoEarthDatasetConfig: + """Build the dataset config for an experiment.""" + return OlmoEarthDatasetConfig( + h5py_dir="/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828", + training_modalities=common.training_modalities, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build the trainer config for an experiment.""" + MAX_DURATION = Duration.epochs(300) + METRICS_COLLECT_INTERVAL = 10 + CANCEL_CHECK_INTERVAL = 25 + LOAD_STRATEGY = LoadStrategy.if_available + WANDB_USERNAME = "eai-ai2" # nosec + WANDB_PROJECT = "2026_02_08_masked_neg" + PERMANENT_SAVE_INTERVAL = 5000 + EPHERMERAL_SAVE_INTERVAL = 250 + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project=WANDB_PROJECT, + entity=WANDB_USERNAME, + enabled=True, + ) + garbage_collector_callback = GarbageCollectorCallback(gc_interval=1) + EVAL_TASKS = { + "m-eurosat": DownstreamTaskConfig( + dataset="m-eurosat", + embedding_batch_size=128, + num_workers=0, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(4000), + ), + "m_so2sat": DownstreamTaskConfig( + dataset="m-so2sat", + embedding_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(20000), + ), + "mados": DownstreamTaskConfig( + dataset="mados", + embedding_batch_size=128, + probe_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=False, + probe_lr=0.01, + epochs=50, + eval_interval=Duration.steps(4000), + ), + "pastis": DownstreamTaskConfig( + dataset="pastis", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + ), + } + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LOAD_STRATEGY, + save_folder=common.save_folder, + cancel_check_interval=CANCEL_CHECK_INTERVAL, + metrics_collect_interval=METRICS_COLLECT_INTERVAL, + max_duration=MAX_DURATION, + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("speed_monitor", OlmoEarthSpeedMonitorCallback()) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=EVAL_TASKS, + ), + ) + .with_callback("garbage_collector", garbage_collector_callback) + .with_callback("beaker", BeakerCallback()) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=PERMANENT_SAVE_INTERVAL, + ephemeral_save_interval=EPHERMERAL_SAVE_INTERVAL, + ), + ) + ) + return trainer_config + + +def build_visualize_config(common: CommonComponents) -> OlmoEarthVisualizeConfig: + """Build the visualize config for an experiment.""" + return OlmoEarthVisualizeConfig( + num_samples=None, + output_dir=str(f"{common.save_folder}/visualizations"), + std_multiplier=2.0, + ) + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["base_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + ) From 2069b72c8ec3ec2c43b6e8a45bbfee0064e45a1e Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 30 Mar 2026 20:54:50 +0200 Subject: [PATCH 02/31] Add more evals --- ...ase_band_dropout_no_s1_drop_random_time.py | 69 ++++++++++++++++++- 1 file changed, 66 insertions(+), 3 deletions(-) diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py index 7b2ae1ea7..5060f2cfd 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py @@ -31,6 +31,8 @@ from olmoearth_pretrain.data.constants import Modality from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig from olmoearth_pretrain.data.dataset import OlmoEarthDatasetConfig +from olmoearth_pretrain.evals.datasets.normalize import NormMethod +from olmoearth_pretrain.evals.metrics import EvalMetric from olmoearth_pretrain.internal.common import ( build_common_components as build_common_components_default, ) @@ -55,7 +57,10 @@ OlmoEarthSpeedMonitorCallback, OlmoEarthWandBCallback, ) -from olmoearth_pretrain.train.callbacks.evaluator_callback import DownstreamTaskConfig +from olmoearth_pretrain.train.callbacks.evaluator_callback import ( + DownstreamTaskConfig, + EvalMode, +) from olmoearth_pretrain.train.loss import LossConfig from olmoearth_pretrain.train.masking import MaskingConfig from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( @@ -236,15 +241,22 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: num_workers=0, pooling_type=PoolingType.MEAN, norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, eval_interval=Duration.steps(4000), ), "m_so2sat": DownstreamTaskConfig( dataset="m-so2sat", embedding_batch_size=128, - num_workers=8, + num_workers=4, pooling_type=PoolingType.MEAN, norm_stats_from_pretrained=True, eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, ), "mados": DownstreamTaskConfig( dataset="mados", @@ -253,21 +265,72 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: num_workers=8, pooling_type=PoolingType.MEAN, norm_stats_from_pretrained=False, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, probe_lr=0.01, - epochs=50, eval_interval=Duration.steps(4000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MICRO_F1, ), "pastis": DownstreamTaskConfig( dataset="pastis", embedding_batch_size=32, probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MAX, + norm_stats_from_pretrained=True, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MIOU, + ), + "yemen_crop": DownstreamTaskConfig( + dataset="yemen_crop", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + eval_interval=Duration.steps(20000), + probe_lr=0.001, + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "geo_ecosystem_annual_test": DownstreamTaskConfig( + dataset="geo_ecosystem_annual_test", + embedding_batch_size=32, + probe_batch_size=8, num_workers=8, pooling_type=PoolingType.MEAN, norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "canada_wildfire_sat_eval_split": DownstreamTaskConfig( + dataset="canada_wildfire_sat_eval_split", + embedding_batch_size=32, + probe_batch_size=16, + patch_size=5, # TODO: This is changeable but we should know the valid sizes for inputs + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, probe_lr=0.1, eval_interval=Duration.steps(20000), input_modalities=[Modality.SENTINEL2_L2A.name], epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + use_dice_loss=True, + primary_metric=EvalMetric.CLASS_F1, + primary_metric_class=1, ), } trainer_config = ( From 4c9cae45a52728a6cd8d5020e8083a6815632963 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 30 Mar 2026 20:59:51 +0200 Subject: [PATCH 03/31] less random more time --- .../base_band_dropout_no_s1_drop_random_time.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py index 5060f2cfd..447e904ed 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py @@ -131,6 +131,7 @@ def _masking_config( "type": "random_time_with_decode", "encode_ratio": 0.5, "decode_ratio": 0.5, + "random_ratio": 0.25, "only_decode_modalities": ONLY_DECODE_MODALITIES, }, tokenization_config=tokenization_config, From b66430976cddc757896b1b92306ca9b9ffc0c2db Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 30 Mar 2026 21:01:13 +0200 Subject: [PATCH 04/31] less time more random --- .../base_band_dropout_no_s1_drop_random_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py index 447e904ed..9ed0ea4db 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py @@ -131,7 +131,7 @@ def _masking_config( "type": "random_time_with_decode", "encode_ratio": 0.5, "decode_ratio": 0.5, - "random_ratio": 0.25, + "random_ratio": 0.75, "only_decode_modalities": ONLY_DECODE_MODALITIES, }, tokenization_config=tokenization_config, From 06689bd3c725f9f2870472762d820737db6671d4 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Tue, 31 Mar 2026 15:47:03 +0200 Subject: [PATCH 05/31] hack --- olmoearth_pretrain/internal/full_eval_sweep.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/olmoearth_pretrain/internal/full_eval_sweep.py b/olmoearth_pretrain/internal/full_eval_sweep.py index 66e649b1f..1eaa66b99 100644 --- a/olmoearth_pretrain/internal/full_eval_sweep.py +++ b/olmoearth_pretrain/internal/full_eval_sweep.py @@ -1114,6 +1114,9 @@ def main() -> None: commands_to_run = build_commands(args, extra_cli) + logger.info(f"Running {len(commands_to_run)} commands") + logger.info("Actually only running the last 3") + commands_to_run = commands_to_run[-3:] logger.info(f"Running {len(commands_to_run)} commands") for cmd in commands_to_run: logger.info(cmd) From e499f8972b5b4cb6b1bca5711722dd4062d02560 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Tue, 31 Mar 2026 15:53:39 +0200 Subject: [PATCH 06/31] ctrl + z --- olmoearth_pretrain/internal/full_eval_sweep.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/olmoearth_pretrain/internal/full_eval_sweep.py b/olmoearth_pretrain/internal/full_eval_sweep.py index 1eaa66b99..66e649b1f 100644 --- a/olmoearth_pretrain/internal/full_eval_sweep.py +++ b/olmoearth_pretrain/internal/full_eval_sweep.py @@ -1114,9 +1114,6 @@ def main() -> None: commands_to_run = build_commands(args, extra_cli) - logger.info(f"Running {len(commands_to_run)} commands") - logger.info("Actually only running the last 3") - commands_to_run = commands_to_run[-3:] logger.info(f"Running {len(commands_to_run)} commands") for cmd in commands_to_run: logger.info(cmd) From 6318d6e5023e2eb38e3a5a04096e495f5b4ff38c Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Tue, 31 Mar 2026 20:50:18 +0200 Subject: [PATCH 07/31] random only --- .../base_band_dropout_no_s1_drop_random.py | 422 ++++++++++++++++++ 1 file changed, 422 insertions(+) create mode 100644 scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random.py diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random.py new file mode 100644 index 000000000..6fdc7938a --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random.py @@ -0,0 +1,422 @@ +"""Base script for single bandset + random band dropout (no S1) + random time with decode masking + masked-negatives loss. + +- Single bandset S2 (all 12 bands) / Landsat (all 11 bands) +- Random band dropout (rate ~ Uniform(0, 0.3)) on S2 and Landsat only (no S1 dropout) +- Random time with decode masking +- Masked negatives patch discrimination loss +- InfoNCE weight 0.05 +- Rank microbatch size 64 +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel.data_parallel import ( + DataParallelConfig, + DataParallelType, +) +from olmo_core.optim import AdamWConfig +from olmo_core.optim.scheduler import CosWithWarmup +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig + +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig +from olmoearth_pretrain.data.dataset import OlmoEarthDatasetConfig +from olmoearth_pretrain.evals.datasets.normalize import NormMethod +from olmoearth_pretrain.evals.metrics import EvalMetric +from olmoearth_pretrain.internal.common import ( + build_common_components as build_common_components_default, +) +from olmoearth_pretrain.internal.experiment import ( + CommonComponents, + OlmoEarthVisualizeConfig, + SubCmd, + main, +) +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexi_vit import ( + PoolingType, +) +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig +from olmoearth_pretrain.nn.tokenization import ModalityTokenization, TokenizationConfig +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthSpeedMonitorCallback, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import ( + DownstreamTaskConfig, + EvalMode, +) +from olmoearth_pretrain.train.loss import LossConfig +from olmoearth_pretrain.train.masking import MaskingConfig +from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( + ContrastiveLatentMIMTrainModuleConfig, +) + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 +RANDOM_BAND_DROPOUT_MAX_RATE = 0.2 + +S2_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + [ + "B02", + "B03", + "B04", + "B08", + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12", + "B01", + "B09", + ], + ] +) + +LANDSAT_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], + ] +) + +ONLY_DECODE_MODALITIES = [ + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, +] + +# No S1 dropout — only apply band dropout to S2 and Landsat. +BAND_DROPOUT_MODALITIES = [ + Modality.SENTINEL2_L2A.name, + Modality.LANDSAT.name, +] + + +def _tokenization_config() -> TokenizationConfig: + return TokenizationConfig( + overrides={ + "sentinel2_l2a": S2_SINGLE_BANDSET, + "landsat": LANDSAT_SINGLE_BANDSET, + } + ) + + +def _masking_config( + tokenization_config: TokenizationConfig | None = None, +) -> MaskingConfig: + return MaskingConfig( + strategy_config={ + "type": "random_with_decode", + "encode_ratio": 0.5, + "decode_ratio": 0.5, + "only_decode_modalities": ONLY_DECODE_MODALITIES, + }, + tokenization_config=tokenization_config, + ) + + +def build_common_components( + script: str, cmd: SubCmd, run_name: str, cluster: str, overrides: list[str] +) -> CommonComponents: + """Build the common components for an experiment.""" + config = build_common_components_default(script, cmd, run_name, cluster, overrides) + config.training_modalities = [ + Modality.SENTINEL2_L2A.name, + Modality.SENTINEL1.name, + Modality.LANDSAT.name, + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + ] + config.tokenization_config = _tokenization_config() + return config + + +def build_train_module_config( + common: CommonComponents, +) -> ContrastiveLatentMIMTrainModuleConfig: + """Build the train module config for an experiment.""" + return ContrastiveLatentMIMTrainModuleConfig( + optim_config=AdamWConfig(lr=0.0001, weight_decay=0.02, fused=False), + rank_microbatch_size=64, + masking_config=_masking_config(common.tokenization_config), + loss_config=LossConfig( + loss_config={ + "type": "modality_patch_discrimination_masked_negatives", + "tau": 0.1, + "same_target_threshold": 0.999, + "mask_negatives_for_modalities": ONLY_DECODE_MODALITIES, + } + ), + contrastive_config=LossConfig( + loss_config={ + "type": "InfoNCE", + "weight": 0.05, + } + ), + token_exit_cfg={modality: 0 for modality in common.training_modalities}, + max_grad_norm=1.0, + scheduler=CosWithWarmup(warmup_steps=8000), + ema_decay=(1.0, 1.0), + dp_config=DataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + ), + ) + + +def build_dataloader_config(common: CommonComponents) -> OlmoEarthDataLoaderConfig: + """Build the dataloader config for an experiment.""" + return OlmoEarthDataLoaderConfig( + num_workers=16, + global_batch_size=512, + token_budget=2250, + prefetch_factor=4, + sampled_hw_p_list=list(range(1, 13)), + min_patch_size=MIN_PATCH_SIZE, + max_patch_size=MAX_PATCH_SIZE, + work_dir=common.save_folder, + seed=3622, + num_masked_views=2, + masking_config=_masking_config(common.tokenization_config), + ) + + +def build_dataset_config(common: CommonComponents) -> OlmoEarthDatasetConfig: + """Build the dataset config for an experiment.""" + return OlmoEarthDatasetConfig( + h5py_dir="/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828", + training_modalities=common.training_modalities, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build the trainer config for an experiment.""" + MAX_DURATION = Duration.epochs(300) + METRICS_COLLECT_INTERVAL = 10 + CANCEL_CHECK_INTERVAL = 25 + LOAD_STRATEGY = LoadStrategy.if_available + WANDB_USERNAME = "eai-ai2" # nosec + WANDB_PROJECT = "2026_02_08_masked_neg" + PERMANENT_SAVE_INTERVAL = 5000 + EPHERMERAL_SAVE_INTERVAL = 250 + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project=WANDB_PROJECT, + entity=WANDB_USERNAME, + enabled=True, + ) + garbage_collector_callback = GarbageCollectorCallback(gc_interval=1) + EVAL_TASKS = { + "m-eurosat": DownstreamTaskConfig( + dataset="m-eurosat", + embedding_batch_size=128, + num_workers=0, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + eval_interval=Duration.steps(4000), + ), + "m_so2sat": DownstreamTaskConfig( + dataset="m-so2sat", + embedding_batch_size=128, + num_workers=4, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + ), + "mados": DownstreamTaskConfig( + dataset="mados", + embedding_batch_size=128, + probe_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=False, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(4000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MICRO_F1, + ), + "pastis": DownstreamTaskConfig( + dataset="pastis", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MAX, + norm_stats_from_pretrained=True, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MIOU, + ), + "yemen_crop": DownstreamTaskConfig( + dataset="yemen_crop", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + eval_interval=Duration.steps(20000), + probe_lr=0.001, + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "geo_ecosystem_annual_test": DownstreamTaskConfig( + dataset="geo_ecosystem_annual_test", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "canada_wildfire_sat_eval_split": DownstreamTaskConfig( + dataset="canada_wildfire_sat_eval_split", + embedding_batch_size=32, + probe_batch_size=16, + patch_size=5, # TODO: This is changeable but we should know the valid sizes for inputs + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + use_dice_loss=True, + primary_metric=EvalMetric.CLASS_F1, + primary_metric_class=1, + ), + } + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LOAD_STRATEGY, + save_folder=common.save_folder, + cancel_check_interval=CANCEL_CHECK_INTERVAL, + metrics_collect_interval=METRICS_COLLECT_INTERVAL, + max_duration=MAX_DURATION, + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("speed_monitor", OlmoEarthSpeedMonitorCallback()) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=EVAL_TASKS, + ), + ) + .with_callback("garbage_collector", garbage_collector_callback) + .with_callback("beaker", BeakerCallback()) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=PERMANENT_SAVE_INTERVAL, + ephemeral_save_interval=EPHERMERAL_SAVE_INTERVAL, + ), + ) + ) + return trainer_config + + +def build_visualize_config(common: CommonComponents) -> OlmoEarthVisualizeConfig: + """Build the visualize config for an experiment.""" + return OlmoEarthVisualizeConfig( + num_samples=None, + output_dir=str(f"{common.save_folder}/visualizations"), + std_multiplier=2.0, + ) + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["base_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + ) From eeae879a3a0e6b1d6e0c95d96bbdb2bc714c8168 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 1 Apr 2026 16:11:15 +0200 Subject: [PATCH 08/31] back to normal --- .../base_band_dropout_no_s1_drop_random_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py index 9ed0ea4db..23c5928d7 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py @@ -131,7 +131,7 @@ def _masking_config( "type": "random_time_with_decode", "encode_ratio": 0.5, "decode_ratio": 0.5, - "random_ratio": 0.75, + "random_ratio": 0.5, "only_decode_modalities": ONLY_DECODE_MODALITIES, }, tokenization_config=tokenization_config, From 87df9bf50752c88d283dc58c95408dfd4b9f4bed Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 1 Apr 2026 16:28:08 +0200 Subject: [PATCH 09/31] :facepalm: --- .../base_band_dropout_no_s1_drop_random_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py index 23c5928d7..198160cd9 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py @@ -278,7 +278,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: embedding_batch_size=32, probe_batch_size=8, num_workers=2, - pooling_type=PoolingType.MAX, + pooling_type=PoolingType.MEAN, norm_stats_from_pretrained=True, probe_lr=0.1, eval_interval=Duration.steps(20000), From 09a130ddc16214f400c7a4796ec5f082b803bffa Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 1 Apr 2026 16:29:03 +0200 Subject: [PATCH 10/31] be consistent --- olmoearth_pretrain/internal/all_evals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmoearth_pretrain/internal/all_evals.py b/olmoearth_pretrain/internal/all_evals.py index 5e4c6b8d9..e2e29d554 100644 --- a/olmoearth_pretrain/internal/all_evals.py +++ b/olmoearth_pretrain/internal/all_evals.py @@ -181,7 +181,7 @@ def load_user_module(path: str) -> Any: embedding_batch_size=32, probe_batch_size=8, num_workers=2, - pooling_type=PoolingType.MAX, + pooling_type=PoolingType.MEAN, norm_stats_from_pretrained=True, probe_lr=0.1, eval_interval=Duration.epochs(50), From 4375d6ec72456d11b9063dbb11732325bd4a3f74 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 1 Apr 2026 16:31:04 +0200 Subject: [PATCH 11/31] same fix --- .../base_band_dropout_no_s1_drop_random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random.py index 6fdc7938a..76fb3575f 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random.py +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random.py @@ -277,7 +277,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: embedding_batch_size=32, probe_batch_size=8, num_workers=2, - pooling_type=PoolingType.MAX, + pooling_type=PoolingType.MEAN, norm_stats_from_pretrained=True, probe_lr=0.1, eval_interval=Duration.steps(20000), From 0177de49aa1d380a7dd8f1a75be77bc18fa6ed99 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Thu, 2 Apr 2026 10:41:43 +0200 Subject: [PATCH 12/31] Update script --- scripts/estimate_token_ratios.py | 92 ++++++++++++++++++++++++++++---- 1 file changed, 83 insertions(+), 9 deletions(-) diff --git a/scripts/estimate_token_ratios.py b/scripts/estimate_token_ratios.py index a5b3c5718..4c2a29711 100644 --- a/scripts/estimate_token_ratios.py +++ b/scripts/estimate_token_ratios.py @@ -24,6 +24,7 @@ ) from olmoearth_pretrain.data.dataset import OlmoEarthSample from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample, MaskValue +from olmoearth_pretrain.nn.tokenization import ModalityTokenization, TokenizationConfig from olmoearth_pretrain.train.masking import MaskingConfig logger = logging.getLogger(__name__) @@ -56,6 +57,42 @@ ], } +# Single bandset tokenization configs (from single_bandset_band_dropout experiments) +S2_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + [ + "B02", + "B03", + "B04", + "B08", + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12", + "B01", + "B09", + ], + ] +) + +LANDSAT_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], + ] +) + + +def build_single_bandset_tokenization_config() -> TokenizationConfig: + """Build a TokenizationConfig with single bandset for S2 and Landsat.""" + return TokenizationConfig( + overrides={ + "sentinel2_l2a": S2_SINGLE_BANDSET, + "landsat": LANDSAT_SINGLE_BANDSET, + } + ) + @dataclass class TokenRatioResult: @@ -252,6 +289,7 @@ def estimate_token_ratios( missing_prob: float = 0.1, seed: int = 42, track_per_modality: bool = False, + tokenization_config: TokenizationConfig | None = None, ) -> list[TokenRatioResult]: """Estimate token encode/decode ratios by sampling many configurations. @@ -266,6 +304,7 @@ def estimate_token_ratios( missing_prob: Probability of a modality/timestep being missing. seed: Random seed. track_per_modality: Whether to track per-modality statistics. + tokenization_config: tokenization_config Returns: List of TokenRatioResult for each sample. @@ -279,7 +318,10 @@ def estimate_token_ratios( # Build the masking strategy config_copy = masking_config.copy() - masking_strategy = MaskingConfig(strategy_config=config_copy).build() + masking_strategy = MaskingConfig( + strategy_config=config_copy, + tokenization_config=tokenization_config, + ).build() for _ in tqdm(range(num_samples), desc="Sampling"): # Sample patch_size and hw_p @@ -291,9 +333,11 @@ def estimate_token_ratios( # Estimate max_t based on token budget (simplified version) # This mimics OlmoEarthSample._get_max_t_within_token_budget tokens_per_timestep = estimate_tokens_per_timestep( - training_modalities, sampled_hw_p + training_modalities, sampled_hw_p, tokenization_config + ) + static_tokens = estimate_static_tokens( + training_modalities, sampled_hw_p, tokenization_config ) - static_tokens = estimate_static_tokens(training_modalities, sampled_hw_p) available_budget = token_budget - static_tokens max_t = ( max(1, int(available_budget / tokens_per_timestep)) @@ -301,7 +345,6 @@ def estimate_token_ratios( else 12 ) max_t = min(max_t, 12) # Cap at MAX_SEQUENCE_LENGTH - # Create synthetic sample sample = create_synthetic_sample( training_modalities=training_modalities, @@ -336,32 +379,50 @@ def estimate_token_ratios( return results +def _get_num_bandsets( + modality_name: str, + modality_spec: ModalitySpec, + tokenization_config: TokenizationConfig | None, +) -> int: + """Get number of bandsets, respecting tokenization config overrides.""" + if tokenization_config is not None: + return tokenization_config.get_num_bandsets(modality_name) + return modality_spec.num_band_sets + + def estimate_tokens_per_timestep( training_modalities: list[str], sampled_hw_p: int, + tokenization_config: TokenizationConfig | None = None, ) -> int: """Estimate tokens per timestep for spatiotemporal modalities.""" tokens = 0 for modality_name in training_modalities: modality_spec = Modality.get(modality_name) if modality_spec.is_spacetime_varying: - # tokens = h_p * w_p * num_bandsets - tokens += sampled_hw_p * sampled_hw_p * modality_spec.num_band_sets + num_bandsets = _get_num_bandsets( + modality_name, modality_spec, tokenization_config + ) + tokens += sampled_hw_p * sampled_hw_p * num_bandsets return tokens def estimate_static_tokens( training_modalities: list[str], sampled_hw_p: int, + tokenization_config: TokenizationConfig | None = None, ) -> int: """Estimate tokens for static/space-only modalities.""" tokens = 0 for modality_name in training_modalities: modality_spec = Modality.get(modality_name) + num_bandsets = _get_num_bandsets( + modality_name, modality_spec, tokenization_config + ) if modality_spec.is_space_only_varying: - tokens += sampled_hw_p * sampled_hw_p * modality_spec.num_band_sets + tokens += sampled_hw_p * sampled_hw_p * num_bandsets elif modality_spec.is_static_in_space_and_time: - tokens += modality_spec.num_band_sets + tokens += num_bandsets return tokens @@ -592,6 +653,11 @@ def main() -> None: action="store_true", help="Show per-modality breakdown statistics", ) + parser.add_argument( + "--single_bandset", + action="store_true", + help="Use single bandset tokenization for S2 and Landsat", + ) args = parser.parse_args() logging.basicConfig(level=logging.WARNING) @@ -614,7 +680,7 @@ def main() -> None: Modality.CDL.name, Modality.WORLDCEREAL.name, ] - elif args.masking_type == "random_with_decode": + elif args.masking_type in ("random_with_decode", "random_time_with_decode"): masking_config["only_decode_modalities"] = [ Modality.WORLDCOVER.name, Modality.SRTM.name, @@ -624,8 +690,15 @@ def main() -> None: Modality.WORLDCEREAL.name, ] + # Build tokenization config + tokenization_config = None + if args.single_bandset: + tokenization_config = build_single_bandset_tokenization_config() + print(f"Running with masking config: {masking_config}") print(f"Training modalities: {DEFAULT_TRAINING_MODALITIES}") + if tokenization_config is not None: + print(f"Tokenization overrides: {list(tokenization_config.overrides.keys())}") print(f"Sampling {args.num_samples} configurations...") results = estimate_token_ratios( @@ -639,6 +712,7 @@ def main() -> None: missing_prob=args.missing_prob, seed=args.seed, track_per_modality=args.per_modality, + tokenization_config=tokenization_config, ) print_statistics(results, show_per_modality=args.per_modality) From 4bdaeb598ff2b9742827f6a3a23541b75c20a22d Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 3 Apr 2026 08:17:54 +0200 Subject: [PATCH 13/31] oil spill eval is massive leading to OOM - lets skip it --- olmoearth_pretrain/internal/all_evals.py | 28 ++++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/olmoearth_pretrain/internal/all_evals.py b/olmoearth_pretrain/internal/all_evals.py index e2e29d554..6168d427f 100644 --- a/olmoearth_pretrain/internal/all_evals.py +++ b/olmoearth_pretrain/internal/all_evals.py @@ -384,20 +384,20 @@ def load_user_module(path: str) -> Any: epochs=50, eval_mode=EvalMode.LINEAR_PROBE, ), - "oil_spill_detection": DownstreamTaskConfig( - dataset="oil_spill_detection", - embedding_batch_size=128, - probe_batch_size=8, - num_workers=8, - pooling_type=PoolingType.MEAN, - norm_stats_from_pretrained=True, - norm_method=NormMethod.NORM_NO_CLIP_2_STD, - probe_lr=0.01, - eval_interval=Duration.epochs(10), - input_modalities=[Modality.SENTINEL1.name], - epochs=50, - eval_mode=EvalMode.LINEAR_PROBE, - ), + # "oil_spill_detection": DownstreamTaskConfig( + # dataset="oil_spill_detection", + # embedding_batch_size=128, + # probe_batch_size=8, + # num_workers=8, + # pooling_type=PoolingType.MEAN, + # norm_stats_from_pretrained=True, + # norm_method=NormMethod.NORM_NO_CLIP_2_STD, + # probe_lr=0.01, + # eval_interval=Duration.epochs(10), + # input_modalities=[Modality.SENTINEL1.name], + # epochs=50, + # eval_mode=EvalMode.LINEAR_PROBE, + # ), } FT_EVAL_TASKS = { From 186ad56603bbd21ec50b70aeeb7b50c1cda90984 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 3 Apr 2026 08:18:30 +0200 Subject: [PATCH 14/31] Add comment explaining --- olmoearth_pretrain/internal/all_evals.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/olmoearth_pretrain/internal/all_evals.py b/olmoearth_pretrain/internal/all_evals.py index 6168d427f..10e6182b4 100644 --- a/olmoearth_pretrain/internal/all_evals.py +++ b/olmoearth_pretrain/internal/all_evals.py @@ -384,6 +384,8 @@ def load_user_module(path: str) -> Any: epochs=50, eval_mode=EvalMode.LINEAR_PROBE, ), + # this eval is very large and can lead to + # OOM errors. Skipping for now. # "oil_spill_detection": DownstreamTaskConfig( # dataset="oil_spill_detection", # embedding_batch_size=128, From e4190e592b02f21ef33ca195f3c03b9d9b86f9d3 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Tue, 7 Apr 2026 15:43:27 +0100 Subject: [PATCH 15/31] Add single bandset models to flop calculations --- scripts/tools/20251111_flops.py | 81 +++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/scripts/tools/20251111_flops.py b/scripts/tools/20251111_flops.py index 7cf573d2c..20655b075 100644 --- a/scripts/tools/20251111_flops.py +++ b/scripts/tools/20251111_flops.py @@ -23,8 +23,41 @@ from olmoearth_pretrain.evals.models.dinov3.dinov3 import DINOv3, DinoV3Models from olmoearth_pretrain.nn.flexi_vit import Encoder from olmoearth_pretrain.nn.pooling import PoolingType +from olmoearth_pretrain.nn.tokenization import ModalityTokenization, TokenizationConfig from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue +S2_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + [ + "B02", + "B03", + "B04", + "B08", + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12", + "B01", + "B09", + ], + ] +) + +LANDSAT_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], + ] +) + +SINGLE_BANDSET_CONFIG = TokenizationConfig( + overrides={ + "sentinel2_l2a": S2_SINGLE_BANDSET, + "landsat": LANDSAT_SINGLE_BANDSET, + } +) + def count_params(model: torch.nn.Module, trainable_only: bool = True): """count_params.""" @@ -165,6 +198,54 @@ def flops_per_model(model, samples: list[MaskedOlmoEarthSample, int, bool]) -> f supported_modalities=modalities, max_sequence_length=24, ), + Encoder( # large encoder, single bandset + embedding_size=1024, + min_patch_size=1, + max_patch_size=8, + num_heads=16, + mlp_ratio=4, + depth=24, + drop_path=0.1, + supported_modalities=modalities, + max_sequence_length=24, + tokenization_config=SINGLE_BANDSET_CONFIG, + ), + Encoder( # base encoder, single bandset + embedding_size=768, + min_patch_size=1, + max_patch_size=8, + num_heads=12, + mlp_ratio=4, + depth=12, + drop_path=0.1, + supported_modalities=modalities, + max_sequence_length=24, + tokenization_config=SINGLE_BANDSET_CONFIG, + ), + Encoder( # tiny encoder, single bandset + embedding_size=192, + min_patch_size=1, + max_patch_size=8, + num_heads=3, + mlp_ratio=4, + depth=12, + drop_path=0.1, + supported_modalities=modalities, + max_sequence_length=24, + tokenization_config=SINGLE_BANDSET_CONFIG, + ), + Encoder( # nano encoder, single bandset + embedding_size=128, + min_patch_size=1, + max_patch_size=8, + num_heads=8, + mlp_ratio=4, + depth=4, + drop_path=0.1, + supported_modalities=modalities, + max_sequence_length=24, + tokenization_config=SINGLE_BANDSET_CONFIG, + ), Terramind("base"), Terramind("large"), # for the models below, the paths will From 23a1a8e47c73764d8193c0dc7f1a5cbc0bef7b3a Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 13 Apr 2026 15:28:29 +0100 Subject: [PATCH 16/31] Add loss ablation --- ...ropout_no_s1_drop_random_time_base_loss.py | 421 ++++++++++++++++++ 1 file changed, 421 insertions(+) create mode 100644 scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py new file mode 100644 index 000000000..8f23ee1b8 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py @@ -0,0 +1,421 @@ +"""Base script for single bandset + random band dropout (no S1) + random time with decode masking + base loss. + +- Single bandset S2 (all 12 bands) / Landsat (all 11 bands) +- Random band dropout (rate ~ Uniform(0, 0.3)) on S2 and Landsat only (no S1 dropout) +- Random time with decode masking +- Base patch discrimination loss (modality_patch_discrimination_new) — same as scripts/official/script.py +- InfoNCE weight 0.1 +- Rank microbatch size 64 +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel.data_parallel import ( + DataParallelConfig, + DataParallelType, +) +from olmo_core.optim import AdamWConfig +from olmo_core.optim.scheduler import CosWithWarmup +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig + +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig +from olmoearth_pretrain.data.dataset import OlmoEarthDatasetConfig +from olmoearth_pretrain.evals.datasets.normalize import NormMethod +from olmoearth_pretrain.evals.metrics import EvalMetric +from olmoearth_pretrain.internal.common import ( + build_common_components as build_common_components_default, +) +from olmoearth_pretrain.internal.experiment import ( + CommonComponents, + OlmoEarthVisualizeConfig, + SubCmd, + main, +) +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexi_vit import ( + PoolingType, +) +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig +from olmoearth_pretrain.nn.tokenization import ModalityTokenization, TokenizationConfig +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthSpeedMonitorCallback, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import ( + DownstreamTaskConfig, + EvalMode, +) +from olmoearth_pretrain.train.loss import LossConfig +from olmoearth_pretrain.train.masking import MaskingConfig +from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( + ContrastiveLatentMIMTrainModuleConfig, +) + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 +RANDOM_BAND_DROPOUT_MAX_RATE = 0.2 + +S2_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + [ + "B02", + "B03", + "B04", + "B08", + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12", + "B01", + "B09", + ], + ] +) + +LANDSAT_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], + ] +) + +ONLY_DECODE_MODALITIES = [ + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, +] + +# No S1 dropout — only apply band dropout to S2 and Landsat. +BAND_DROPOUT_MODALITIES = [ + Modality.SENTINEL2_L2A.name, + Modality.LANDSAT.name, +] + + +def _tokenization_config() -> TokenizationConfig: + return TokenizationConfig( + overrides={ + "sentinel2_l2a": S2_SINGLE_BANDSET, + "landsat": LANDSAT_SINGLE_BANDSET, + } + ) + + +def _masking_config( + tokenization_config: TokenizationConfig | None = None, +) -> MaskingConfig: + return MaskingConfig( + strategy_config={ + "type": "random_time_with_decode", + "encode_ratio": 0.5, + "decode_ratio": 0.5, + "random_ratio": 0.5, + "only_decode_modalities": ONLY_DECODE_MODALITIES, + }, + tokenization_config=tokenization_config, + ) + + +def build_common_components( + script: str, cmd: SubCmd, run_name: str, cluster: str, overrides: list[str] +) -> CommonComponents: + """Build the common components for an experiment.""" + config = build_common_components_default(script, cmd, run_name, cluster, overrides) + config.training_modalities = [ + Modality.SENTINEL2_L2A.name, + Modality.SENTINEL1.name, + Modality.LANDSAT.name, + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + ] + config.tokenization_config = _tokenization_config() + return config + + +def build_train_module_config( + common: CommonComponents, +) -> ContrastiveLatentMIMTrainModuleConfig: + """Build the train module config for an experiment.""" + return ContrastiveLatentMIMTrainModuleConfig( + optim_config=AdamWConfig(lr=0.0001, weight_decay=0.02, fused=False), + rank_microbatch_size=64, + masking_config=_masking_config(common.tokenization_config), + loss_config=LossConfig( + loss_config={ + "type": "modality_patch_discrimination_new", + "tau": 0.1, + } + ), + contrastive_config=LossConfig( + loss_config={ + "type": "InfoNCE", + "weight": 0.1, + } + ), + token_exit_cfg={modality: 0 for modality in common.training_modalities}, + max_grad_norm=1.0, + scheduler=CosWithWarmup(warmup_steps=8000), + ema_decay=(1.0, 1.0), + dp_config=DataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + ), + ) + + +def build_dataloader_config(common: CommonComponents) -> OlmoEarthDataLoaderConfig: + """Build the dataloader config for an experiment.""" + return OlmoEarthDataLoaderConfig( + num_workers=16, + global_batch_size=512, + token_budget=2250, + prefetch_factor=4, + sampled_hw_p_list=list(range(1, 13)), + min_patch_size=MIN_PATCH_SIZE, + max_patch_size=MAX_PATCH_SIZE, + work_dir=common.save_folder, + seed=3622, + num_masked_views=2, + masking_config=_masking_config(common.tokenization_config), + ) + + +def build_dataset_config(common: CommonComponents) -> OlmoEarthDatasetConfig: + """Build the dataset config for an experiment.""" + return OlmoEarthDatasetConfig( + h5py_dir="/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828", + training_modalities=common.training_modalities, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build the trainer config for an experiment.""" + MAX_DURATION = Duration.epochs(300) + METRICS_COLLECT_INTERVAL = 10 + CANCEL_CHECK_INTERVAL = 25 + LOAD_STRATEGY = LoadStrategy.if_available + WANDB_USERNAME = "eai-ai2" # nosec + WANDB_PROJECT = "2026_02_08_masked_neg" + PERMANENT_SAVE_INTERVAL = 5000 + EPHERMERAL_SAVE_INTERVAL = 250 + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project=WANDB_PROJECT, + entity=WANDB_USERNAME, + enabled=True, + ) + garbage_collector_callback = GarbageCollectorCallback(gc_interval=1) + EVAL_TASKS = { + "m-eurosat": DownstreamTaskConfig( + dataset="m-eurosat", + embedding_batch_size=128, + num_workers=0, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + eval_interval=Duration.steps(4000), + ), + "m_so2sat": DownstreamTaskConfig( + dataset="m-so2sat", + embedding_batch_size=128, + num_workers=4, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + ), + "mados": DownstreamTaskConfig( + dataset="mados", + embedding_batch_size=128, + probe_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=False, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(4000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MICRO_F1, + ), + "pastis": DownstreamTaskConfig( + dataset="pastis", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MIOU, + ), + "yemen_crop": DownstreamTaskConfig( + dataset="yemen_crop", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + eval_interval=Duration.steps(20000), + probe_lr=0.001, + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "geo_ecosystem_annual_test": DownstreamTaskConfig( + dataset="geo_ecosystem_annual_test", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "canada_wildfire_sat_eval_split": DownstreamTaskConfig( + dataset="canada_wildfire_sat_eval_split", + embedding_batch_size=32, + probe_batch_size=16, + patch_size=5, # TODO: This is changeable but we should know the valid sizes for inputs + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + use_dice_loss=True, + primary_metric=EvalMetric.CLASS_F1, + primary_metric_class=1, + ), + } + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LOAD_STRATEGY, + save_folder=common.save_folder, + cancel_check_interval=CANCEL_CHECK_INTERVAL, + metrics_collect_interval=METRICS_COLLECT_INTERVAL, + max_duration=MAX_DURATION, + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("speed_monitor", OlmoEarthSpeedMonitorCallback()) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=EVAL_TASKS, + ), + ) + .with_callback("garbage_collector", garbage_collector_callback) + .with_callback("beaker", BeakerCallback()) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=PERMANENT_SAVE_INTERVAL, + ephemeral_save_interval=EPHERMERAL_SAVE_INTERVAL, + ), + ) + ) + return trainer_config + + +def build_visualize_config(common: CommonComponents) -> OlmoEarthVisualizeConfig: + """Build the visualize config for an experiment.""" + return OlmoEarthVisualizeConfig( + num_samples=None, + output_dir=str(f"{common.save_folder}/visualizations"), + std_multiplier=2.0, + ) + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["base_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + ) From 3d8ee10e9088723ba50cfe30d099375f5ddaabc7 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 13 Apr 2026 15:59:34 +0100 Subject: [PATCH 17/31] Add ablation --- ...dropout_no_s1_drop_random_time_base_loss.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py index 8f23ee1b8..5929fd069 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py @@ -1,10 +1,10 @@ -"""Base script for single bandset + random band dropout (no S1) + random time with decode masking + base loss. +"""Base script for single bandset + random band dropout (no S1) + modality_cross_random masking + masked-negatives loss. - Single bandset S2 (all 12 bands) / Landsat (all 11 bands) - Random band dropout (rate ~ Uniform(0, 0.3)) on S2 and Landsat only (no S1 dropout) -- Random time with decode masking -- Base patch discrimination loss (modality_patch_discrimination_new) — same as scripts/official/script.py -- InfoNCE weight 0.1 +- modality_cross_random masking (same as scripts/official/script.py) +- Masked negatives patch discrimination loss (same as base_band_dropout_no_s1_drop_random.py) +- InfoNCE weight 0.05 - Rank microbatch size 64 """ @@ -128,10 +128,10 @@ def _masking_config( ) -> MaskingConfig: return MaskingConfig( strategy_config={ - "type": "random_time_with_decode", + "type": "modality_cross_random", "encode_ratio": 0.5, "decode_ratio": 0.5, - "random_ratio": 0.5, + "allow_encoding_decoding_same_bandset": True, "only_decode_modalities": ONLY_DECODE_MODALITIES, }, tokenization_config=tokenization_config, @@ -168,14 +168,16 @@ def build_train_module_config( masking_config=_masking_config(common.tokenization_config), loss_config=LossConfig( loss_config={ - "type": "modality_patch_discrimination_new", + "type": "modality_patch_discrimination_masked_negatives", "tau": 0.1, + "same_target_threshold": 0.999, + "mask_negatives_for_modalities": ONLY_DECODE_MODALITIES, } ), contrastive_config=LossConfig( loss_config={ "type": "InfoNCE", - "weight": 0.1, + "weight": 0.05, } ), token_exit_cfg={modality: 0 for modality in common.training_modalities}, From 3e2d8fdb6ab652f426d6ead257e8e44c181035ff Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Mon, 13 Apr 2026 16:01:53 +0100 Subject: [PATCH 18/31] rename --- ..._base_loss.py => base_band_dropout_no_s1_drop_base_masking.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/vnext/single_bandset_band_dropout/{base_band_dropout_no_s1_drop_random_time_base_loss.py => base_band_dropout_no_s1_drop_base_masking.py} (100%) diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_base_masking.py similarity index 100% rename from scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_base_loss.py rename to scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_base_masking.py From cafc95a6da5ff137166f3c787ed85078ca59719a Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Tue, 21 Apr 2026 07:59:14 +0200 Subject: [PATCH 19/31] update script defaults, add base sweeps --- scripts/estimate_token_ratios.py | 2 +- scripts/vnext/single_bandset_band_dropout/base_launch.sh | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 scripts/vnext/single_bandset_band_dropout/base_launch.sh diff --git a/scripts/estimate_token_ratios.py b/scripts/estimate_token_ratios.py index 4c2a29711..d07444a18 100644 --- a/scripts/estimate_token_ratios.py +++ b/scripts/estimate_token_ratios.py @@ -633,7 +633,7 @@ def main() -> None: parser.add_argument( "--masking_type", type=str, - default="modality_cross_random", + default="random_time_with_decode", help="Type of masking strategy", ) parser.add_argument( diff --git a/scripts/vnext/single_bandset_band_dropout/base_launch.sh b/scripts/vnext/single_bandset_band_dropout/base_launch.sh new file mode 100644 index 000000000..927f7ab28 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/base_launch.sh @@ -0,0 +1,4 @@ +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0001_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0001_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent From 8ef6d2058460da9bfd54ceb981db8eea3cb936a8 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Tue, 21 Apr 2026 08:09:33 +0200 Subject: [PATCH 20/31] Add lr 0.0003 --- scripts/vnext/single_bandset_band_dropout/base_launch.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/vnext/single_bandset_band_dropout/base_launch.sh b/scripts/vnext/single_bandset_band_dropout/base_launch.sh index 927f7ab28..c9d75ad54 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/base_launch.sh @@ -2,3 +2,5 @@ python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_ra python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0001_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent From c9ff50b7dfc3c2a828d4e2fb949eb632cefbc16d Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Tue, 21 Apr 2026 08:47:22 +0200 Subject: [PATCH 21/31] run on 8 gpus, dont repeat the original run --- .../vnext/single_bandset_band_dropout/base_launch.sh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/scripts/vnext/single_bandset_band_dropout/base_launch.sh b/scripts/vnext/single_bandset_band_dropout/base_launch.sh index c9d75ad54..248f06504 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/base_launch.sh @@ -1,6 +1,5 @@ -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0001_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0001_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0001_wd0.002_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.02_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.002_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.02_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.002_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent From 04dc50b393e37a58b16061214ca8005822681edb Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 22 Apr 2026 08:28:55 +0200 Subject: [PATCH 22/31] nano and tiny runs --- .../vnext/single_bandset_band_dropout/nano.py | 72 +++++++++++++++++++ .../nano_launch.sh | 4 ++ .../vnext/single_bandset_band_dropout/tiny.py | 72 +++++++++++++++++++ .../tiny_launch.sh | 6 ++ 4 files changed, 154 insertions(+) create mode 100644 scripts/vnext/single_bandset_band_dropout/nano.py create mode 100644 scripts/vnext/single_bandset_band_dropout/nano_launch.sh create mode 100644 scripts/vnext/single_bandset_band_dropout/tiny.py create mode 100644 scripts/vnext/single_bandset_band_dropout/tiny_launch.sh diff --git a/scripts/vnext/single_bandset_band_dropout/nano.py b/scripts/vnext/single_bandset_band_dropout/nano.py new file mode 100644 index 000000000..fdde25ec0 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/nano.py @@ -0,0 +1,72 @@ +"""Trying to prototype fitting everything into olmo core.""" + +import logging + +from base_band_dropout_no_s1_drop_random_time import ( + BAND_DROPOUT_MODALITIES, + RANDOM_BAND_DROPOUT_MAX_RATE, + build_common_components, + build_dataloader_config, + build_dataset_config, + build_train_module_config, + build_trainer_config, +) + +from olmoearth_pretrain.internal.experiment import CommonComponents, main +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["nano"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + ) diff --git a/scripts/vnext/single_bandset_band_dropout/nano_launch.sh b/scripts/vnext/single_bandset_band_dropout/nano_launch.sh new file mode 100644 index 000000000..ef169a8c2 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/nano_launch.sh @@ -0,0 +1,4 @@ +python scripts/vnext/single_bandset_band_dropout/nano.py launch tiny_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/official/nano.py launch tiny_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/official/nano.py launch tiny_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/official/nano.py launch tiny_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 diff --git a/scripts/vnext/single_bandset_band_dropout/tiny.py b/scripts/vnext/single_bandset_band_dropout/tiny.py new file mode 100644 index 000000000..89124d8c4 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/tiny.py @@ -0,0 +1,72 @@ +"""Trying to prototype fitting everything into olmo core.""" + +import logging + +from base_band_dropout_no_s1_drop_random_time import ( + BAND_DROPOUT_MODALITIES, + RANDOM_BAND_DROPOUT_MAX_RATE, + build_common_components, + build_dataloader_config, + build_dataset_config, + build_train_module_config, + build_trainer_config, +) + +from olmoearth_pretrain.internal.experiment import CommonComponents, main +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["tiny_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + ) diff --git a/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh b/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh new file mode 100644 index 000000000..2b51dcefa --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh @@ -0,0 +1,6 @@ +python scripts/official/tiny.py launch tiny_lr0.0002_wd0.20 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.20 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/official/tiny.py launch tiny_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/official/tiny.py launch tiny_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/official/tiny.py launch tiny_lr0.0005_wd0.20 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.20 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/official/tiny.py launch tiny_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/official/tiny.py launch tiny_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 From 6251859c335cb3d11822e49a183735d91a53bb4d Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 22 Apr 2026 08:39:49 +0200 Subject: [PATCH 23/31] update to vectorized loss --- .../base_band_dropout_no_s1_drop_random_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py index 198160cd9..df9a99848 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py @@ -168,7 +168,7 @@ def build_train_module_config( masking_config=_masking_config(common.tokenization_config), loss_config=LossConfig( loss_config={ - "type": "modality_patch_discrimination_masked_negatives", + "type": "modality_patch_discrimination_masked_negatives_vec", "tau": 0.1, "same_target_threshold": 0.999, "mask_negatives_for_modalities": ONLY_DECODE_MODALITIES, From 6e4b591a05f701b3eccafc5c560a260a739b99b7 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 22 Apr 2026 08:40:53 +0200 Subject: [PATCH 24/31] update to vec loss --- .../vnext/single_bandset_band_dropout/base_launch.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/vnext/single_bandset_band_dropout/base_launch.sh b/scripts/vnext/single_bandset_band_dropout/base_launch.sh index 248f06504..8fe7ab950 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/base_launch.sh @@ -1,5 +1,5 @@ -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0001_wd0.002_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.02_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.002_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.02_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent -python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.002_8gpu ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0001_wd0.002_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.02_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.002_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.02_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.002_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent From 7d11df1bed19cb1e33be2e967f648481641d00e8 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 22 Apr 2026 08:44:32 +0200 Subject: [PATCH 25/31] update run names --- .../vnext/single_bandset_band_dropout/nano_launch.sh | 8 ++++---- .../vnext/single_bandset_band_dropout/tiny_launch.sh | 10 ++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/scripts/vnext/single_bandset_band_dropout/nano_launch.sh b/scripts/vnext/single_bandset_band_dropout/nano_launch.sh index ef169a8c2..9d8ed22a5 100644 --- a/scripts/vnext/single_bandset_band_dropout/nano_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/nano_launch.sh @@ -1,4 +1,4 @@ -python scripts/vnext/single_bandset_band_dropout/nano.py launch tiny_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/vnext/single_bandset_band_dropout/official/nano.py launch tiny_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/vnext/single_bandset_band_dropout/official/nano.py launch tiny_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/vnext/single_bandset_band_dropout/official/nano.py launch tiny_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/nano.py launch nano_1.1_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/official/nano.py launch nano_1.1_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/official/nano.py launch nano_1.1_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/official/tiny.py launch nano_1.1_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 diff --git a/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh b/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh index 2b51dcefa..8f8505905 100644 --- a/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh @@ -1,6 +1,4 @@ -python scripts/official/tiny.py launch tiny_lr0.0002_wd0.20 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.20 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/official/tiny.py launch tiny_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/official/tiny.py launch tiny_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/official/tiny.py launch tiny_lr0.0005_wd0.20 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.20 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/official/tiny.py launch tiny_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/official/tiny.py launch tiny_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/official/tiny.py launch tiny_1.1_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/official/tiny.py launch tiny_1.1_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/official/tiny.py launch tiny_1.1_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 From 1dee7a2d9a2883ed03e1f44af3c7349844be8da4 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 22 Apr 2026 08:46:38 +0200 Subject: [PATCH 26/31] fix --- scripts/vnext/single_bandset_band_dropout/nano_launch.sh | 6 +++--- scripts/vnext/single_bandset_band_dropout/tiny_launch.sh | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/vnext/single_bandset_band_dropout/nano_launch.sh b/scripts/vnext/single_bandset_band_dropout/nano_launch.sh index 9d8ed22a5..c68d2d897 100644 --- a/scripts/vnext/single_bandset_band_dropout/nano_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/nano_launch.sh @@ -1,4 +1,4 @@ python scripts/vnext/single_bandset_band_dropout/nano.py launch nano_1.1_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/vnext/single_bandset_band_dropout/official/nano.py launch nano_1.1_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/vnext/single_bandset_band_dropout/official/nano.py launch nano_1.1_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/vnext/single_bandset_band_dropout/official/tiny.py launch nano_1.1_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/nano.py launch nano_1.1_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/nano.py launch nano_1.1_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/tiny.py launch nano_1.1_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 diff --git a/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh b/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh index 8f8505905..9129ccd66 100644 --- a/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh @@ -1,4 +1,4 @@ python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0002_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/vnext/single_bandset_band_dropout/official/tiny.py launch tiny_1.1_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/vnext/single_bandset_band_dropout/official/tiny.py launch tiny_1.1_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 -python scripts/vnext/single_bandset_band_dropout/official/tiny.py launch tiny_1.1_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 From d42fe6bed3707c9952cb4b5c8d3390d8e7e571ec Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 22 Apr 2026 18:15:02 +0200 Subject: [PATCH 27/31] add loss ablation --- ...se_band_dropout_no_s1_drop_base_orgloss.py | 421 ++++++++++++++++++ 1 file changed, 421 insertions(+) create mode 100644 scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_base_orgloss.py diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_base_orgloss.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_base_orgloss.py new file mode 100644 index 000000000..f4ccc7613 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_base_orgloss.py @@ -0,0 +1,421 @@ +"""Base script for single bandset + random band dropout (no S1) + random time with decode masking + original base loss. + +- Single bandset S2 (all 12 bands) / Landsat (all 11 bands) +- Random band dropout (rate ~ Uniform(0, 0.3)) on S2 and Landsat only (no S1 dropout) +- Random time with decode masking +- Original modality_patch_discrimination_new loss (as in official/base.py) +- InfoNCE weight 0.05 +- Rank microbatch size 64 +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel.data_parallel import ( + DataParallelConfig, + DataParallelType, +) +from olmo_core.optim import AdamWConfig +from olmo_core.optim.scheduler import CosWithWarmup +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig + +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig +from olmoearth_pretrain.data.dataset import OlmoEarthDatasetConfig +from olmoearth_pretrain.evals.datasets.normalize import NormMethod +from olmoearth_pretrain.evals.metrics import EvalMetric +from olmoearth_pretrain.internal.common import ( + build_common_components as build_common_components_default, +) +from olmoearth_pretrain.internal.experiment import ( + CommonComponents, + OlmoEarthVisualizeConfig, + SubCmd, + main, +) +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexi_vit import ( + PoolingType, +) +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig +from olmoearth_pretrain.nn.tokenization import ModalityTokenization, TokenizationConfig +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthSpeedMonitorCallback, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import ( + DownstreamTaskConfig, + EvalMode, +) +from olmoearth_pretrain.train.loss import LossConfig +from olmoearth_pretrain.train.masking import MaskingConfig +from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( + ContrastiveLatentMIMTrainModuleConfig, +) + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 +RANDOM_BAND_DROPOUT_MAX_RATE = 0.2 + +S2_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + [ + "B02", + "B03", + "B04", + "B08", + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12", + "B01", + "B09", + ], + ] +) + +LANDSAT_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], + ] +) + +ONLY_DECODE_MODALITIES = [ + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, +] + +# No S1 dropout — only apply band dropout to S2 and Landsat. +BAND_DROPOUT_MODALITIES = [ + Modality.SENTINEL2_L2A.name, + Modality.LANDSAT.name, +] + + +def _tokenization_config() -> TokenizationConfig: + return TokenizationConfig( + overrides={ + "sentinel2_l2a": S2_SINGLE_BANDSET, + "landsat": LANDSAT_SINGLE_BANDSET, + } + ) + + +def _masking_config( + tokenization_config: TokenizationConfig | None = None, +) -> MaskingConfig: + return MaskingConfig( + strategy_config={ + "type": "random_time_with_decode", + "encode_ratio": 0.5, + "decode_ratio": 0.5, + "random_ratio": 0.5, + "only_decode_modalities": ONLY_DECODE_MODALITIES, + }, + tokenization_config=tokenization_config, + ) + + +def build_common_components( + script: str, cmd: SubCmd, run_name: str, cluster: str, overrides: list[str] +) -> CommonComponents: + """Build the common components for an experiment.""" + config = build_common_components_default(script, cmd, run_name, cluster, overrides) + config.training_modalities = [ + Modality.SENTINEL2_L2A.name, + Modality.SENTINEL1.name, + Modality.LANDSAT.name, + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + ] + config.tokenization_config = _tokenization_config() + return config + + +def build_train_module_config( + common: CommonComponents, +) -> ContrastiveLatentMIMTrainModuleConfig: + """Build the train module config for an experiment.""" + return ContrastiveLatentMIMTrainModuleConfig( + optim_config=AdamWConfig(lr=0.0001, weight_decay=0.02, fused=False), + rank_microbatch_size=64, + masking_config=_masking_config(common.tokenization_config), + loss_config=LossConfig( + loss_config={ + "type": "modality_patch_discrimination_new", + "tau": 0.1, + } + ), + contrastive_config=LossConfig( + loss_config={ + "type": "InfoNCE", + "weight": 0.05, + } + ), + token_exit_cfg={modality: 0 for modality in common.training_modalities}, + max_grad_norm=1.0, + scheduler=CosWithWarmup(warmup_steps=8000), + ema_decay=(1.0, 1.0), + dp_config=DataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + ), + ) + + +def build_dataloader_config(common: CommonComponents) -> OlmoEarthDataLoaderConfig: + """Build the dataloader config for an experiment.""" + return OlmoEarthDataLoaderConfig( + num_workers=16, + global_batch_size=512, + token_budget=2250, + prefetch_factor=4, + sampled_hw_p_list=list(range(1, 13)), + min_patch_size=MIN_PATCH_SIZE, + max_patch_size=MAX_PATCH_SIZE, + work_dir=common.save_folder, + seed=3622, + num_masked_views=2, + masking_config=_masking_config(common.tokenization_config), + ) + + +def build_dataset_config(common: CommonComponents) -> OlmoEarthDatasetConfig: + """Build the dataset config for an experiment.""" + return OlmoEarthDatasetConfig( + h5py_dir="/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828", + training_modalities=common.training_modalities, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build the trainer config for an experiment.""" + MAX_DURATION = Duration.epochs(300) + METRICS_COLLECT_INTERVAL = 10 + CANCEL_CHECK_INTERVAL = 25 + LOAD_STRATEGY = LoadStrategy.if_available + WANDB_USERNAME = "eai-ai2" # nosec + WANDB_PROJECT = "2026_02_08_masked_neg" + PERMANENT_SAVE_INTERVAL = 5000 + EPHERMERAL_SAVE_INTERVAL = 250 + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project=WANDB_PROJECT, + entity=WANDB_USERNAME, + enabled=True, + ) + garbage_collector_callback = GarbageCollectorCallback(gc_interval=1) + EVAL_TASKS = { + "m-eurosat": DownstreamTaskConfig( + dataset="m-eurosat", + embedding_batch_size=128, + num_workers=0, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + eval_interval=Duration.steps(4000), + ), + "m_so2sat": DownstreamTaskConfig( + dataset="m-so2sat", + embedding_batch_size=128, + num_workers=4, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + ), + "mados": DownstreamTaskConfig( + dataset="mados", + embedding_batch_size=128, + probe_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=False, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(4000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MICRO_F1, + ), + "pastis": DownstreamTaskConfig( + dataset="pastis", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MIOU, + ), + "yemen_crop": DownstreamTaskConfig( + dataset="yemen_crop", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + eval_interval=Duration.steps(20000), + probe_lr=0.001, + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "geo_ecosystem_annual_test": DownstreamTaskConfig( + dataset="geo_ecosystem_annual_test", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "canada_wildfire_sat_eval_split": DownstreamTaskConfig( + dataset="canada_wildfire_sat_eval_split", + embedding_batch_size=32, + probe_batch_size=16, + patch_size=5, # TODO: This is changeable but we should know the valid sizes for inputs + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + use_dice_loss=True, + primary_metric=EvalMetric.CLASS_F1, + primary_metric_class=1, + ), + } + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LOAD_STRATEGY, + save_folder=common.save_folder, + cancel_check_interval=CANCEL_CHECK_INTERVAL, + metrics_collect_interval=METRICS_COLLECT_INTERVAL, + max_duration=MAX_DURATION, + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("speed_monitor", OlmoEarthSpeedMonitorCallback()) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=EVAL_TASKS, + ), + ) + .with_callback("garbage_collector", garbage_collector_callback) + .with_callback("beaker", BeakerCallback()) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=PERMANENT_SAVE_INTERVAL, + ephemeral_save_interval=EPHERMERAL_SAVE_INTERVAL, + ), + ) + ) + return trainer_config + + +def build_visualize_config(common: CommonComponents) -> OlmoEarthVisualizeConfig: + """Build the visualize config for an experiment.""" + return OlmoEarthVisualizeConfig( + num_samples=None, + output_dir=str(f"{common.save_folder}/visualizations"), + std_multiplier=2.0, + ) + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["base_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + ) From 34b22d2fe695991491b72151034ad99aaf0c2b63 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Wed, 22 Apr 2026 21:21:42 +0200 Subject: [PATCH 28/31] add ndvi --- ...and_dropout_no_s1_drop_random_time_ndvi.py | 429 ++++++++++++++++++ 1 file changed, 429 insertions(+) create mode 100644 scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_ndvi.py diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_ndvi.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_ndvi.py new file mode 100644 index 000000000..e54093059 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_ndvi.py @@ -0,0 +1,429 @@ +"""Base script for single bandset + random band dropout (no S1) + random time with decode masking + NDVI decode-only + masked-negatives loss. + +- Single bandset S2 (all 12 bands) / Landsat (all 11 bands) +- Random band dropout (rate ~ Uniform(0, 0.3)) on S2 and Landsat only (no S1 dropout) +- Random time with decode masking +- NDVI added as a decode-only modality +- Masked negatives patch discrimination loss +- InfoNCE weight 0.05 +- Rank microbatch size 64 +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel.data_parallel import ( + DataParallelConfig, + DataParallelType, +) +from olmo_core.optim import AdamWConfig +from olmo_core.optim.scheduler import CosWithWarmup +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig + +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig +from olmoearth_pretrain.data.dataset import OlmoEarthDatasetConfig +from olmoearth_pretrain.evals.datasets.normalize import NormMethod +from olmoearth_pretrain.evals.metrics import EvalMetric +from olmoearth_pretrain.internal.common import ( + build_common_components as build_common_components_default, +) +from olmoearth_pretrain.internal.experiment import ( + CommonComponents, + OlmoEarthVisualizeConfig, + SubCmd, + main, +) +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexi_vit import ( + PoolingType, +) +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig +from olmoearth_pretrain.nn.tokenization import ModalityTokenization, TokenizationConfig +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthSpeedMonitorCallback, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import ( + DownstreamTaskConfig, + EvalMode, +) +from olmoearth_pretrain.train.loss import LossConfig +from olmoearth_pretrain.train.masking import MaskingConfig +from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( + ContrastiveLatentMIMTrainModuleConfig, +) + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 +RANDOM_BAND_DROPOUT_MAX_RATE = 0.2 + +S2_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + [ + "B02", + "B03", + "B04", + "B08", + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12", + "B01", + "B09", + ], + ] +) + +LANDSAT_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], + ] +) + +NDVI_SINGLE_BANDSET = ModalityTokenization(band_groups=[["ndvi"]]) + +ONLY_DECODE_MODALITIES = [ + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + Modality.NDVI.name, +] + +# No S1 dropout — only apply band dropout to S2 and Landsat. +BAND_DROPOUT_MODALITIES = [ + Modality.SENTINEL2_L2A.name, + Modality.LANDSAT.name, +] + + +def _tokenization_config() -> TokenizationConfig: + return TokenizationConfig( + overrides={ + "sentinel2_l2a": S2_SINGLE_BANDSET, + "landsat": LANDSAT_SINGLE_BANDSET, + "ndvi": NDVI_SINGLE_BANDSET, + } + ) + + +def _masking_config( + tokenization_config: TokenizationConfig | None = None, +) -> MaskingConfig: + return MaskingConfig( + strategy_config={ + "type": "random_time_with_decode", + "encode_ratio": 0.5, + "decode_ratio": 0.5, + "random_ratio": 0.5, + "only_decode_modalities": ONLY_DECODE_MODALITIES, + }, + tokenization_config=tokenization_config, + ) + + +def build_common_components( + script: str, cmd: SubCmd, run_name: str, cluster: str, overrides: list[str] +) -> CommonComponents: + """Build the common components for an experiment.""" + config = build_common_components_default(script, cmd, run_name, cluster, overrides) + config.training_modalities = [ + Modality.SENTINEL2_L2A.name, + Modality.SENTINEL1.name, + Modality.LANDSAT.name, + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + Modality.NDVI.name, + ] + config.tokenization_config = _tokenization_config() + return config + + +def build_train_module_config( + common: CommonComponents, +) -> ContrastiveLatentMIMTrainModuleConfig: + """Build the train module config for an experiment.""" + return ContrastiveLatentMIMTrainModuleConfig( + optim_config=AdamWConfig(lr=0.0001, weight_decay=0.02, fused=False), + rank_microbatch_size=64, + masking_config=_masking_config(common.tokenization_config), + loss_config=LossConfig( + loss_config={ + "type": "modality_patch_discrimination_masked_negatives_vec", + "tau": 0.1, + "same_target_threshold": 0.999, + "mask_negatives_for_modalities": ONLY_DECODE_MODALITIES, + } + ), + contrastive_config=LossConfig( + loss_config={ + "type": "InfoNCE", + "weight": 0.05, + } + ), + token_exit_cfg={modality: 0 for modality in common.training_modalities}, + max_grad_norm=1.0, + scheduler=CosWithWarmup(warmup_steps=8000), + ema_decay=(1.0, 1.0), + dp_config=DataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + ), + ) + + +def build_dataloader_config(common: CommonComponents) -> OlmoEarthDataLoaderConfig: + """Build the dataloader config for an experiment.""" + return OlmoEarthDataLoaderConfig( + num_workers=16, + global_batch_size=512, + token_budget=2250, + prefetch_factor=4, + sampled_hw_p_list=list(range(1, 13)), + min_patch_size=MIN_PATCH_SIZE, + max_patch_size=MAX_PATCH_SIZE, + work_dir=common.save_folder, + seed=3622, + num_masked_views=2, + masking_config=_masking_config(common.tokenization_config), + ) + + +def build_dataset_config(common: CommonComponents) -> OlmoEarthDatasetConfig: + """Build the dataset config for an experiment.""" + return OlmoEarthDatasetConfig( + h5py_dir="/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828", + training_modalities=common.training_modalities, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build the trainer config for an experiment.""" + MAX_DURATION = Duration.epochs(300) + METRICS_COLLECT_INTERVAL = 10 + CANCEL_CHECK_INTERVAL = 25 + LOAD_STRATEGY = LoadStrategy.if_available + WANDB_USERNAME = "eai-ai2" # nosec + WANDB_PROJECT = "2026_02_08_masked_neg" + PERMANENT_SAVE_INTERVAL = 5000 + EPHERMERAL_SAVE_INTERVAL = 250 + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project=WANDB_PROJECT, + entity=WANDB_USERNAME, + enabled=True, + ) + garbage_collector_callback = GarbageCollectorCallback(gc_interval=1) + EVAL_TASKS = { + "m-eurosat": DownstreamTaskConfig( + dataset="m-eurosat", + embedding_batch_size=128, + num_workers=0, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + eval_interval=Duration.steps(4000), + ), + "m_so2sat": DownstreamTaskConfig( + dataset="m-so2sat", + embedding_batch_size=128, + num_workers=4, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + ), + "mados": DownstreamTaskConfig( + dataset="mados", + embedding_batch_size=128, + probe_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=False, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(4000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MICRO_F1, + ), + "pastis": DownstreamTaskConfig( + dataset="pastis", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MIOU, + ), + "yemen_crop": DownstreamTaskConfig( + dataset="yemen_crop", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + eval_interval=Duration.steps(20000), + probe_lr=0.001, + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "geo_ecosystem_annual_test": DownstreamTaskConfig( + dataset="geo_ecosystem_annual_test", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "canada_wildfire_sat_eval_split": DownstreamTaskConfig( + dataset="canada_wildfire_sat_eval_split", + embedding_batch_size=32, + probe_batch_size=16, + patch_size=5, # TODO: This is changeable but we should know the valid sizes for inputs + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + use_dice_loss=True, + primary_metric=EvalMetric.CLASS_F1, + primary_metric_class=1, + ), + } + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LOAD_STRATEGY, + save_folder=common.save_folder, + cancel_check_interval=CANCEL_CHECK_INTERVAL, + metrics_collect_interval=METRICS_COLLECT_INTERVAL, + max_duration=MAX_DURATION, + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("speed_monitor", OlmoEarthSpeedMonitorCallback()) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=EVAL_TASKS, + ), + ) + .with_callback("garbage_collector", garbage_collector_callback) + .with_callback("beaker", BeakerCallback()) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=PERMANENT_SAVE_INTERVAL, + ephemeral_save_interval=EPHERMERAL_SAVE_INTERVAL, + ), + ) + ) + return trainer_config + + +def build_visualize_config(common: CommonComponents) -> OlmoEarthVisualizeConfig: + """Build the visualize config for an experiment.""" + return OlmoEarthVisualizeConfig( + num_samples=None, + output_dir=str(f"{common.save_folder}/visualizations"), + std_multiplier=2.0, + ) + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["base_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + ) From 7221d2983ddf618eebf2491075ce1f135f9011af Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 24 Apr 2026 11:30:59 +0200 Subject: [PATCH 29/31] add 0.001 lrs for nano and tiny --- scripts/vnext/single_bandset_band_dropout/nano_launch.sh | 2 ++ scripts/vnext/single_bandset_band_dropout/tiny_launch.sh | 2 ++ 2 files changed, 4 insertions(+) diff --git a/scripts/vnext/single_bandset_band_dropout/nano_launch.sh b/scripts/vnext/single_bandset_band_dropout/nano_launch.sh index c68d2d897..b257df66a 100644 --- a/scripts/vnext/single_bandset_band_dropout/nano_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/nano_launch.sh @@ -2,3 +2,5 @@ python scripts/vnext/single_bandset_band_dropout/nano.py launch nano_1.1_lr0.000 python scripts/vnext/single_bandset_band_dropout/nano.py launch nano_1.1_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 python scripts/vnext/single_bandset_band_dropout/nano.py launch nano_1.1_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 python scripts/vnext/single_bandset_band_dropout/tiny.py launch nano_1.1_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/nano.py launch nano_1.1_lr0.0001_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/tiny.py launch nano_1.1_lr0.0001_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 diff --git a/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh b/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh index 9129ccd66..b94f73214 100644 --- a/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/tiny_launch.sh @@ -2,3 +2,5 @@ python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.000 python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0002_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0005_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0005_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0001_wd0.02 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 +python scripts/vnext/single_bandset_band_dropout/tiny.py launch tiny_1.1_lr0.0001_wd0.002 ai2/ceres-cirrascale --train_module.optim_config.lr=0.0001 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.priority=urgent --launch.num_gpus=4 From ede47580d40a70b7773705e3835a8dc4a75cafc4 Mon Sep 17 00:00:00 2001 From: Gabriel Tseng Date: Fri, 24 Apr 2026 19:32:54 +0200 Subject: [PATCH 30/31] more base, lower lr --- scripts/vnext/single_bandset_band_dropout/base_launch.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/vnext/single_bandset_band_dropout/base_launch.sh b/scripts/vnext/single_bandset_band_dropout/base_launch.sh index 8fe7ab950..995499312 100644 --- a/scripts/vnext/single_bandset_band_dropout/base_launch.sh +++ b/scripts/vnext/single_bandset_band_dropout/base_launch.sh @@ -3,3 +3,5 @@ python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_ra python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0002_wd0.002_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.0002 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.02_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.0003_wd0.002_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.0003 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.00005_wd0.02_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.00005 --train_module.optim_config.weight_decay=0.02 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent +python scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time.py launch v1.1_base_lr0.00005_wd0.002_8gpu_vec ai2/ceres-cirrascale --train_module.optim_config.lr=0.00005 --train_module.optim_config.weight_decay=0.002 --launch.clusters='[ai2/jupiter-cirrascale-2]' --launch.num_gpus=8 --launch.priority=urgent From 9670a4ef826504be5de127d2cf8836e564c862d0 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 24 Apr 2026 21:19:52 -0400 Subject: [PATCH 31/31] train with osm_sampling + eurocrops_sampling with latest single bandset config --- .../script.py | 430 ++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 scripts/archived/2026_04_24_single_bandset_eurocrops/script.py diff --git a/scripts/archived/2026_04_24_single_bandset_eurocrops/script.py b/scripts/archived/2026_04_24_single_bandset_eurocrops/script.py new file mode 100644 index 000000000..f548a999c --- /dev/null +++ b/scripts/archived/2026_04_24_single_bandset_eurocrops/script.py @@ -0,0 +1,430 @@ +"""Single bandset + random band dropout (no S1) + random time + eurocrops. + +Based on base_band_dropout_no_s1_drop_random_time.py but trains on +osm_sampling + eurocrops datasets via concat dataset config, and adds +eurocrops as a decode-only modality. +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel.data_parallel import ( + DataParallelConfig, + DataParallelType, +) +from olmo_core.optim import AdamWConfig +from olmo_core.optim.scheduler import CosWithWarmup +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig + +from olmoearth_pretrain.data.concat import OlmoEarthConcatDatasetConfig +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig +from olmoearth_pretrain.data.dataset import OlmoEarthDatasetConfig +from olmoearth_pretrain.evals.datasets.normalize import NormMethod +from olmoearth_pretrain.evals.metrics import EvalMetric +from olmoearth_pretrain.internal.common import ( + build_common_components as build_common_components_default, +) +from olmoearth_pretrain.internal.experiment import ( + CommonComponents, + OlmoEarthVisualizeConfig, + SubCmd, + main, +) +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexi_vit import ( + PoolingType, +) +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig +from olmoearth_pretrain.nn.tokenization import ModalityTokenization, TokenizationConfig +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthSpeedMonitorCallback, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import ( + DownstreamTaskConfig, + EvalMode, +) +from olmoearth_pretrain.train.loss import LossConfig +from olmoearth_pretrain.train.masking import MaskingConfig +from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( + ContrastiveLatentMIMTrainModuleConfig, +) + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 +RANDOM_BAND_DROPOUT_MAX_RATE = 0.2 + +S2_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + [ + "B02", + "B03", + "B04", + "B08", + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12", + "B01", + "B09", + ], + ] +) + +LANDSAT_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], + ] +) + +ONLY_DECODE_MODALITIES = [ + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + Modality.EUROCROPS.name, +] + +# No S1 dropout — only apply band dropout to S2 and Landsat. +BAND_DROPOUT_MODALITIES = [ + Modality.SENTINEL2_L2A.name, + Modality.LANDSAT.name, +] + + +def _tokenization_config() -> TokenizationConfig: + return TokenizationConfig( + overrides={ + "sentinel2_l2a": S2_SINGLE_BANDSET, + "landsat": LANDSAT_SINGLE_BANDSET, + } + ) + + +def _masking_config( + tokenization_config: TokenizationConfig | None = None, +) -> MaskingConfig: + return MaskingConfig( + strategy_config={ + "type": "random_time_with_decode", + "encode_ratio": 0.5, + "decode_ratio": 0.5, + "random_ratio": 0.5, + "only_decode_modalities": ONLY_DECODE_MODALITIES, + }, + tokenization_config=tokenization_config, + ) + + +def build_common_components( + script: str, cmd: SubCmd, run_name: str, cluster: str, overrides: list[str] +) -> CommonComponents: + """Build the common components for an experiment.""" + config = build_common_components_default(script, cmd, run_name, cluster, overrides) + config.training_modalities = [ + Modality.SENTINEL2_L2A.name, + Modality.SENTINEL1.name, + Modality.LANDSAT.name, + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + Modality.EUROCROPS.name, + ] + config.tokenization_config = _tokenization_config() + return config + + +def build_train_module_config( + common: CommonComponents, +) -> ContrastiveLatentMIMTrainModuleConfig: + """Build the train module config for an experiment.""" + return ContrastiveLatentMIMTrainModuleConfig( + optim_config=AdamWConfig(lr=0.0001, weight_decay=0.02, fused=False), + rank_microbatch_size=64, + masking_config=_masking_config(common.tokenization_config), + loss_config=LossConfig( + loss_config={ + "type": "modality_patch_discrimination_masked_negatives_vec", + "tau": 0.1, + "same_target_threshold": 0.999, + "mask_negatives_for_modalities": ONLY_DECODE_MODALITIES, + } + ), + contrastive_config=LossConfig( + loss_config={ + "type": "InfoNCE", + "weight": 0.05, + } + ), + token_exit_cfg={modality: 0 for modality in common.training_modalities}, + max_grad_norm=1.0, + scheduler=CosWithWarmup(warmup_steps=8000), + ema_decay=(1.0, 1.0), + dp_config=DataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + ), + ) + + +def build_dataloader_config(common: CommonComponents) -> OlmoEarthDataLoaderConfig: + """Build the dataloader config for an experiment.""" + return OlmoEarthDataLoaderConfig( + num_workers=16, + global_batch_size=512, + token_budget=2250, + prefetch_factor=4, + sampled_hw_p_list=list(range(1, 13)), + min_patch_size=MIN_PATCH_SIZE, + max_patch_size=MAX_PATCH_SIZE, + work_dir=common.save_folder, + seed=3622, + num_masked_views=2, + masking_config=_masking_config(common.tokenization_config), + ) + + +def build_dataset_config(common: CommonComponents) -> OlmoEarthConcatDatasetConfig: + """Build the dataset config for an experiment.""" + dataset_configs = [ + OlmoEarthDatasetConfig( + h5py_dir="/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_era5l_day_10_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_wri_canopy_height_map/1138828", + training_modalities=common.training_modalities, + ), + OlmoEarthDatasetConfig( + h5py_dir="/weka/dfive-default/helios/dataset/eurocrops_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_eurocrops_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_wri_canopy_height_map/646212", + training_modalities=common.training_modalities, + ), + ] + return OlmoEarthConcatDatasetConfig(dataset_configs=dataset_configs) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build the trainer config for an experiment.""" + MAX_DURATION = Duration.epochs(300) + METRICS_COLLECT_INTERVAL = 10 + CANCEL_CHECK_INTERVAL = 25 + LOAD_STRATEGY = LoadStrategy.if_available + WANDB_USERNAME = "eai-ai2" # nosec + WANDB_PROJECT = "2026_04_24_single_bandset_eurocrops" + PERMANENT_SAVE_INTERVAL = 5000 + EPHERMERAL_SAVE_INTERVAL = 250 + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project=WANDB_PROJECT, + entity=WANDB_USERNAME, + enabled=True, + ) + garbage_collector_callback = GarbageCollectorCallback(gc_interval=1) + EVAL_TASKS = { + "m-eurosat": DownstreamTaskConfig( + dataset="m-eurosat", + embedding_batch_size=128, + num_workers=0, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + eval_interval=Duration.steps(4000), + ), + "m_so2sat": DownstreamTaskConfig( + dataset="m-so2sat", + embedding_batch_size=128, + num_workers=4, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + ), + "mados": DownstreamTaskConfig( + dataset="mados", + embedding_batch_size=128, + probe_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=False, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(4000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MICRO_F1, + ), + "pastis": DownstreamTaskConfig( + dataset="pastis", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MIOU, + ), + "yemen_crop": DownstreamTaskConfig( + dataset="yemen_crop", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + eval_interval=Duration.steps(20000), + probe_lr=0.001, + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "geo_ecosystem_annual_test": DownstreamTaskConfig( + dataset="geo_ecosystem_annual_test", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "canada_wildfire_sat_eval_split": DownstreamTaskConfig( + dataset="canada_wildfire_sat_eval_split", + embedding_batch_size=32, + probe_batch_size=16, + patch_size=5, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + use_dice_loss=True, + primary_metric=EvalMetric.CLASS_F1, + primary_metric_class=1, + ), + } + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LOAD_STRATEGY, + save_folder=common.save_folder, + cancel_check_interval=CANCEL_CHECK_INTERVAL, + metrics_collect_interval=METRICS_COLLECT_INTERVAL, + max_duration=MAX_DURATION, + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("speed_monitor", OlmoEarthSpeedMonitorCallback()) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=EVAL_TASKS, + ), + ) + .with_callback("garbage_collector", garbage_collector_callback) + .with_callback("beaker", BeakerCallback()) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=PERMANENT_SAVE_INTERVAL, + ephemeral_save_interval=EPHERMERAL_SAVE_INTERVAL, + ), + ) + ) + return trainer_config + + +def build_visualize_config(common: CommonComponents) -> OlmoEarthVisualizeConfig: + """Build the visualize config for an experiment.""" + return OlmoEarthVisualizeConfig( + num_samples=None, + output_dir=str(f"{common.save_folder}/visualizations"), + std_multiplier=2.0, + ) + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["base_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + )