Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added support for temporal segmentation data in encoder decoder factory #355

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rijuld
Copy link

@rijuld rijuld commented Jan 9, 2025

No description provided.

@Joao-L-S-Almeida
Copy link
Member

Joao-L-S-Almeida commented Jan 9, 2025

Hi, @rijuld Maybe you would prefer to convert your PR to a draft and add the prefix [WIP] (work in progress) while work on it.
When you feel it's ready to be merged, you can convert it back to PR and add me as reviewer.

@rijuld rijuld changed the title Added support for temporal segmentation data in encoder decoder factory [DO NOT MERGE] [WIP] Added support for temporal segmentation data in encoder decoder factory Jan 9, 2025
@rijuld rijuld marked this pull request as draft January 9, 2025 16:41
@rijuld
Copy link
Author

rijuld commented Jan 9, 2025

Hi, @rijuld Maybe you would prefer to convert your PR to a draft and add the prefix [WIP] (work in progree) while works on it. When you feel it's ready to be merged, you can convert it back to PR and add me as reviewer.

@Joao-L-S-Almeida Thanks, Done.

@Joao-L-S-Almeida
Copy link
Member

Thank you for your contribution.


def forward(self, x):
# x is a list of tensors, each corresponding to a different timestamp
features = [self.encoder(t) for t in x]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your approach, thanks for starting this draft!

Other models in terratorch process data in the format [B, C, T, H, W]. That is also the format how data is provided from the generic data modules. It might be good to follow this this pattern und iterate over dim=2 instead of expecting a list.

self.pooling = torch.mean
elif pooling == "max":
self.pooling = torch.max
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have a method concat which merges the embeddings of all timestamps along the embedding dim. E.g. for testing how much accuracy we loose if the time stamps are averaged before the decoder.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires fix time stamps defined by the user so that the decoder gets the correct out_channels.

@@ -136,6 +164,10 @@ def build_model(
decoder, channel_list, decoder_kwargs, head_kwargs, num_classes=num_classes
)

# Add temporal wrapper if enabled
if use_temporal:
backbone = TemporalWrapper(backbone, pooling=temporal_pooling)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would apply the wrapper when building the backbone. I.e. backbone_use_temporal is passed as use_temporal to _get_backbone. The only important thing is that you save the backbone.out_channels in your Wrapper as self.out_channels as well (For concat you have to modify it as well).

@blumenstiel
Copy link
Collaborator

@rijuld Thanks for working on this feature, that will be very helpful!
To pass the checks, please remember to sign off your next commits (either git commit -s or VS Code/PyCharm have some checkboxes for it to do it automatically)
See https://github.com/IBM/terratorch/pull/355/checks?check_run_id=35365530364

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants