I'm trying to get the ouput of the encoder only, as I want to fine-tune it on some S1 images for downstream tasks, and I would like to build my own pipeline on top of it.
import torch
from olmoearth_pretrain.model_loader import ModelID, load_model_from_id
from olmoearth_pretrain.datatypes import OlmoEarthSample, MaskedOlmoEarthSample
model = load_model_from_id(ModelID.OLMOEARTH_V1_TINY)
# shape for S1: [B, H, W, T, len(S1_bands)]
B, H, W, T, C_s1 = 1, 64, 64, 1, 2
s1 = torch.randn(B, H, W, T, C_s1)
timestamps = torch.tensor([[[15, 5, 2021]]], dtype=torch.float32) # shape [1, 1, 3]
raw_sample = OlmoEarthSample(
sentinel1=s1,
timestamps=timestamps.long(),
)
masked_sample = MaskedOlmoEarthSample.from_olmoearthsample(raw_sample)
encoder_out, *rest = model(masked_sample, patch_size=4)
RuntimeError Traceback (most recent call last)
Cell In[7], [line 21](vscode-notebook-cell:?execution_count=7&line=21)
15 raw_sample = OlmoEarthSample(
16 sentinel1=s1,
17 timestamps=timestamps.long(),
18 )
19 masked_sample = MaskedOlmoEarthSample.from_olmoearthsample(raw_sample)
---> [21](vscode-notebook-cell:?execution_count=7&line=21) encoder_out, *rest = encoder(masked_sample, patch_size=4)
23 print(encoder_out)
File c:\Users\tooRi\anaconda3\envs\gfm_env\Lib\site-packages\torch\nn\modules\module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> [1751](file:///C:/Users/tooRi/anaconda3/envs/gfm_env/Lib/site-packages/torch/nn/modules/module.py:1751) return self._call_impl(*args, **kwargs)
File c:\Users\tooRi\anaconda3\envs\gfm_env\Lib\site-packages\torch\nn\modules\module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> [1762](file:///C:/Users/tooRi/anaconda3/envs/gfm_env/Lib/site-packages/torch/nn/modules/module.py:1762) return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File e:\coding\olmoearth_pretrain\olmoearth_pretrain\nn\flexi_vit.py:1610, in Encoder.forward(self, x, patch_size, input_res, token_exit_cfg, fast_pass)
1605 patchified_tokens_and_masks = self.patch_embeddings.forward(x, patch_size)
1607 if token_exit_cfg is None or any(
1608 [exit_depth > 0 for exit_depth in token_exit_cfg.values()]
1609 ):
-> [1610](file:///E:/coding/olmoearth_pretrain/olmoearth_pretrain/nn/flexi_vit.py:1610) patchified_tokens_and_masks, token_norm_stats = self.apply_attn(
1611 x=patchified_tokens_and_masks,
1612 timestamps=x.timestamps,
1613 patch_size=patch_size,
1614 input_res=input_res,
1615 token_exit_cfg=token_exit_cfg,
1616 fast_pass=fast_pass,
1617 )
1618 else:
1619 token_norm_stats = {}
File e:\coding\olmoearth_pretrain\olmoearth_pretrain\nn\flexi_vit.py:1490, in Encoder.apply_attn(self, x, timestamps, patch_size, input_res, token_exit_cfg, fast_pass)
1485 tokens_dict.update(original_masks_dict)
1487 tokens, mask = self.collapse_and_combine_hwtc(tokens_dict)
1489 tokens, indices, new_mask, seq_lengths, max_seqlen, bool_mask = (
-> [1490](file:///E:/coding/olmoearth_pretrain/olmoearth_pretrain/nn/flexi_vit.py:1490) self._maybe_remove_masked_tokens(tokens, mask, fast_pass)
1491 )
1493 if exit_ids_seq is not None:
1494 exit_ids_seq, _, _, _, _ = self.remove_masked_tokens(
1495 exit_ids_seq, bool_mask
1496 )
File e:\coding\olmoearth_pretrain\olmoearth_pretrain\nn\flexi_vit.py:1443, in Encoder._maybe_remove_masked_tokens(self, tokens, mask, fast_pass)
1440 else:
1441 bool_mask = mask == MaskValue.ONLINE_ENCODER.value
1442 tokens, indices, new_mask, seq_lengths, max_seqlen = (
-> [1443](file:///E:/coding/olmoearth_pretrain/olmoearth_pretrain/nn/flexi_vit.py:1443) self.remove_masked_tokens(tokens, bool_mask)
1444 )
1445 return tokens, indices, new_mask, seq_lengths, max_seqlen, bool_mask
File e:\coding\olmoearth_pretrain\olmoearth_pretrain\nn\flexi_vit.py:1264, in Encoder.remove_masked_tokens(x, mask)
1262 sorted_mask, indices = torch.sort(mask, dim=1, descending=True, stable=True)
1263 # Now all the places where we want to keep the token are at the front of the tensor
-> [1264](file:///E:/coding/olmoearth_pretrain/olmoearth_pretrain/nn/flexi_vit.py:1264) x = x.gather(1, indices[:, :, None].expand_as(x))
1265 # Now all tokens that should be kept are first in the tensor
1266
1267 # set masked values to 0 (not really necessary since we'll ignore them anyway)
1268 x = x * sorted_mask.unsqueeze(-1)
RuntimeError: The expanded size of the tensor (256) must match the existing size (16) at non-singleton dimension 1. Target sizes: [1, 256, 192]. Tensor sizes: [64, 16, 1]
Hi,
I'm trying to get the ouput of the encoder only, as I want to fine-tune it on some S1 images for downstream tasks, and I would like to build my own pipeline on top of it.
This is a demo snippet I've built to get some outputs from the encoder:
but I get this strange error:
and I get the same error even if I only pass the input throught
encoder = model.encoder. Is my snippet wrong?