From acb0a57728e0c29a8067db1a207d60742c19bfb5 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Tue, 7 Apr 2026 04:55:55 +0000 Subject: [PATCH 1/4] Add static_temporal encoding mode Replace legacy time-index + month embeddings with a multi-frequency sinusoidal temporal encoding based on fractional year. The encoding uses geometric-spaced frequencies from ~128-year periods to daily resolution, making it slot-position invariant (same calendar date always gets the same encoding regardless of time slot index). - Add TimestampEncodingMode enum and get_static_temporal_encoding() - Thread timestamp_encoding_mode through CompositeEncodings, FlexiVitBase, Encoder, PredictorBase, EncoderConfig, PredictorConfig - Add experiment script (vnext base_band_dropout variant) - Add 8 unit tests covering shape, determinism, forward pass, and legacy compat Made-with: Cursor --- olmoearth_pretrain/nn/encodings.py | 43 ++++++++ olmoearth_pretrain/nn/flexi_vit.py | 85 ++++++++++----- .../base_band_dropout_static_temporal.py | 70 ++++++++++++ .../unit/nn/test_static_temporal_encoding.py | 103 ++++++++++++++++++ 4 files changed, 276 insertions(+), 25 deletions(-) create mode 100644 scripts/vnext/single_bandset_band_dropout/base_band_dropout_static_temporal.py create mode 100644 tests/unit/nn/test_static_temporal_encoding.py diff --git a/olmoearth_pretrain/nn/encodings.py b/olmoearth_pretrain/nn/encodings.py index d152809cf..56e1b1a1c 100644 --- a/olmoearth_pretrain/nn/encodings.py +++ b/olmoearth_pretrain/nn/encodings.py @@ -7,12 +7,23 @@ - 2D sinusoidal position encoding (for spatial data) - 1D sinusoidal position encoding (for temporal data) - Month encoding (for temporal data) +- Static multi-frequency temporal encoding """ +import math +from enum import StrEnum + import numpy as np import torch +class TimestampEncodingMode(StrEnum): + """Mode for encoding temporal information.""" + + LEGACY = "legacy" + STATIC_TEMPORAL = "static_temporal" + + def get_1d_sincos_pos_encoding(pos: torch.Tensor, encoding_dim: int) -> torch.Tensor: """Get 1D sin cos position encoding for a given set of positions. @@ -119,3 +130,35 @@ def get_month_encoding_table(encoding_dim: int) -> torch.Tensor: month_table = torch.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1) return month_table # (M, D) + + +def get_static_temporal_encoding( + timestamps: torch.Tensor, encoding_dim: int +) -> torch.Tensor: + """Static multi-frequency sinusoidal temporal encoding. + + Converts timestamps to a fractional year and applies geometric-spaced + sinusoidal frequencies ranging from ~128-year periods to daily resolution. + + Args: + timestamps: Tensor of shape (B, T, 3) where [..., 0] is day (1-31), + [..., 1] is month (0-indexed, 0-11), [..., 2] is year. + encoding_dim: Output encoding dimension (must be even). + + Returns: + Tensor of shape (B, T, encoding_dim). + """ + assert encoding_dim % 2 == 0, f"encoding_dim must be even, got {encoding_dim}" + day = timestamps[..., 0].float() + month = timestamps[..., 1].float() + year = timestamps[..., 2].float() + + day_of_year = month * 30.4375 + day + frac_year = year + day_of_year / 365.25 - 2020.0 + + num_freqs = encoding_dim // 2 + exponents = torch.linspace(-7.0, 8.5, num_freqs, device=timestamps.device) + freqs = 2.0 * math.pi * (2.0**exponents) # (num_freqs,) + + angles = frac_year.unsqueeze(-1) * freqs # (B, T, num_freqs) + return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) diff --git a/olmoearth_pretrain/nn/flexi_vit.py b/olmoearth_pretrain/nn/flexi_vit.py index e90ddd5c0..8ac7d07ed 100644 --- a/olmoearth_pretrain/nn/flexi_vit.py +++ b/olmoearth_pretrain/nn/flexi_vit.py @@ -24,9 +24,11 @@ ) from olmoearth_pretrain.nn.attention import Block from olmoearth_pretrain.nn.encodings import ( + TimestampEncodingMode, get_1d_sincos_pos_encoding, get_2d_sincos_pos_encoding_with_resolution, get_month_encoding_table, + get_static_temporal_encoding, ) from olmoearth_pretrain.nn.flexi_patch_embed import ( FlexiPatchEmbed, @@ -624,6 +626,7 @@ def __init__( learnable_channel_embeddings: bool = True, random_channel_embeddings: bool = False, tokenization_config: TokenizationConfig | None = None, + timestamp_encoding_mode: str = "legacy", ): """Initialize the composite encodings. @@ -635,8 +638,11 @@ def __init__( learnable_channel_embeddings: Whether to use learnable channel embeddings random_channel_embeddings: Initialize channel embeddings randomly (zeros if False) tokenization_config: Optional config for custom band groupings + timestamp_encoding_mode: "legacy" for time-index + month embeddings, + "static_temporal" for multi-frequency sinusoidal encoding """ super().__init__() + self.timestamp_encoding_mode = TimestampEncodingMode(timestamp_encoding_mode) self.embedding_size = embedding_size self.supported_modalities = supported_modalities self.supported_modality_names = [ @@ -647,22 +653,20 @@ def __init__( self.max_sequence_length = ( max_sequence_length # This max sequence length is a time dim thing ) - # TODO: we need to be able to calculate the size of the param based on what types of embeddings it will get - - # we have 4 embeddings types (pos_in_time, pos_in_space, month, channel) so each get - # 0.25 of the dimension self.embedding_dim_per_embedding_type = int(embedding_size * 0.25) - # Position encodings for time dimension initialized to 1D sinusoidal encodings - self.pos_embed = nn.Parameter( - get_1d_sincos_pos_encoding( - torch.arange(max_sequence_length), - self.embedding_dim_per_embedding_type, - ), - requires_grad=False, - ) - # Month encodings - month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type) - self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True) + if self.timestamp_encoding_mode == TimestampEncodingMode.STATIC_TEMPORAL: + self.pos_embed = None + self.month_embed = None + else: + self.pos_embed = nn.Parameter( + get_1d_sincos_pos_encoding( + torch.arange(max_sequence_length), + self.embedding_dim_per_embedding_type, + ), + requires_grad=False, + ) + month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type) + self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True) if not learnable_channel_embeddings and not random_channel_embeddings: self.per_modality_channel_embeddings = nn.ParameterDict() for modality in self.supported_modalities: @@ -795,16 +799,27 @@ def _apply_encodings_per_modality( modality_embed[..., :n] += channel_embed if modality.is_multitemporal and use_temporal_encodings: - # Time position encodings - time_embed = repeat(self.pos_embed[:t], f"t d -> {ein_string}", **ein_dict) - modality_embed[..., n : n * 2] += time_embed.to(device) - - # Month encodings - assert timestamps is not None - months = timestamps[:, :, 1] - month_embed = self.month_embed(months) - month_embed = repeat(month_embed, f"b t d -> {ein_string}", **ein_dict) - modality_embed[..., n * 2 : n * 3] += month_embed.to(device) + if self.timestamp_encoding_mode == TimestampEncodingMode.STATIC_TEMPORAL: + assert timestamps is not None + ts_embed = get_static_temporal_encoding(timestamps, 2 * n) + ts_view = repeat(ts_embed, f"b t d -> {ein_string}", **ein_dict).to( + device + ) + modality_embed[..., n : n * 3] += ts_view + else: + # Legacy: time-index position + month embeddings + assert self.pos_embed is not None + time_embed = repeat( + self.pos_embed[:t], f"t d -> {ein_string}", **ein_dict + ) + modality_embed[..., n : n * 2] += time_embed.to(device) + + assert timestamps is not None + assert self.month_embed is not None + months = timestamps[:, :, 1] + month_embed = self.month_embed(months) + month_embed = repeat(month_embed, f"b t d -> {ein_string}", **ein_dict) + modality_embed[..., n * 2 : n * 3] += month_embed.to(device) if modality.is_spatial: # Spatial encodings assert input_res is not None @@ -876,6 +891,7 @@ def __init__( use_flash_attn: bool = False, qk_norm: bool = False, tokenization_config: TokenizationConfig | None = None, + timestamp_encoding_mode: str = "legacy", ) -> None: """Initialize the FlexiVitBase class.""" super().__init__() @@ -915,6 +931,7 @@ def __init__( learnable_channel_embeddings, random_channel_embeddings, tokenization_config=self._base_tokenization_config, + timestamp_encoding_mode=timestamp_encoding_mode, ) self.apply(self._init_weights) @@ -1112,6 +1129,7 @@ def __init__( band_dropout_rate: float = 0.0, random_band_dropout: bool = False, band_dropout_modalities: list[str] | None = None, + timestamp_encoding_mode: str = "legacy", ): """Initialize the encoder. @@ -1145,6 +1163,7 @@ def __init__( random_band_dropout: If True, sample dropout rate from Uniform(0, band_dropout_rate). band_dropout_modalities: If provided, only apply band dropout to these modalities. If None, apply to all modalities. Default: None. + timestamp_encoding_mode: "legacy" or "static_temporal" """ self.tokenization_config = tokenization_config or TokenizationConfig() super().__init__( @@ -1160,6 +1179,7 @@ def __init__( random_channel_embeddings=random_channel_embeddings, qk_norm=qk_norm, tokenization_config=self.tokenization_config, + timestamp_encoding_mode=timestamp_encoding_mode, ) self.num_register_tokens = num_register_tokens self.has_register_tokens = num_register_tokens > 0 @@ -1680,6 +1700,7 @@ def __init__( use_flash_attn: bool = False, qk_norm: bool = False, tokenization_config: TokenizationConfig | None = None, + timestamp_encoding_mode: str = "legacy", ): """Initialize the predictor. @@ -1698,6 +1719,7 @@ def __init__( use_flash_attn: Whether to use flash attention qk_norm: Whether to apply normalization to Q and K in attention tokenization_config: Optional config for custom band groupings + timestamp_encoding_mode: "legacy" or "static_temporal" """ self.tokenization_config = tokenization_config or TokenizationConfig() super().__init__( @@ -1713,6 +1735,7 @@ def __init__( use_flash_attn=use_flash_attn, qk_norm=qk_norm, tokenization_config=self.tokenization_config, + timestamp_encoding_mode=timestamp_encoding_mode, ) self.learnable_channel_embeddings = learnable_channel_embeddings self.random_channel_embeddings = random_channel_embeddings @@ -2079,6 +2102,7 @@ class EncoderConfig(Config): band_dropout_rate: float = 0.0 random_band_dropout: bool = False band_dropout_modalities: list[str] | None = None + timestamp_encoding_mode: str = "legacy" def __post_init__(self) -> None: """Coerce raw dicts to TokenizationConfig for old checkpoint compatibility.""" @@ -2104,6 +2128,11 @@ def validate(self) -> None: ) if self.tokenization_config is not None: self.tokenization_config.validate() + if self.timestamp_encoding_mode not in ("legacy", "static_temporal"): + raise ValueError( + f"timestamp_encoding_mode must be 'legacy' or 'static_temporal', " + f"got '{self.timestamp_encoding_mode}'" + ) @property def supported_modalities(self) -> list[ModalitySpec]: @@ -2139,6 +2168,7 @@ class PredictorConfig(Config): use_flash_attn: bool = False qk_norm: bool = False tokenization_config: TokenizationConfig | None = None + timestamp_encoding_mode: str = "legacy" def __post_init__(self) -> None: """Coerce raw dicts to TokenizationConfig for old checkpoint compatibility.""" @@ -2155,6 +2185,11 @@ def validate(self) -> None: raise ValueError(f"Modality {modality} is not supported") if self.tokenization_config is not None: self.tokenization_config.validate() + if self.timestamp_encoding_mode not in ("legacy", "static_temporal"): + raise ValueError( + f"timestamp_encoding_mode must be 'legacy' or 'static_temporal', " + f"got '{self.timestamp_encoding_mode}'" + ) @property def supported_modalities(self) -> list[ModalitySpec]: diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_static_temporal.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_static_temporal.py new file mode 100644 index 000000000..a0d0eed61 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_static_temporal.py @@ -0,0 +1,70 @@ +"""Same as base_band_dropout.py but with static_temporal encoding. + +Replaces legacy time-index + month embeddings with multi-frequency sinusoidal +temporal encoding based on fractional year. All other settings identical. +""" + +from base_band_dropout import ( + BAND_DROPOUT_MODALITIES, + MAX_PATCH_SIZE, + RANDOM_BAND_DROPOUT_MAX_RATE, + build_common_components, + build_dataloader_config, + build_dataset_config, + build_train_module_config, + build_trainer_config, + build_visualize_config, +) + +from olmoearth_pretrain.internal.experiment import CommonComponents, main +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexihelios import EncoderConfig, PredictorConfig +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build model config with static_temporal encoding.""" + model_size = MODEL_SIZE_ARGS["base_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + timestamp_encoding_mode="static_temporal", + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + timestamp_encoding_mode="static_temporal", + ) + return LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + ) diff --git a/tests/unit/nn/test_static_temporal_encoding.py b/tests/unit/nn/test_static_temporal_encoding.py new file mode 100644 index 000000000..deccba786 --- /dev/null +++ b/tests/unit/nn/test_static_temporal_encoding.py @@ -0,0 +1,103 @@ +"""Tests for static_temporal encoding mode and get_static_temporal_encoding.""" + +import torch + +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.nn.encodings import get_static_temporal_encoding +from olmoearth_pretrain.nn.flexi_vit import CompositeEncodings + + +def _make_ce( + mode: str = "static_temporal", embedding_size: int = 768 +) -> CompositeEncodings: + return CompositeEncodings( + embedding_size=embedding_size, + supported_modalities=[Modality.SENTINEL2_L2A], + max_sequence_length=12, + timestamp_encoding_mode=mode, + ) + + +def _make_timestamps(b: int = 2, t: int = 4) -> torch.Tensor: + days = torch.randint(1, 28, (b, t, 1), dtype=torch.long) + months = torch.randint(0, 12, (b, t, 1), dtype=torch.long) + years = torch.randint(2018, 2023, (b, t, 1), dtype=torch.long) + return torch.cat([days, months, years], dim=-1) + + +def test_get_static_temporal_encoding_shape() -> None: + """Output shape should be (B, T, encoding_dim).""" + ts = _make_timestamps(3, 5) + enc = get_static_temporal_encoding(ts, 64) + assert enc.shape == (3, 5, 64) + + +def test_get_static_temporal_encoding_deterministic() -> None: + """Same input should produce same output.""" + ts = torch.tensor([[[15, 6, 2021], [1, 0, 2020]]]) + a = get_static_temporal_encoding(ts, 32) + b = get_static_temporal_encoding(ts, 32) + assert torch.allclose(a, b) + + +def test_get_static_temporal_encoding_different_dates_differ() -> None: + """Different dates should produce different encodings.""" + ts = torch.tensor([[[1, 0, 2020], [1, 6, 2021]]]) + enc = get_static_temporal_encoding(ts, 32) + assert not torch.allclose(enc[0, 0], enc[0, 1]) + + +def test_static_temporal_no_pos_embed() -> None: + """static_temporal mode should not create pos_embed or month_embed.""" + ce = _make_ce("static_temporal") + assert ce.pos_embed is None + assert ce.month_embed is None + + +def test_legacy_has_pos_embed() -> None: + """Legacy mode should still create pos_embed and month_embed.""" + ce = _make_ce("legacy") + assert ce.pos_embed is not None + assert ce.month_embed is not None + + +def test_static_temporal_forward_shape() -> None: + """Forward pass should preserve token shape.""" + ce = _make_ce("static_temporal", embedding_size=16) + B, H, W, T = 2, 2, 2, 4 + tokens = torch.randn(B, H, W, T, 3, 16) + timestamps = _make_timestamps(B, T) + out = ce.forward({"sentinel2_l2a": tokens}, timestamps, patch_size=4) + assert out["sentinel2_l2a"].shape == tokens.shape + + +def test_static_temporal_same_date_same_encoding() -> None: + """Same calendar date in different slots should get identical temporal encoding.""" + ce = _make_ce("static_temporal", embedding_size=16) + B, H, W, T = 1, 2, 2, 3 + tokens = torch.zeros(B, H, W, T, 3, 16) + ts = torch.tensor([[[15, 6, 2021]] * T]) + out = ce.forward({"sentinel2_l2a": tokens}, ts, patch_size=4) + result = out["sentinel2_l2a"] + n = ce.embedding_dim_per_embedding_type + # All time slots should have identical temporal encoding + assert torch.allclose( + result[0, 0, 0, 0, 0, n : 3 * n], + result[0, 0, 0, 1, 0, n : 3 * n], + ) + + +def test_static_temporal_differs_from_legacy() -> None: + """static_temporal and legacy should produce different temporal embeddings.""" + ce_st = _make_ce("static_temporal", embedding_size=16) + ce_lg = _make_ce("legacy", embedding_size=16) + B, H, W, T = 1, 2, 2, 4 + tokens = torch.zeros(B, H, W, T, 3, 16) + ts = _make_timestamps(B, T) + out_st = ce_st.forward({"sentinel2_l2a": tokens}, ts, patch_size=4) + out_lg = ce_lg.forward({"sentinel2_l2a": tokens}, ts, patch_size=4) + n = ce_st.embedding_dim_per_embedding_type + assert not torch.allclose( + out_st["sentinel2_l2a"][..., n : 3 * n], + out_lg["sentinel2_l2a"][..., n : 3 * n], + ) From 9c07fccdb9be7924b896dedb635c0ae45c6013e8 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Tue, 7 Apr 2026 05:06:23 +0000 Subject: [PATCH 2/4] Address code review feedback - Precompute temporal frequencies as a registered buffer instead of recomputing linspace/exp on every forward pass - Use TimestampEncodingMode enum as single source of truth for config validation instead of duplicated string literals - Remove duplicate self.embedding_size assignment - Fix docstring for get_static_temporal_encoding (clarify approximate DOY) - Use fixed timestamps in test_static_temporal_differs_from_legacy - Add edge case tests: odd dim assertion, invalid mode, freqs buffer shape Made-with: Cursor --- olmoearth_pretrain/nn/encodings.py | 35 +++++++++++----- olmoearth_pretrain/nn/flexi_vit.py | 24 ++++++++--- .../unit/nn/test_static_temporal_encoding.py | 40 +++++++++++++++---- 3 files changed, 75 insertions(+), 24 deletions(-) diff --git a/olmoearth_pretrain/nn/encodings.py b/olmoearth_pretrain/nn/encodings.py index 56e1b1a1c..bf2a4cf81 100644 --- a/olmoearth_pretrain/nn/encodings.py +++ b/olmoearth_pretrain/nn/encodings.py @@ -132,23 +132,40 @@ def get_month_encoding_table(encoding_dim: int) -> torch.Tensor: return month_table # (M, D) +def build_static_temporal_freqs(encoding_dim: int) -> torch.Tensor: + """Precompute geometric-spaced angular frequencies for static temporal encoding. + + Args: + encoding_dim: Total encoding dimension (must be even). Uses encoding_dim/2 + distinct frequencies. + + Returns: + Tensor of shape (encoding_dim // 2,) containing angular frequencies. + """ + assert encoding_dim % 2 == 0, f"encoding_dim must be even, got {encoding_dim}" + num_freqs = encoding_dim // 2 + exponents = torch.linspace(-7.0, 8.5, num_freqs) + return 2.0 * math.pi * (2.0**exponents) + + def get_static_temporal_encoding( - timestamps: torch.Tensor, encoding_dim: int + timestamps: torch.Tensor, + freqs: torch.Tensor, ) -> torch.Tensor: """Static multi-frequency sinusoidal temporal encoding. - Converts timestamps to a fractional year and applies geometric-spaced + Converts timestamps to a fractional year and applies precomputed sinusoidal frequencies ranging from ~128-year periods to daily resolution. + Day-of-year is approximated as ``month * 30.4375 + day``. Args: - timestamps: Tensor of shape (B, T, 3) where [..., 0] is day (1-31), + timestamps: Tensor of shape (B, T, 3) where [..., 0] is day, [..., 1] is month (0-indexed, 0-11), [..., 2] is year. - encoding_dim: Output encoding dimension (must be even). + freqs: Precomputed angular frequencies from ``build_static_temporal_freqs``. Returns: - Tensor of shape (B, T, encoding_dim). + Tensor of shape (B, T, 2 * len(freqs)). """ - assert encoding_dim % 2 == 0, f"encoding_dim must be even, got {encoding_dim}" day = timestamps[..., 0].float() month = timestamps[..., 1].float() year = timestamps[..., 2].float() @@ -156,9 +173,5 @@ def get_static_temporal_encoding( day_of_year = month * 30.4375 + day frac_year = year + day_of_year / 365.25 - 2020.0 - num_freqs = encoding_dim // 2 - exponents = torch.linspace(-7.0, 8.5, num_freqs, device=timestamps.device) - freqs = 2.0 * math.pi * (2.0**exponents) # (num_freqs,) - - angles = frac_year.unsqueeze(-1) * freqs # (B, T, num_freqs) + angles = frac_year.unsqueeze(-1) * freqs.to(timestamps.device) # (B, T, num_freqs) return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) diff --git a/olmoearth_pretrain/nn/flexi_vit.py b/olmoearth_pretrain/nn/flexi_vit.py index 8ac7d07ed..5241b73ce 100644 --- a/olmoearth_pretrain/nn/flexi_vit.py +++ b/olmoearth_pretrain/nn/flexi_vit.py @@ -25,6 +25,7 @@ from olmoearth_pretrain.nn.attention import Block from olmoearth_pretrain.nn.encodings import ( TimestampEncodingMode, + build_static_temporal_freqs, get_1d_sincos_pos_encoding, get_2d_sincos_pos_encoding_with_resolution, get_month_encoding_table, @@ -649,7 +650,6 @@ def __init__( modality.name for modality in supported_modalities ] self.tokenization_config = tokenization_config or TokenizationConfig() - self.embedding_size = embedding_size self.max_sequence_length = ( max_sequence_length # This max sequence length is a time dim thing ) @@ -657,6 +657,10 @@ def __init__( if self.timestamp_encoding_mode == TimestampEncodingMode.STATIC_TEMPORAL: self.pos_embed = None self.month_embed = None + self.register_buffer( + "_static_temporal_freqs", + build_static_temporal_freqs(2 * self.embedding_dim_per_embedding_type), + ) else: self.pos_embed = nn.Parameter( get_1d_sincos_pos_encoding( @@ -801,7 +805,9 @@ def _apply_encodings_per_modality( if modality.is_multitemporal and use_temporal_encodings: if self.timestamp_encoding_mode == TimestampEncodingMode.STATIC_TEMPORAL: assert timestamps is not None - ts_embed = get_static_temporal_encoding(timestamps, 2 * n) + ts_embed = get_static_temporal_encoding( + timestamps, self._static_temporal_freqs + ) ts_view = repeat(ts_embed, f"b t d -> {ein_string}", **ein_dict).to( device ) @@ -2128,9 +2134,12 @@ def validate(self) -> None: ) if self.tokenization_config is not None: self.tokenization_config.validate() - if self.timestamp_encoding_mode not in ("legacy", "static_temporal"): + try: + TimestampEncodingMode(self.timestamp_encoding_mode) + except ValueError: + valid = [m.value for m in TimestampEncodingMode] raise ValueError( - f"timestamp_encoding_mode must be 'legacy' or 'static_temporal', " + f"timestamp_encoding_mode must be one of {valid}, " f"got '{self.timestamp_encoding_mode}'" ) @@ -2185,9 +2194,12 @@ def validate(self) -> None: raise ValueError(f"Modality {modality} is not supported") if self.tokenization_config is not None: self.tokenization_config.validate() - if self.timestamp_encoding_mode not in ("legacy", "static_temporal"): + try: + TimestampEncodingMode(self.timestamp_encoding_mode) + except ValueError: + valid = [m.value for m in TimestampEncodingMode] raise ValueError( - f"timestamp_encoding_mode must be 'legacy' or 'static_temporal', " + f"timestamp_encoding_mode must be one of {valid}, " f"got '{self.timestamp_encoding_mode}'" ) diff --git a/tests/unit/nn/test_static_temporal_encoding.py b/tests/unit/nn/test_static_temporal_encoding.py index deccba786..c9fdf108c 100644 --- a/tests/unit/nn/test_static_temporal_encoding.py +++ b/tests/unit/nn/test_static_temporal_encoding.py @@ -1,9 +1,13 @@ """Tests for static_temporal encoding mode and get_static_temporal_encoding.""" +import pytest import torch from olmoearth_pretrain.data.constants import Modality -from olmoearth_pretrain.nn.encodings import get_static_temporal_encoding +from olmoearth_pretrain.nn.encodings import ( + build_static_temporal_freqs, + get_static_temporal_encoding, +) from olmoearth_pretrain.nn.flexi_vit import CompositeEncodings @@ -28,25 +32,40 @@ def _make_timestamps(b: int = 2, t: int = 4) -> torch.Tensor: def test_get_static_temporal_encoding_shape() -> None: """Output shape should be (B, T, encoding_dim).""" ts = _make_timestamps(3, 5) - enc = get_static_temporal_encoding(ts, 64) + freqs = build_static_temporal_freqs(64) + enc = get_static_temporal_encoding(ts, freqs) assert enc.shape == (3, 5, 64) def test_get_static_temporal_encoding_deterministic() -> None: """Same input should produce same output.""" ts = torch.tensor([[[15, 6, 2021], [1, 0, 2020]]]) - a = get_static_temporal_encoding(ts, 32) - b = get_static_temporal_encoding(ts, 32) + freqs = build_static_temporal_freqs(32) + a = get_static_temporal_encoding(ts, freqs) + b = get_static_temporal_encoding(ts, freqs) assert torch.allclose(a, b) def test_get_static_temporal_encoding_different_dates_differ() -> None: """Different dates should produce different encodings.""" ts = torch.tensor([[[1, 0, 2020], [1, 6, 2021]]]) - enc = get_static_temporal_encoding(ts, 32) + freqs = build_static_temporal_freqs(32) + enc = get_static_temporal_encoding(ts, freqs) assert not torch.allclose(enc[0, 0], enc[0, 1]) +def test_build_static_temporal_freqs_odd_dim_raises() -> None: + """Odd encoding_dim should raise AssertionError.""" + with pytest.raises(AssertionError, match="encoding_dim must be even"): + build_static_temporal_freqs(33) + + +def test_invalid_timestamp_encoding_mode_raises() -> None: + """Invalid mode should raise ValueError.""" + with pytest.raises(ValueError): + _make_ce("nonexistent_mode") + + def test_static_temporal_no_pos_embed() -> None: """static_temporal mode should not create pos_embed or month_embed.""" ce = _make_ce("static_temporal") @@ -54,6 +73,14 @@ def test_static_temporal_no_pos_embed() -> None: assert ce.month_embed is None +def test_static_temporal_has_freqs_buffer() -> None: + """static_temporal mode should register precomputed frequencies buffer.""" + ce = _make_ce("static_temporal") + assert hasattr(ce, "_static_temporal_freqs") + n = ce.embedding_dim_per_embedding_type + assert ce._static_temporal_freqs.shape == (n,) + + def test_legacy_has_pos_embed() -> None: """Legacy mode should still create pos_embed and month_embed.""" ce = _make_ce("legacy") @@ -80,7 +107,6 @@ def test_static_temporal_same_date_same_encoding() -> None: out = ce.forward({"sentinel2_l2a": tokens}, ts, patch_size=4) result = out["sentinel2_l2a"] n = ce.embedding_dim_per_embedding_type - # All time slots should have identical temporal encoding assert torch.allclose( result[0, 0, 0, 0, 0, n : 3 * n], result[0, 0, 0, 1, 0, n : 3 * n], @@ -93,7 +119,7 @@ def test_static_temporal_differs_from_legacy() -> None: ce_lg = _make_ce("legacy", embedding_size=16) B, H, W, T = 1, 2, 2, 4 tokens = torch.zeros(B, H, W, T, 3, 16) - ts = _make_timestamps(B, T) + ts = torch.tensor([[[1, 0, 2020], [15, 6, 2020], [1, 0, 2021], [15, 6, 2021]]]) out_st = ce_st.forward({"sentinel2_l2a": tokens}, ts, patch_size=4) out_lg = ce_lg.forward({"sentinel2_l2a": tokens}, ts, patch_size=4) n = ce_st.embedding_dim_per_embedding_type From d6a5caaae961ac213db7f7df8922b3e3e0a5b8fc Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Wed, 15 Apr 2026 18:14:14 +0000 Subject: [PATCH 3/4] Remove precomputed freqs buffer, compute inline instead Simpler implementation: get_static_temporal_encoding now takes encoding_dim and computes frequencies inline rather than requiring a precomputed buffer. The computation is trivial (linspace on ~64 elements) so caching provides no meaningful benefit. Co-Authored-By: Claude Opus 4.6 (1M context) --- olmoearth_pretrain/nn/encodings.py | 37 +++++++------------ olmoearth_pretrain/nn/flexi_vit.py | 7 +--- .../unit/nn/test_static_temporal_encoding.py | 25 ++++--------- 3 files changed, 21 insertions(+), 48 deletions(-) diff --git a/olmoearth_pretrain/nn/encodings.py b/olmoearth_pretrain/nn/encodings.py index bf2a4cf81..78714fe6a 100644 --- a/olmoearth_pretrain/nn/encodings.py +++ b/olmoearth_pretrain/nn/encodings.py @@ -132,40 +132,25 @@ def get_month_encoding_table(encoding_dim: int) -> torch.Tensor: return month_table # (M, D) -def build_static_temporal_freqs(encoding_dim: int) -> torch.Tensor: - """Precompute geometric-spaced angular frequencies for static temporal encoding. - - Args: - encoding_dim: Total encoding dimension (must be even). Uses encoding_dim/2 - distinct frequencies. - - Returns: - Tensor of shape (encoding_dim // 2,) containing angular frequencies. - """ - assert encoding_dim % 2 == 0, f"encoding_dim must be even, got {encoding_dim}" - num_freqs = encoding_dim // 2 - exponents = torch.linspace(-7.0, 8.5, num_freqs) - return 2.0 * math.pi * (2.0**exponents) - - def get_static_temporal_encoding( - timestamps: torch.Tensor, - freqs: torch.Tensor, + timestamps: torch.Tensor, encoding_dim: int ) -> torch.Tensor: """Static multi-frequency sinusoidal temporal encoding. - Converts timestamps to a fractional year and applies precomputed + Converts timestamps to a fractional year and applies geometric-spaced sinusoidal frequencies ranging from ~128-year periods to daily resolution. - Day-of-year is approximated as ``month * 30.4375 + day``. + The 1-cycle/year frequency naturally produces identical values for the + same day-of-year across different years. Args: - timestamps: Tensor of shape (B, T, 3) where [..., 0] is day, + timestamps: Tensor of shape (B, T, 3) where [..., 0] is day (1-31), [..., 1] is month (0-indexed, 0-11), [..., 2] is year. - freqs: Precomputed angular frequencies from ``build_static_temporal_freqs``. + encoding_dim: Output encoding dimension (must be even). Returns: - Tensor of shape (B, T, 2 * len(freqs)). + Tensor of shape (B, T, encoding_dim). """ + assert encoding_dim % 2 == 0, f"encoding_dim must be even, got {encoding_dim}" day = timestamps[..., 0].float() month = timestamps[..., 1].float() year = timestamps[..., 2].float() @@ -173,5 +158,9 @@ def get_static_temporal_encoding( day_of_year = month * 30.4375 + day frac_year = year + day_of_year / 365.25 - 2020.0 - angles = frac_year.unsqueeze(-1) * freqs.to(timestamps.device) # (B, T, num_freqs) + num_freqs = encoding_dim // 2 + exponents = torch.linspace(-7.0, 8.5, num_freqs, device=timestamps.device) + freqs = 2.0 * math.pi * (2.0**exponents) # (num_freqs,) + + angles = frac_year.unsqueeze(-1) * freqs # (B, T, num_freqs) return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) diff --git a/olmoearth_pretrain/nn/flexi_vit.py b/olmoearth_pretrain/nn/flexi_vit.py index 5241b73ce..a490f497b 100644 --- a/olmoearth_pretrain/nn/flexi_vit.py +++ b/olmoearth_pretrain/nn/flexi_vit.py @@ -25,7 +25,6 @@ from olmoearth_pretrain.nn.attention import Block from olmoearth_pretrain.nn.encodings import ( TimestampEncodingMode, - build_static_temporal_freqs, get_1d_sincos_pos_encoding, get_2d_sincos_pos_encoding_with_resolution, get_month_encoding_table, @@ -657,10 +656,6 @@ def __init__( if self.timestamp_encoding_mode == TimestampEncodingMode.STATIC_TEMPORAL: self.pos_embed = None self.month_embed = None - self.register_buffer( - "_static_temporal_freqs", - build_static_temporal_freqs(2 * self.embedding_dim_per_embedding_type), - ) else: self.pos_embed = nn.Parameter( get_1d_sincos_pos_encoding( @@ -806,7 +801,7 @@ def _apply_encodings_per_modality( if self.timestamp_encoding_mode == TimestampEncodingMode.STATIC_TEMPORAL: assert timestamps is not None ts_embed = get_static_temporal_encoding( - timestamps, self._static_temporal_freqs + timestamps, 2 * self.embedding_dim_per_embedding_type ) ts_view = repeat(ts_embed, f"b t d -> {ein_string}", **ein_dict).to( device diff --git a/tests/unit/nn/test_static_temporal_encoding.py b/tests/unit/nn/test_static_temporal_encoding.py index c9fdf108c..9e2ef58e5 100644 --- a/tests/unit/nn/test_static_temporal_encoding.py +++ b/tests/unit/nn/test_static_temporal_encoding.py @@ -5,7 +5,6 @@ from olmoearth_pretrain.data.constants import Modality from olmoearth_pretrain.nn.encodings import ( - build_static_temporal_freqs, get_static_temporal_encoding, ) from olmoearth_pretrain.nn.flexi_vit import CompositeEncodings @@ -32,32 +31,30 @@ def _make_timestamps(b: int = 2, t: int = 4) -> torch.Tensor: def test_get_static_temporal_encoding_shape() -> None: """Output shape should be (B, T, encoding_dim).""" ts = _make_timestamps(3, 5) - freqs = build_static_temporal_freqs(64) - enc = get_static_temporal_encoding(ts, freqs) + enc = get_static_temporal_encoding(ts, 64) assert enc.shape == (3, 5, 64) def test_get_static_temporal_encoding_deterministic() -> None: """Same input should produce same output.""" ts = torch.tensor([[[15, 6, 2021], [1, 0, 2020]]]) - freqs = build_static_temporal_freqs(32) - a = get_static_temporal_encoding(ts, freqs) - b = get_static_temporal_encoding(ts, freqs) + a = get_static_temporal_encoding(ts, 32) + b = get_static_temporal_encoding(ts, 32) assert torch.allclose(a, b) def test_get_static_temporal_encoding_different_dates_differ() -> None: """Different dates should produce different encodings.""" ts = torch.tensor([[[1, 0, 2020], [1, 6, 2021]]]) - freqs = build_static_temporal_freqs(32) - enc = get_static_temporal_encoding(ts, freqs) + enc = get_static_temporal_encoding(ts, 32) assert not torch.allclose(enc[0, 0], enc[0, 1]) -def test_build_static_temporal_freqs_odd_dim_raises() -> None: +def test_get_static_temporal_encoding_odd_dim_raises() -> None: """Odd encoding_dim should raise AssertionError.""" + ts = _make_timestamps(1, 1) with pytest.raises(AssertionError, match="encoding_dim must be even"): - build_static_temporal_freqs(33) + get_static_temporal_encoding(ts, 33) def test_invalid_timestamp_encoding_mode_raises() -> None: @@ -73,14 +70,6 @@ def test_static_temporal_no_pos_embed() -> None: assert ce.month_embed is None -def test_static_temporal_has_freqs_buffer() -> None: - """static_temporal mode should register precomputed frequencies buffer.""" - ce = _make_ce("static_temporal") - assert hasattr(ce, "_static_temporal_freqs") - n = ce.embedding_dim_per_embedding_type - assert ce._static_temporal_freqs.shape == (n,) - - def test_legacy_has_pos_embed() -> None: """Legacy mode should still create pos_embed and month_embed.""" ce = _make_ce("legacy") From 41901bba442d411c5b33c9f80b9f03f1eb22f276 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Wed, 15 Apr 2026 18:59:15 +0000 Subject: [PATCH 4/4] Add random_time + static_temporal experiment script and fix randperm device bug Co-Authored-By: Claude Opus 4.6 (1M context) --- olmoearth_pretrain/train/masking.py | 4 +- ..._no_s1_drop_random_time_static_temporal.py | 424 ++++++++++++++++++ 2 files changed, 427 insertions(+), 1 deletion(-) create mode 100644 scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_static_temporal.py diff --git a/olmoearth_pretrain/train/masking.py b/olmoearth_pretrain/train/masking.py index e0a8f7e54..c4bb12a90 100644 --- a/olmoearth_pretrain/train/masking.py +++ b/olmoearth_pretrain/train/masking.py @@ -1917,7 +1917,9 @@ def apply_mask( else: use_random_masking = False not_missing_t = torch.argwhere(missing_per_time)[:, 0] - not_missing_t = not_missing_t[torch.randperm(len(not_missing_t))] + not_missing_t = not_missing_t[ + torch.randperm(len(not_missing_t), device=not_missing_t.device) + ] num_encode = math.ceil(len(not_missing_t) * self.encode_ratio) encode_timestamps = not_missing_t[:num_encode] decode_timestamps = not_missing_t[num_encode:] diff --git a/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_static_temporal.py b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_static_temporal.py new file mode 100644 index 000000000..8f90e3404 --- /dev/null +++ b/scripts/vnext/single_bandset_band_dropout/base_band_dropout_no_s1_drop_random_time_static_temporal.py @@ -0,0 +1,424 @@ +"""Base script for single bandset + random band dropout (no S1) + random time with decode masking + static temporal encoding. + +Same as base_band_dropout_no_s1_drop_random_time.py but with +timestamp_encoding_mode="static_temporal" so: +- Temporal: multi-frequency sinusoidal (get_static_temporal_encoding) in legacy 2-quarter slot +- Spatial: legacy 2D sincos pos encoding (not per-token lat/lon) +- No latlon_dropout / spatial_dim_fraction / temporal_dim_fraction needed +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel.data_parallel import ( + DataParallelConfig, + DataParallelType, +) +from olmo_core.optim import AdamWConfig +from olmo_core.optim.scheduler import CosWithWarmup +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig + +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig +from olmoearth_pretrain.data.dataset import OlmoEarthDatasetConfig +from olmoearth_pretrain.evals.datasets.normalize import NormMethod +from olmoearth_pretrain.evals.metrics import EvalMetric +from olmoearth_pretrain.internal.common import ( + build_common_components as build_common_components_default, +) +from olmoearth_pretrain.internal.experiment import ( + CommonComponents, + OlmoEarthVisualizeConfig, + SubCmd, + main, +) +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexi_vit import ( + PoolingType, +) +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig +from olmoearth_pretrain.nn.tokenization import ModalityTokenization, TokenizationConfig +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthSpeedMonitorCallback, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import ( + DownstreamTaskConfig, + EvalMode, +) +from olmoearth_pretrain.train.loss import LossConfig +from olmoearth_pretrain.train.masking import MaskingConfig +from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( + ContrastiveLatentMIMTrainModuleConfig, +) + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 +RANDOM_BAND_DROPOUT_MAX_RATE = 0.2 + +S2_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + [ + "B02", + "B03", + "B04", + "B08", + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12", + "B01", + "B09", + ], + ] +) + +LANDSAT_SINGLE_BANDSET = ModalityTokenization( + band_groups=[ + ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], + ] +) + +ONLY_DECODE_MODALITIES = [ + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, +] + +# No S1 dropout — only apply band dropout to S2 and Landsat. +BAND_DROPOUT_MODALITIES = [ + Modality.SENTINEL2_L2A.name, + Modality.LANDSAT.name, +] + + +def _tokenization_config() -> TokenizationConfig: + return TokenizationConfig( + overrides={ + "sentinel2_l2a": S2_SINGLE_BANDSET, + "landsat": LANDSAT_SINGLE_BANDSET, + } + ) + + +def _masking_config( + tokenization_config: TokenizationConfig | None = None, +) -> MaskingConfig: + return MaskingConfig( + strategy_config={ + "type": "random_time_with_decode", + "encode_ratio": 0.5, + "decode_ratio": 0.5, + "random_ratio": 0.5, + "only_decode_modalities": ONLY_DECODE_MODALITIES, + }, + tokenization_config=tokenization_config, + ) + + +def build_common_components( + script: str, cmd: SubCmd, run_name: str, cluster: str, overrides: list[str] +) -> CommonComponents: + """Build the common components for an experiment.""" + config = build_common_components_default(script, cmd, run_name, cluster, overrides) + config.training_modalities = [ + Modality.SENTINEL2_L2A.name, + Modality.SENTINEL1.name, + Modality.LANDSAT.name, + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + ] + config.tokenization_config = _tokenization_config() + return config + + +def build_train_module_config( + common: CommonComponents, +) -> ContrastiveLatentMIMTrainModuleConfig: + """Build the train module config for an experiment.""" + return ContrastiveLatentMIMTrainModuleConfig( + optim_config=AdamWConfig(lr=0.0001, weight_decay=0.02, fused=False), + rank_microbatch_size=64, + masking_config=_masking_config(common.tokenization_config), + loss_config=LossConfig( + loss_config={ + "type": "modality_patch_discrimination_masked_negatives", + "tau": 0.1, + "same_target_threshold": 0.999, + "mask_negatives_for_modalities": ONLY_DECODE_MODALITIES, + } + ), + contrastive_config=LossConfig( + loss_config={ + "type": "InfoNCE", + "weight": 0.05, + } + ), + token_exit_cfg={modality: 0 for modality in common.training_modalities}, + max_grad_norm=1.0, + scheduler=CosWithWarmup(warmup_steps=8000), + ema_decay=(1.0, 1.0), + dp_config=DataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + ), + ) + + +def build_dataloader_config(common: CommonComponents) -> OlmoEarthDataLoaderConfig: + """Build the dataloader config for an experiment.""" + return OlmoEarthDataLoaderConfig( + num_workers=16, + global_batch_size=512, + token_budget=2250, + prefetch_factor=4, + sampled_hw_p_list=list(range(1, 13)), + min_patch_size=MIN_PATCH_SIZE, + max_patch_size=MAX_PATCH_SIZE, + work_dir=common.save_folder, + seed=3622, + num_masked_views=2, + masking_config=_masking_config(common.tokenization_config), + ) + + +def build_dataset_config(common: CommonComponents) -> OlmoEarthDatasetConfig: + """Build the dataset config for an experiment.""" + return OlmoEarthDatasetConfig( + h5py_dir="/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828", + training_modalities=common.training_modalities, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build the trainer config for an experiment.""" + MAX_DURATION = Duration.epochs(300) + METRICS_COLLECT_INTERVAL = 10 + CANCEL_CHECK_INTERVAL = 25 + LOAD_STRATEGY = LoadStrategy.if_available + WANDB_USERNAME = "eai-ai2" # nosec + WANDB_PROJECT = "2026_02_08_masked_neg" + PERMANENT_SAVE_INTERVAL = 5000 + EPHERMERAL_SAVE_INTERVAL = 250 + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project=WANDB_PROJECT, + entity=WANDB_USERNAME, + enabled=True, + ) + garbage_collector_callback = GarbageCollectorCallback(gc_interval=1) + EVAL_TASKS = { + "m-eurosat": DownstreamTaskConfig( + dataset="m-eurosat", + embedding_batch_size=128, + num_workers=0, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + eval_interval=Duration.steps(4000), + ), + "m_so2sat": DownstreamTaskConfig( + dataset="m-so2sat", + embedding_batch_size=128, + num_workers=4, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.KNN, + primary_metric=EvalMetric.ACCURACY, + ), + "mados": DownstreamTaskConfig( + dataset="mados", + embedding_batch_size=128, + probe_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=False, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(4000), + input_modalities=[Modality.SENTINEL2_L2A.name], + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MICRO_F1, + ), + "pastis": DownstreamTaskConfig( + dataset="pastis", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + primary_metric=EvalMetric.MIOU, + ), + "yemen_crop": DownstreamTaskConfig( + dataset="yemen_crop", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + eval_interval=Duration.steps(20000), + probe_lr=0.001, + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "geo_ecosystem_annual_test": DownstreamTaskConfig( + dataset="geo_ecosystem_annual_test", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.01, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + ), + "canada_wildfire_sat_eval_split": DownstreamTaskConfig( + dataset="canada_wildfire_sat_eval_split", + embedding_batch_size=32, + probe_batch_size=16, + patch_size=5, # TODO: This is changeable but we should know the valid sizes for inputs + num_workers=2, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + eval_mode=EvalMode.LINEAR_PROBE, + use_dice_loss=True, + primary_metric=EvalMetric.CLASS_F1, + primary_metric_class=1, + ), + } + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LOAD_STRATEGY, + save_folder=common.save_folder, + cancel_check_interval=CANCEL_CHECK_INTERVAL, + metrics_collect_interval=METRICS_COLLECT_INTERVAL, + max_duration=MAX_DURATION, + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("speed_monitor", OlmoEarthSpeedMonitorCallback()) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=EVAL_TASKS, + ), + ) + .with_callback("garbage_collector", garbage_collector_callback) + .with_callback("beaker", BeakerCallback()) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=PERMANENT_SAVE_INTERVAL, + ephemeral_save_interval=EPHERMERAL_SAVE_INTERVAL, + ), + ) + ) + return trainer_config + + +def build_visualize_config(common: CommonComponents) -> OlmoEarthVisualizeConfig: + """Build the visualize config for an experiment.""" + return OlmoEarthVisualizeConfig( + num_samples=None, + output_dir=str(f"{common.save_folder}/visualizations"), + std_multiplier=2.0, + ) + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["base_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + band_dropout_rate=RANDOM_BAND_DROPOUT_MAX_RATE, + random_band_dropout=True, + band_dropout_modalities=BAND_DROPOUT_MODALITIES, + timestamp_encoding_mode="static_temporal", + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + tokenization_config=common.tokenization_config, + timestamp_encoding_mode="static_temporal", + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + )