diff --git a/olmoearth_pretrain/nn/flexi_vit.py b/olmoearth_pretrain/nn/flexi_vit.py index e90ddd5c0..b0fa03d35 100644 --- a/olmoearth_pretrain/nn/flexi_vit.py +++ b/olmoearth_pretrain/nn/flexi_vit.py @@ -2,6 +2,7 @@ import logging import math +import warnings from dataclasses import dataclass from typing import Any @@ -620,7 +621,7 @@ def __init__( self, embedding_size: int, supported_modalities: list[ModalitySpec], - max_sequence_length: int, + max_sequence_length: int | None = None, learnable_channel_embeddings: bool = True, random_channel_embeddings: bool = False, tokenization_config: TokenizationConfig | None = None, @@ -631,12 +632,20 @@ def __init__( embedding_size: Size of token embeddings supported_modalities: Which modalities from Modality this model instantiation supports - max_sequence_length: Maximum sequence length + max_sequence_length: Deprecated, has no effect. Temporal position + encodings are now computed on-the-fly. learnable_channel_embeddings: Whether to use learnable channel embeddings random_channel_embeddings: Initialize channel embeddings randomly (zeros if False) tokenization_config: Optional config for custom band groupings """ super().__init__() + if max_sequence_length is not None: + warnings.warn( + "max_sequence_length is deprecated and has no effect. " + "Temporal position encodings are now computed on-the-fly.", + DeprecationWarning, + stacklevel=2, + ) self.embedding_size = embedding_size self.supported_modalities = supported_modalities self.supported_modality_names = [ @@ -644,22 +653,15 @@ def __init__( ] self.tokenization_config = tokenization_config or TokenizationConfig() self.embedding_size = embedding_size - self.max_sequence_length = ( - max_sequence_length # This max sequence length is a time dim thing - ) # TODO: we need to be able to calculate the size of the param based on what types of embeddings it will get # we have 4 embeddings types (pos_in_time, pos_in_space, month, channel) so each get # 0.25 of the dimension self.embedding_dim_per_embedding_type = int(embedding_size * 0.25) - # Position encodings for time dimension initialized to 1D sinusoidal encodings - self.pos_embed = nn.Parameter( - get_1d_sincos_pos_encoding( - torch.arange(max_sequence_length), - self.embedding_dim_per_embedding_type, - ), - requires_grad=False, - ) + # Temporal position encodings are computed on-the-fly via + # get_1d_sincos_pos_encoding so that any number of timesteps is supported + # without a pre-allocated table. + self._register_load_state_dict_pre_hook(self._drop_pos_embed_hook) # Month encodings month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type) self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True) @@ -701,6 +703,15 @@ def _init_weights(self, m: nn.Module) -> None: # TODO: fix the dtype here nn.init.constant_(m.bias, 0).to(torch.float32) + @staticmethod + def _drop_pos_embed_hook( + state_dict: dict, prefix: str, *args: object, **kwargs: object + ) -> None: + """Drop legacy pos_embed from old checkpoints so strict loading succeeds.""" + key = prefix + "pos_embed" + if key in state_dict: + del state_dict[key] + @staticmethod def calculate_gsd_ratio(input_res: float, patch_size: int) -> float: """Calculate the Ground Sample Distance ratio.""" @@ -795,9 +806,12 @@ def _apply_encodings_per_modality( modality_embed[..., :n] += channel_embed if modality.is_multitemporal and use_temporal_encodings: - # Time position encodings - time_embed = repeat(self.pos_embed[:t], f"t d -> {ein_string}", **ein_dict) - modality_embed[..., n : n * 2] += time_embed.to(device) + # Time position encodings (computed on-the-fly for arbitrary t) + pos_embed = get_1d_sincos_pos_encoding( + torch.arange(t, device=device), self.embedding_dim_per_embedding_type + ) + time_embed = repeat(pos_embed, f"t d -> {ein_string}", **ein_dict) + modality_embed[..., n : n * 2] += time_embed # Month encodings assert timestamps is not None diff --git a/tests/unit/nn/test_flexi_vit.py b/tests/unit/nn/test_flexi_vit.py index ab5a042b3..94ee5bb88 100644 --- a/tests/unit/nn/test_flexi_vit.py +++ b/tests/unit/nn/test_flexi_vit.py @@ -7,6 +7,7 @@ from einops import repeat from olmoearth_pretrain.data.constants import Modality, ModalitySpec +from olmoearth_pretrain.nn.encodings import get_1d_sincos_pos_encoding from olmoearth_pretrain.nn.flexi_vit import ( CompositeEncodings, Encoder, @@ -120,6 +121,56 @@ def test_apply_encodings_per_modality_grad( is not None ) + def test_dynamic_pos_embed_matches_static(self) -> None: + """On-the-fly sinusoidal encoding matches a pre-allocated table for overlapping positions.""" + dim = 48 + table = get_1d_sincos_pos_encoding(torch.arange(12), dim) + for t in [1, 5, 12, 17, 24]: + dynamic = get_1d_sincos_pos_encoding(torch.arange(t), dim) + overlap = min(t, 12) + assert torch.allclose(dynamic[:overlap], table[:overlap], atol=1e-6) + + def test_temporal_encoding_works_beyond_max_sequence_length( + self, + ) -> None: + """Forward pass works when t exceeds the configured max_sequence_length.""" + ce = CompositeEncodings( + embedding_size=16, + supported_modalities=[Modality.SENTINEL2_L2A], + max_sequence_length=12, + random_channel_embeddings=True, + ) + B, H, W, T, C, D = 2, 4, 4, 17, 3, 16 + tokens = torch.randn(B, H, W, T, C, D) + timestamps = torch.zeros(B, T, 3, dtype=torch.long) + timestamps[:, :, 1] = torch.arange(T) % 12 + result = ce._apply_encodings_per_modality( + "sentinel2_l2a", tokens, timestamps, patch_size=4, input_res=10 + ) + assert result.shape == tokens.shape + assert not (result == tokens).all() + + def test_temporal_encoding_values_match_expected(self) -> None: + """Temporal position encoding values match get_1d_sincos_pos_encoding directly.""" + embedding_size = 16 + n = embedding_size // 4 + ce = CompositeEncodings( + embedding_size=embedding_size, + supported_modalities=[Modality.SENTINEL2_L2A], + max_sequence_length=12, + random_channel_embeddings=True, + ) + B, H, W, T, C, D = 1, 2, 2, 5, 3, embedding_size + tokens = torch.zeros(B, H, W, T, C, D) + timestamps = torch.zeros(B, T, 3, dtype=torch.long) + timestamps[:, :, 1] = torch.arange(T) + result = ce._apply_encodings_per_modality( + "sentinel2_l2a", tokens, timestamps, patch_size=4, input_res=10 + ) + expected_time = get_1d_sincos_pos_encoding(torch.arange(T), n) + actual_time = result[0, 0, 0, :, 0, n : 2 * n] + assert torch.allclose(actual_time, expected_time, atol=1e-5) + # TODO: Add tests for when the inputs are completely masked or different dims or something class TestFlexiVitBase: