From d277935588879d3eeb0b3dd4a961a0faad8fa01c Mon Sep 17 00:00:00 2001 From: Hadrien Sablon Date: Thu, 9 Apr 2026 10:15:57 -0700 Subject: [PATCH 1/4] minor readability update for create_exit_seqs and collapse_and_combine_hwtc --- olmoearth_pretrain/nn/flexi_vit.py | 51 ++++++++++--------- .../nn/pooled_modality_predictor.py | 6 +-- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/olmoearth_pretrain/nn/flexi_vit.py b/olmoearth_pretrain/nn/flexi_vit.py index af6db94db..5836191ea 100644 --- a/olmoearth_pretrain/nn/flexi_vit.py +++ b/olmoearth_pretrain/nn/flexi_vit.py @@ -940,26 +940,36 @@ def grab_modality_specific_dims(modality_data: Tensor) -> tuple[int, ...]: """ return modality_data.shape[1:-2] if modality_data.ndim > 3 else () - # is naming here confusing if one of these channels can be missing? - def collapse_and_combine_hwtc(self, x: dict[str, Tensor]) -> tuple[Tensor, Tensor]: - """Collapse the tokens and masks, respectively, into two tensors.""" - tokens, masks = [], [] + def collapse_and_combine_hwtc( + self, x: dict[str, Tensor], include_masks: bool = True + ) -> tuple[Tensor, Tensor | None]: + """Collapse the tokens and masks, respectively, into two tensors. + + Args: + x: Dict of per-modality tensors (and optionally their masks). + include_masks: If True, also collapse and return the mask tensors. + When False, only token tensors are required in *x* and the + returned mask will be ``None``. + """ + tokens: list[Tensor] = [] + masks: list[Tensor] = [] available_modalities = return_modalities_from_dict(x) modalities_to_process = get_modalities_to_process( available_modalities, self.supported_modality_names ) for modality in modalities_to_process: - masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name( - modality - ) x_modality = x[modality] - x_modality_mask = x[masked_modality_name] tokens.append(rearrange(x_modality, "b ... d -> b (...) d")) - masks.append(rearrange(x_modality_mask, "b ... -> b (...)")) - tokens = torch.cat(tokens, dim=1) - masks = torch.cat(masks, dim=1) + if include_masks: + masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name( + modality + ) + x_modality_mask = x[masked_modality_name] + masks.append(rearrange(x_modality_mask, "b ... -> b (...)")) + tokens_out = torch.cat(tokens, dim=1) + masks_out = torch.cat(masks, dim=1) if include_masks else None - return tokens, masks + return tokens_out, masks_out @staticmethod def _construct_einops_pattern( @@ -1321,11 +1331,9 @@ def add_removed_tokens( def create_exit_seqs( self, tokens_only_dict: dict[str, Tensor], - mask_only_dict: dict[str, Tensor], token_exit_cfg: dict[str, int] | None, - ) -> tuple[Tensor | None]: + ) -> Tensor | None: """Create the exit sequences and tokens.""" - # Check that tokens_only_dict doesn't contain any mask keys assert all(not key.endswith("_mask") for key in tokens_only_dict), ( "tokens_only_dict should not contain mask keys" ) @@ -1333,9 +1341,9 @@ def create_exit_seqs( exit_ids_per_modality = self.create_token_exit_ids( tokens_only_dict, token_exit_cfg ) - exit_ids_per_modality.update(mask_only_dict) - # Exit ids seqs tells us which layer to exit each token - exit_ids_seq, _ = self.collapse_and_combine_hwtc(exit_ids_per_modality) + exit_ids_seq, _ = self.collapse_and_combine_hwtc( + exit_ids_per_modality, include_masks=False + ) else: exit_ids_seq = None return exit_ids_seq @@ -1469,12 +1477,9 @@ def apply_attn( tokens_only_dict, original_masks_dict, modalities_to_dims_dict = ( self.split_tokens_masks_and_dims(x) ) - # already a no-op but we could remove entirely - exit_ids_seq = self.create_exit_seqs( - tokens_only_dict, original_masks_dict, token_exit_cfg - ) + exit_ids_seq = self.create_exit_seqs(tokens_only_dict, token_exit_cfg) # exited tokens are just the linear projection - exited_tokens, _ = self.collapse_and_combine_hwtc(x) + exited_tokens, _ = self.collapse_and_combine_hwtc(x, include_masks=False) tokens_dict = self.composite_encodings.forward( tokens_only_dict, diff --git a/olmoearth_pretrain/nn/pooled_modality_predictor.py b/olmoearth_pretrain/nn/pooled_modality_predictor.py index b7b47681d..536094cb9 100644 --- a/olmoearth_pretrain/nn/pooled_modality_predictor.py +++ b/olmoearth_pretrain/nn/pooled_modality_predictor.py @@ -507,11 +507,9 @@ def apply_attn( tokens_only_dict, original_masks_dict, pre_pooled_modality_to_dims_dict = ( self.split_tokens_masks_and_dims(x) ) - exit_ids_seq = self.create_exit_seqs( - tokens_only_dict, original_masks_dict, token_exit_cfg - ) + exit_ids_seq = self.create_exit_seqs(tokens_only_dict, token_exit_cfg) # exited tokens are just the linear projection - exited_tokens, _ = self.collapse_and_combine_hwtc(x) + exited_tokens, _ = self.collapse_and_combine_hwtc(x, include_masks=False) tokens_dict = self.composite_encodings.forward( tokens_only_dict, From e8739d0fd129269914f5b8a26b93c4320acd5c9e Mon Sep 17 00:00:00 2001 From: Hadrien Sablon Date: Thu, 9 Apr 2026 10:27:27 -0700 Subject: [PATCH 2/4] minor edits --- olmoearth_pretrain/nn/flexi_vit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olmoearth_pretrain/nn/flexi_vit.py b/olmoearth_pretrain/nn/flexi_vit.py index 5836191ea..aa43a985e 100644 --- a/olmoearth_pretrain/nn/flexi_vit.py +++ b/olmoearth_pretrain/nn/flexi_vit.py @@ -948,8 +948,6 @@ def collapse_and_combine_hwtc( Args: x: Dict of per-modality tensors (and optionally their masks). include_masks: If True, also collapse and return the mask tensors. - When False, only token tensors are required in *x* and the - returned mask will be ``None``. """ tokens: list[Tensor] = [] masks: list[Tensor] = [] @@ -1334,6 +1332,7 @@ def create_exit_seqs( token_exit_cfg: dict[str, int] | None, ) -> Tensor | None: """Create the exit sequences and tokens.""" + # Check that tokens_only_dict doesn't contain any mask keys assert all(not key.endswith("_mask") for key in tokens_only_dict), ( "tokens_only_dict should not contain mask keys" ) @@ -1477,6 +1476,7 @@ def apply_attn( tokens_only_dict, original_masks_dict, modalities_to_dims_dict = ( self.split_tokens_masks_and_dims(x) ) + # already a no-op but we could remove entirely exit_ids_seq = self.create_exit_seqs(tokens_only_dict, token_exit_cfg) # exited tokens are just the linear projection exited_tokens, _ = self.collapse_and_combine_hwtc(x, include_masks=False) From 03c46cb024b9666b71ad5806c5389f5360f9cdf9 Mon Sep 17 00:00:00 2001 From: Hadrien Sablon Date: Thu, 9 Apr 2026 11:27:23 -0700 Subject: [PATCH 3/4] buffer --- olmoearth_pretrain/nn/flexi_vit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olmoearth_pretrain/nn/flexi_vit.py b/olmoearth_pretrain/nn/flexi_vit.py index aa43a985e..a95b23a6e 100644 --- a/olmoearth_pretrain/nn/flexi_vit.py +++ b/olmoearth_pretrain/nn/flexi_vit.py @@ -649,12 +649,12 @@ def __init__( # 0.25 of the dimension self.embedding_dim_per_embedding_type = int(embedding_size * 0.25) # Position encodings for time dimension initialized to 1D sinusoidal encodings - self.pos_embed = nn.Parameter( + self.register_buffer( + "pos_embed", get_1d_sincos_pos_encoding( torch.arange(max_sequence_length), self.embedding_dim_per_embedding_type, ), - requires_grad=False, ) # Month encodings month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type) From 87b300e54e8cd88a8330998742693ad842928b33 Mon Sep 17 00:00:00 2001 From: Hadrien Sablon Date: Thu, 9 Apr 2026 11:55:35 -0700 Subject: [PATCH 4/4] lint stuff --- tests/unit/nn/test_flexi_vit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/nn/test_flexi_vit.py b/tests/unit/nn/test_flexi_vit.py index ab5a042b3..85aff9928 100644 --- a/tests/unit/nn/test_flexi_vit.py +++ b/tests/unit/nn/test_flexi_vit.py @@ -156,6 +156,7 @@ def test_collapse_and_combine_hwtc(self, flexi_helios_base: FlexiVitBase) -> Non } tokens, masks = flexi_helios_base.collapse_and_combine_hwtc(x) assert tokens.shape == (B, 5, D) + assert masks is not None assert masks.shape == (B, 5) def test_split_and_expand_per_modality(self) -> None: