Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions olmoearth_pretrain/nn/flexi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,12 +649,12 @@ def __init__(
# 0.25 of the dimension
self.embedding_dim_per_embedding_type = int(embedding_size * 0.25)
# Position encodings for time dimension initialized to 1D sinusoidal encodings
self.pos_embed = nn.Parameter(
self.register_buffer(
"pos_embed",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this won't affect old checkpoints but did you run a check to confirm (i.e. we can still load existing checkpoints with this change)?

get_1d_sincos_pos_encoding(
torch.arange(max_sequence_length),
self.embedding_dim_per_embedding_type,
),
requires_grad=False,
)
# Month encodings
month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type)
Expand Down Expand Up @@ -940,26 +940,34 @@ def grab_modality_specific_dims(modality_data: Tensor) -> tuple[int, ...]:
"""
return modality_data.shape[1:-2] if modality_data.ndim > 3 else ()

# is naming here confusing if one of these channels can be missing?
def collapse_and_combine_hwtc(self, x: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
"""Collapse the tokens and masks, respectively, into two tensors."""
tokens, masks = [], []
def collapse_and_combine_hwtc(
self, x: dict[str, Tensor], include_masks: bool = True
) -> tuple[Tensor, Tensor | None]:
"""Collapse the tokens and masks, respectively, into two tensors.

Args:
x: Dict of per-modality tensors (and optionally their masks).
include_masks: If True, also collapse and return the mask tensors.
"""
tokens: list[Tensor] = []
masks: list[Tensor] = []
available_modalities = return_modalities_from_dict(x)
modalities_to_process = get_modalities_to_process(
available_modalities, self.supported_modality_names
)
for modality in modalities_to_process:
masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name(
modality
)
x_modality = x[modality]
x_modality_mask = x[masked_modality_name]
tokens.append(rearrange(x_modality, "b ... d -> b (...) d"))
masks.append(rearrange(x_modality_mask, "b ... -> b (...)"))
tokens = torch.cat(tokens, dim=1)
masks = torch.cat(masks, dim=1)
if include_masks:
masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name(
modality
)
x_modality_mask = x[masked_modality_name]
masks.append(rearrange(x_modality_mask, "b ... -> b (...)"))
tokens_out = torch.cat(tokens, dim=1)
masks_out = torch.cat(masks, dim=1) if include_masks else None

return tokens, masks
return tokens_out, masks_out

@staticmethod
def _construct_einops_pattern(
Expand Down Expand Up @@ -1321,9 +1329,8 @@ def add_removed_tokens(
def create_exit_seqs(
self,
tokens_only_dict: dict[str, Tensor],
mask_only_dict: dict[str, Tensor],
token_exit_cfg: dict[str, int] | None,
) -> tuple[Tensor | None]:
) -> Tensor | None:
"""Create the exit sequences and tokens."""
# Check that tokens_only_dict doesn't contain any mask keys
assert all(not key.endswith("_mask") for key in tokens_only_dict), (
Expand All @@ -1333,9 +1340,9 @@ def create_exit_seqs(
exit_ids_per_modality = self.create_token_exit_ids(
tokens_only_dict, token_exit_cfg
)
exit_ids_per_modality.update(mask_only_dict)
# Exit ids seqs tells us which layer to exit each token
exit_ids_seq, _ = self.collapse_and_combine_hwtc(exit_ids_per_modality)
exit_ids_seq, _ = self.collapse_and_combine_hwtc(
exit_ids_per_modality, include_masks=False
)
else:
exit_ids_seq = None
return exit_ids_seq
Expand Down Expand Up @@ -1470,11 +1477,9 @@ def apply_attn(
self.split_tokens_masks_and_dims(x)
)
# already a no-op but we could remove entirely
exit_ids_seq = self.create_exit_seqs(
tokens_only_dict, original_masks_dict, token_exit_cfg
)
exit_ids_seq = self.create_exit_seqs(tokens_only_dict, token_exit_cfg)
# exited tokens are just the linear projection
exited_tokens, _ = self.collapse_and_combine_hwtc(x)
exited_tokens, _ = self.collapse_and_combine_hwtc(x, include_masks=False)

tokens_dict = self.composite_encodings.forward(
tokens_only_dict,
Expand Down
6 changes: 2 additions & 4 deletions olmoearth_pretrain/nn/pooled_modality_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,11 +507,9 @@ def apply_attn(
tokens_only_dict, original_masks_dict, pre_pooled_modality_to_dims_dict = (
self.split_tokens_masks_and_dims(x)
)
exit_ids_seq = self.create_exit_seqs(
tokens_only_dict, original_masks_dict, token_exit_cfg
)
exit_ids_seq = self.create_exit_seqs(tokens_only_dict, token_exit_cfg)
# exited tokens are just the linear projection
exited_tokens, _ = self.collapse_and_combine_hwtc(x)
exited_tokens, _ = self.collapse_and_combine_hwtc(x, include_masks=False)

tokens_dict = self.composite_encodings.forward(
tokens_only_dict,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/nn/test_flexi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def test_collapse_and_combine_hwtc(self, flexi_helios_base: FlexiVitBase) -> Non
}
tokens, masks = flexi_helios_base.collapse_and_combine_hwtc(x)
assert tokens.shape == (B, 5, D)
assert masks is not None
assert masks.shape == (B, 5)

def test_split_and_expand_per_modality(self) -> None:
Expand Down
Loading