Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions olmoearth_pretrain/nn/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,23 @@
- 2D sinusoidal position encoding (for spatial data)
- 1D sinusoidal position encoding (for temporal data)
- Month encoding (for temporal data)
- Static multi-frequency temporal encoding
"""

import math
from enum import StrEnum

import numpy as np
import torch


class TimestampEncodingMode(StrEnum):
"""Mode for encoding temporal information."""

LEGACY = "legacy"
STATIC_TEMPORAL = "static_temporal"


def get_1d_sincos_pos_encoding(pos: torch.Tensor, encoding_dim: int) -> torch.Tensor:
"""Get 1D sin cos position encoding for a given set of positions.

Expand Down Expand Up @@ -119,3 +130,37 @@ def get_month_encoding_table(encoding_dim: int) -> torch.Tensor:
month_table = torch.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1)

return month_table # (M, D)


def get_static_temporal_encoding(
timestamps: torch.Tensor, encoding_dim: int
) -> torch.Tensor:
"""Static multi-frequency sinusoidal temporal encoding.

Converts timestamps to a fractional year and applies geometric-spaced
sinusoidal frequencies ranging from ~128-year periods to daily resolution.
The 1-cycle/year frequency naturally produces identical values for the
same day-of-year across different years.

Args:
timestamps: Tensor of shape (B, T, 3) where [..., 0] is day (1-31),
[..., 1] is month (0-indexed, 0-11), [..., 2] is year.
encoding_dim: Output encoding dimension (must be even).

Returns:
Tensor of shape (B, T, encoding_dim).
"""
assert encoding_dim % 2 == 0, f"encoding_dim must be even, got {encoding_dim}"
day = timestamps[..., 0].float()
month = timestamps[..., 1].float()
year = timestamps[..., 2].float()

day_of_year = month * 30.4375 + day
frac_year = year + day_of_year / 365.25 - 2020.0

num_freqs = encoding_dim // 2
exponents = torch.linspace(-7.0, 8.5, num_freqs, device=timestamps.device)
freqs = 2.0 * math.pi * (2.0**exponents) # (num_freqs,)

angles = frac_year.unsqueeze(-1) * freqs # (B, T, num_freqs)
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
94 changes: 68 additions & 26 deletions olmoearth_pretrain/nn/flexi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
)
from olmoearth_pretrain.nn.attention import Block
from olmoearth_pretrain.nn.encodings import (
TimestampEncodingMode,
get_1d_sincos_pos_encoding,
get_2d_sincos_pos_encoding_with_resolution,
get_month_encoding_table,
get_static_temporal_encoding,
)
from olmoearth_pretrain.nn.flexi_patch_embed import (
FlexiPatchEmbed,
Expand Down Expand Up @@ -624,6 +626,7 @@ def __init__(
learnable_channel_embeddings: bool = True,
random_channel_embeddings: bool = False,
tokenization_config: TokenizationConfig | None = None,
timestamp_encoding_mode: str = "legacy",
):
"""Initialize the composite encodings.

Expand All @@ -635,34 +638,34 @@ def __init__(
learnable_channel_embeddings: Whether to use learnable channel embeddings
random_channel_embeddings: Initialize channel embeddings randomly (zeros if False)
tokenization_config: Optional config for custom band groupings
timestamp_encoding_mode: "legacy" for time-index + month embeddings,
"static_temporal" for multi-frequency sinusoidal encoding
"""
super().__init__()
self.timestamp_encoding_mode = TimestampEncodingMode(timestamp_encoding_mode)
self.embedding_size = embedding_size
self.supported_modalities = supported_modalities
self.supported_modality_names = [
modality.name for modality in supported_modalities
]
self.tokenization_config = tokenization_config or TokenizationConfig()
self.embedding_size = embedding_size
self.max_sequence_length = (
max_sequence_length # This max sequence length is a time dim thing
)
# TODO: we need to be able to calculate the size of the param based on what types of embeddings it will get

# we have 4 embeddings types (pos_in_time, pos_in_space, month, channel) so each get
# 0.25 of the dimension
self.embedding_dim_per_embedding_type = int(embedding_size * 0.25)
# Position encodings for time dimension initialized to 1D sinusoidal encodings
self.pos_embed = nn.Parameter(
get_1d_sincos_pos_encoding(
torch.arange(max_sequence_length),
self.embedding_dim_per_embedding_type,
),
requires_grad=False,
)
# Month encodings
month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type)
self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True)
if self.timestamp_encoding_mode == TimestampEncodingMode.STATIC_TEMPORAL:
self.pos_embed = None
self.month_embed = None
else:
self.pos_embed = nn.Parameter(
get_1d_sincos_pos_encoding(
torch.arange(max_sequence_length),
self.embedding_dim_per_embedding_type,
),
requires_grad=False,
)
month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type)
self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True)
if not learnable_channel_embeddings and not random_channel_embeddings:
self.per_modality_channel_embeddings = nn.ParameterDict()
for modality in self.supported_modalities:
Expand Down Expand Up @@ -795,16 +798,29 @@ def _apply_encodings_per_modality(
modality_embed[..., :n] += channel_embed

if modality.is_multitemporal and use_temporal_encodings:
# Time position encodings
time_embed = repeat(self.pos_embed[:t], f"t d -> {ein_string}", **ein_dict)
modality_embed[..., n : n * 2] += time_embed.to(device)

# Month encodings
assert timestamps is not None
months = timestamps[:, :, 1]
month_embed = self.month_embed(months)
month_embed = repeat(month_embed, f"b t d -> {ein_string}", **ein_dict)
modality_embed[..., n * 2 : n * 3] += month_embed.to(device)
if self.timestamp_encoding_mode == TimestampEncodingMode.STATIC_TEMPORAL:
assert timestamps is not None
ts_embed = get_static_temporal_encoding(
timestamps, 2 * self.embedding_dim_per_embedding_type
)
ts_view = repeat(ts_embed, f"b t d -> {ein_string}", **ein_dict).to(
device
)
modality_embed[..., n : n * 3] += ts_view
else:
# Legacy: time-index position + month embeddings
assert self.pos_embed is not None
time_embed = repeat(
self.pos_embed[:t], f"t d -> {ein_string}", **ein_dict
)
modality_embed[..., n : n * 2] += time_embed.to(device)

assert timestamps is not None
assert self.month_embed is not None
months = timestamps[:, :, 1]
month_embed = self.month_embed(months)
month_embed = repeat(month_embed, f"b t d -> {ein_string}", **ein_dict)
modality_embed[..., n * 2 : n * 3] += month_embed.to(device)
if modality.is_spatial:
# Spatial encodings
assert input_res is not None
Expand Down Expand Up @@ -876,6 +892,7 @@ def __init__(
use_flash_attn: bool = False,
qk_norm: bool = False,
tokenization_config: TokenizationConfig | None = None,
timestamp_encoding_mode: str = "legacy",
) -> None:
"""Initialize the FlexiVitBase class."""
super().__init__()
Expand Down Expand Up @@ -915,6 +932,7 @@ def __init__(
learnable_channel_embeddings,
random_channel_embeddings,
tokenization_config=self._base_tokenization_config,
timestamp_encoding_mode=timestamp_encoding_mode,
)
self.apply(self._init_weights)

Expand Down Expand Up @@ -1112,6 +1130,7 @@ def __init__(
band_dropout_rate: float = 0.0,
random_band_dropout: bool = False,
band_dropout_modalities: list[str] | None = None,
timestamp_encoding_mode: str = "legacy",
):
"""Initialize the encoder.

Expand Down Expand Up @@ -1145,6 +1164,7 @@ def __init__(
random_band_dropout: If True, sample dropout rate from Uniform(0, band_dropout_rate).
band_dropout_modalities: If provided, only apply band dropout to these
modalities. If None, apply to all modalities. Default: None.
timestamp_encoding_mode: "legacy" or "static_temporal"
"""
self.tokenization_config = tokenization_config or TokenizationConfig()
super().__init__(
Expand All @@ -1160,6 +1180,7 @@ def __init__(
random_channel_embeddings=random_channel_embeddings,
qk_norm=qk_norm,
tokenization_config=self.tokenization_config,
timestamp_encoding_mode=timestamp_encoding_mode,
)
self.num_register_tokens = num_register_tokens
self.has_register_tokens = num_register_tokens > 0
Expand Down Expand Up @@ -1680,6 +1701,7 @@ def __init__(
use_flash_attn: bool = False,
qk_norm: bool = False,
tokenization_config: TokenizationConfig | None = None,
timestamp_encoding_mode: str = "legacy",
):
"""Initialize the predictor.

Expand All @@ -1698,6 +1720,7 @@ def __init__(
use_flash_attn: Whether to use flash attention
qk_norm: Whether to apply normalization to Q and K in attention
tokenization_config: Optional config for custom band groupings
timestamp_encoding_mode: "legacy" or "static_temporal"
"""
self.tokenization_config = tokenization_config or TokenizationConfig()
super().__init__(
Expand All @@ -1713,6 +1736,7 @@ def __init__(
use_flash_attn=use_flash_attn,
qk_norm=qk_norm,
tokenization_config=self.tokenization_config,
timestamp_encoding_mode=timestamp_encoding_mode,
)
self.learnable_channel_embeddings = learnable_channel_embeddings
self.random_channel_embeddings = random_channel_embeddings
Expand Down Expand Up @@ -2079,6 +2103,7 @@ class EncoderConfig(Config):
band_dropout_rate: float = 0.0
random_band_dropout: bool = False
band_dropout_modalities: list[str] | None = None
timestamp_encoding_mode: str = "legacy"

def __post_init__(self) -> None:
"""Coerce raw dicts to TokenizationConfig for old checkpoint compatibility."""
Expand All @@ -2104,6 +2129,14 @@ def validate(self) -> None:
)
if self.tokenization_config is not None:
self.tokenization_config.validate()
try:
TimestampEncodingMode(self.timestamp_encoding_mode)
except ValueError:
valid = [m.value for m in TimestampEncodingMode]
raise ValueError(
f"timestamp_encoding_mode must be one of {valid}, "
f"got '{self.timestamp_encoding_mode}'"
)

@property
def supported_modalities(self) -> list[ModalitySpec]:
Expand Down Expand Up @@ -2139,6 +2172,7 @@ class PredictorConfig(Config):
use_flash_attn: bool = False
qk_norm: bool = False
tokenization_config: TokenizationConfig | None = None
timestamp_encoding_mode: str = "legacy"

def __post_init__(self) -> None:
"""Coerce raw dicts to TokenizationConfig for old checkpoint compatibility."""
Expand All @@ -2155,6 +2189,14 @@ def validate(self) -> None:
raise ValueError(f"Modality {modality} is not supported")
if self.tokenization_config is not None:
self.tokenization_config.validate()
try:
TimestampEncodingMode(self.timestamp_encoding_mode)
except ValueError:
valid = [m.value for m in TimestampEncodingMode]
raise ValueError(
f"timestamp_encoding_mode must be one of {valid}, "
f"got '{self.timestamp_encoding_mode}'"
)

@property
def supported_modalities(self) -> list[ModalitySpec]:
Expand Down
4 changes: 3 additions & 1 deletion olmoearth_pretrain/train/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,9 @@ def apply_mask(
else:
use_random_masking = False
not_missing_t = torch.argwhere(missing_per_time)[:, 0]
not_missing_t = not_missing_t[torch.randperm(len(not_missing_t))]
not_missing_t = not_missing_t[
torch.randperm(len(not_missing_t), device=not_missing_t.device)
]
num_encode = math.ceil(len(not_missing_t) * self.encode_ratio)
encode_timestamps = not_missing_t[:num_encode]
decode_timestamps = not_missing_t[num_encode:]
Expand Down
Loading
Loading