Skip to content

Cannot get only encoder's output #530

@notprime

Description

@notprime

Hi,

I'm trying to get the ouput of the encoder only, as I want to fine-tune it on some S1 images for downstream tasks, and I would like to build my own pipeline on top of it.

This is a demo snippet I've built to get some outputs from the encoder:

import torch
from olmoearth_pretrain.model_loader import ModelID, load_model_from_id
from olmoearth_pretrain.datatypes import OlmoEarthSample, MaskedOlmoEarthSample

model = load_model_from_id(ModelID.OLMOEARTH_V1_TINY)

# shape for S1: [B, H, W, T, len(S1_bands)]
B, H, W, T, C_s1 = 1, 64, 64, 1, 2
s1 = torch.randn(B, H, W, T, C_s1)
timestamps = torch.tensor([[[15, 5, 2021]]], dtype=torch.float32)  # shape [1, 1, 3]

raw_sample = OlmoEarthSample(
    sentinel1=s1,
    timestamps=timestamps.long(),
)
masked_sample = MaskedOlmoEarthSample.from_olmoearthsample(raw_sample)

encoder_out, *rest = model(masked_sample, patch_size=4)

but I get this strange error:

RuntimeError                              Traceback (most recent call last)
Cell In[7], [line 21](vscode-notebook-cell:?execution_count=7&line=21)
     15 raw_sample = OlmoEarthSample(
     16     sentinel1=s1,
     17     timestamps=timestamps.long(),
     18 )
     19 masked_sample = MaskedOlmoEarthSample.from_olmoearthsample(raw_sample)
---> [21](vscode-notebook-cell:?execution_count=7&line=21) encoder_out, *rest = encoder(masked_sample, patch_size=4)
     23 print(encoder_out)

File c:\Users\tooRi\anaconda3\envs\gfm_env\Lib\site-packages\torch\nn\modules\module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> [1751](file:///C:/Users/tooRi/anaconda3/envs/gfm_env/Lib/site-packages/torch/nn/modules/module.py:1751)     return self._call_impl(*args, **kwargs)

File c:\Users\tooRi\anaconda3\envs\gfm_env\Lib\site-packages\torch\nn\modules\module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1762](file:///C:/Users/tooRi/anaconda3/envs/gfm_env/Lib/site-packages/torch/nn/modules/module.py:1762)     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File e:\coding\olmoearth_pretrain\olmoearth_pretrain\nn\flexi_vit.py:1610, in Encoder.forward(self, x, patch_size, input_res, token_exit_cfg, fast_pass)
   1605 patchified_tokens_and_masks = self.patch_embeddings.forward(x, patch_size)
   1607 if token_exit_cfg is None or any(
   1608     [exit_depth > 0 for exit_depth in token_exit_cfg.values()]
   1609 ):
-> [1610](file:///E:/coding/olmoearth_pretrain/olmoearth_pretrain/nn/flexi_vit.py:1610)     patchified_tokens_and_masks, token_norm_stats = self.apply_attn(
   1611         x=patchified_tokens_and_masks,
   1612         timestamps=x.timestamps,
   1613         patch_size=patch_size,
   1614         input_res=input_res,
   1615         token_exit_cfg=token_exit_cfg,
   1616         fast_pass=fast_pass,
   1617     )
   1618 else:
   1619     token_norm_stats = {}

File e:\coding\olmoearth_pretrain\olmoearth_pretrain\nn\flexi_vit.py:1490, in Encoder.apply_attn(self, x, timestamps, patch_size, input_res, token_exit_cfg, fast_pass)
   1485 tokens_dict.update(original_masks_dict)
   1487 tokens, mask = self.collapse_and_combine_hwtc(tokens_dict)
   1489 tokens, indices, new_mask, seq_lengths, max_seqlen, bool_mask = (
-> [1490](file:///E:/coding/olmoearth_pretrain/olmoearth_pretrain/nn/flexi_vit.py:1490)     self._maybe_remove_masked_tokens(tokens, mask, fast_pass)
   1491 )
   1493 if exit_ids_seq is not None:
   1494     exit_ids_seq, _, _, _, _ = self.remove_masked_tokens(
   1495         exit_ids_seq, bool_mask
   1496     )

File e:\coding\olmoearth_pretrain\olmoearth_pretrain\nn\flexi_vit.py:1443, in Encoder._maybe_remove_masked_tokens(self, tokens, mask, fast_pass)
   1440 else:
   1441     bool_mask = mask == MaskValue.ONLINE_ENCODER.value
   1442     tokens, indices, new_mask, seq_lengths, max_seqlen = (
-> [1443](file:///E:/coding/olmoearth_pretrain/olmoearth_pretrain/nn/flexi_vit.py:1443)         self.remove_masked_tokens(tokens, bool_mask)
   1444     )
   1445 return tokens, indices, new_mask, seq_lengths, max_seqlen, bool_mask

File e:\coding\olmoearth_pretrain\olmoearth_pretrain\nn\flexi_vit.py:1264, in Encoder.remove_masked_tokens(x, mask)
   1262 sorted_mask, indices = torch.sort(mask, dim=1, descending=True, stable=True)
   1263 # Now all the places where we want to keep the token are at the front of the tensor
-> [1264](file:///E:/coding/olmoearth_pretrain/olmoearth_pretrain/nn/flexi_vit.py:1264) x = x.gather(1, indices[:, :, None].expand_as(x))
   1265 # Now all tokens that should be kept are first in the tensor
   1266 
   1267 # set masked values to 0 (not really necessary since we'll ignore them anyway)
   1268 x = x * sorted_mask.unsqueeze(-1)

RuntimeError: The expanded size of the tensor (256) must match the existing size (16) at non-singleton dimension 1.  Target sizes: [1, 256, 192].  Tensor sizes: [64, 16, 1]

and I get the same error even if I only pass the input throught encoder = model.encoder. Is my snippet wrong?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions