Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions olmoearth_pretrain/nn/flexi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import math
import warnings
from dataclasses import dataclass
from typing import Any

Expand Down Expand Up @@ -620,7 +621,7 @@ def __init__(
self,
embedding_size: int,
supported_modalities: list[ModalitySpec],
max_sequence_length: int,
max_sequence_length: int | None = None,
learnable_channel_embeddings: bool = True,
random_channel_embeddings: bool = False,
tokenization_config: TokenizationConfig | None = None,
Expand All @@ -631,35 +632,36 @@ def __init__(
embedding_size: Size of token embeddings
supported_modalities: Which modalities from Modality this model
instantiation supports
max_sequence_length: Maximum sequence length
max_sequence_length: Deprecated, has no effect. Temporal position
encodings are now computed on-the-fly.
learnable_channel_embeddings: Whether to use learnable channel embeddings
random_channel_embeddings: Initialize channel embeddings randomly (zeros if False)
tokenization_config: Optional config for custom band groupings
"""
super().__init__()
if max_sequence_length is not None:
warnings.warn(
"max_sequence_length is deprecated and has no effect. "
"Temporal position encodings are now computed on-the-fly.",
DeprecationWarning,
stacklevel=2,
)
self.embedding_size = embedding_size
self.supported_modalities = supported_modalities
self.supported_modality_names = [
modality.name for modality in supported_modalities
]
self.tokenization_config = tokenization_config or TokenizationConfig()
self.embedding_size = embedding_size
self.max_sequence_length = (
max_sequence_length # This max sequence length is a time dim thing
)
# TODO: we need to be able to calculate the size of the param based on what types of embeddings it will get

# we have 4 embeddings types (pos_in_time, pos_in_space, month, channel) so each get
# 0.25 of the dimension
self.embedding_dim_per_embedding_type = int(embedding_size * 0.25)
# Position encodings for time dimension initialized to 1D sinusoidal encodings
self.pos_embed = nn.Parameter(
get_1d_sincos_pos_encoding(
torch.arange(max_sequence_length),
self.embedding_dim_per_embedding_type,
),
requires_grad=False,
)
# Temporal position encodings are computed on-the-fly via
# get_1d_sincos_pos_encoding so that any number of timesteps is supported
# without a pre-allocated table.
self._register_load_state_dict_pre_hook(self._drop_pos_embed_hook)
# Month encodings
month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type)
self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True)
Expand Down Expand Up @@ -701,6 +703,15 @@ def _init_weights(self, m: nn.Module) -> None:
# TODO: fix the dtype here
nn.init.constant_(m.bias, 0).to(torch.float32)

@staticmethod
def _drop_pos_embed_hook(
state_dict: dict, prefix: str, *args: object, **kwargs: object
) -> None:
"""Drop legacy pos_embed from old checkpoints so strict loading succeeds."""
key = prefix + "pos_embed"
if key in state_dict:
del state_dict[key]

@staticmethod
def calculate_gsd_ratio(input_res: float, patch_size: int) -> float:
"""Calculate the Ground Sample Distance ratio."""
Expand Down Expand Up @@ -795,9 +806,12 @@ def _apply_encodings_per_modality(
modality_embed[..., :n] += channel_embed

if modality.is_multitemporal and use_temporal_encodings:
# Time position encodings
time_embed = repeat(self.pos_embed[:t], f"t d -> {ein_string}", **ein_dict)
modality_embed[..., n : n * 2] += time_embed.to(device)
# Time position encodings (computed on-the-fly for arbitrary t)
pos_embed = get_1d_sincos_pos_encoding(
torch.arange(t, device=device), self.embedding_dim_per_embedding_type
)
time_embed = repeat(pos_embed, f"t d -> {ein_string}", **ein_dict)
modality_embed[..., n : n * 2] += time_embed

# Month encodings
assert timestamps is not None
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/nn/test_flexi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from einops import repeat

from olmoearth_pretrain.data.constants import Modality, ModalitySpec
from olmoearth_pretrain.nn.encodings import get_1d_sincos_pos_encoding
from olmoearth_pretrain.nn.flexi_vit import (
CompositeEncodings,
Encoder,
Expand Down Expand Up @@ -120,6 +121,56 @@ def test_apply_encodings_per_modality_grad(
is not None
)

def test_dynamic_pos_embed_matches_static(self) -> None:
"""On-the-fly sinusoidal encoding matches a pre-allocated table for overlapping positions."""
dim = 48
table = get_1d_sincos_pos_encoding(torch.arange(12), dim)
for t in [1, 5, 12, 17, 24]:
dynamic = get_1d_sincos_pos_encoding(torch.arange(t), dim)
overlap = min(t, 12)
assert torch.allclose(dynamic[:overlap], table[:overlap], atol=1e-6)

def test_temporal_encoding_works_beyond_max_sequence_length(
self,
) -> None:
"""Forward pass works when t exceeds the configured max_sequence_length."""
ce = CompositeEncodings(
embedding_size=16,
supported_modalities=[Modality.SENTINEL2_L2A],
max_sequence_length=12,
random_channel_embeddings=True,
)
B, H, W, T, C, D = 2, 4, 4, 17, 3, 16
tokens = torch.randn(B, H, W, T, C, D)
timestamps = torch.zeros(B, T, 3, dtype=torch.long)
timestamps[:, :, 1] = torch.arange(T) % 12
result = ce._apply_encodings_per_modality(
"sentinel2_l2a", tokens, timestamps, patch_size=4, input_res=10
)
assert result.shape == tokens.shape
assert not (result == tokens).all()

def test_temporal_encoding_values_match_expected(self) -> None:
"""Temporal position encoding values match get_1d_sincos_pos_encoding directly."""
embedding_size = 16
n = embedding_size // 4
ce = CompositeEncodings(
embedding_size=embedding_size,
supported_modalities=[Modality.SENTINEL2_L2A],
max_sequence_length=12,
random_channel_embeddings=True,
)
B, H, W, T, C, D = 1, 2, 2, 5, 3, embedding_size
tokens = torch.zeros(B, H, W, T, C, D)
timestamps = torch.zeros(B, T, 3, dtype=torch.long)
timestamps[:, :, 1] = torch.arange(T)
result = ce._apply_encodings_per_modality(
"sentinel2_l2a", tokens, timestamps, patch_size=4, input_res=10
)
expected_time = get_1d_sincos_pos_encoding(torch.arange(T), n)
actual_time = result[0, 0, 0, :, 0, n : 2 * n]
assert torch.allclose(actual_time, expected_time, atol=1e-5)


# TODO: Add tests for when the inputs are completely masked or different dims or something
class TestFlexiVitBase:
Expand Down
Loading