Add 1D time-series support and automatic input projection to TransformerEmbedding#1703
Conversation
…dd corresponding tests
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1703 +/- ##
=======================================
Coverage ? 84.67%
=======================================
Files ? 137
Lines ? 11493
Branches ? 0
=======================================
Hits ? 9732
Misses ? 1761
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more.
|
janfb
left a comment
There was a problem hiding this comment.
Thanks @satwiksps , good first draft!
I suggest to make all the config options actual arguments to __init__ to give full control to the user and have explicit type hints and defaults.
We could even think about using a dataclass TransformerConfig, but this would create additional overhead for the user, having to instantiate the dataclass config object externall, not sure. what's your take here?
janfb
left a comment
There was a problem hiding this comment.
Looks good!
Added more suggestion for making it look even better in terms of documentation. Would be great if you could add this as well 🙏
| super().__init__() | ||
| """ | ||
| Main class for constructing a transformer embedding | ||
| Basic configuration parameters: |
There was a problem hiding this comment.
| Basic configuration parameters: | |
| Args: |
| @@ -628,91 +628,123 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: | |||
|
|
|||
|
|
|||
| class TransformerEmbedding(nn.Module): | |||
There was a problem hiding this comment.
Would you be for adding a short explanatory class docstring here?
e.g., for an SBI user working with time series or images but not so familiar with transformers, give a concise overview how they can use this class? e.g., what means "vit" (for images), what means "is_causal" (for time series). etc. not a tutorial, just a brief high-level explanation. Maybe even with a short code Example block.
When we add this docstring here on the top class level then it will show up nicely in the Sphinx Documentation, e.g., like with the EnsemblePosterior here: https://sbi.readthedocs.io/en/latest/reference/_autosummary/sbi.inference.EnsemblePosterior.html#sbi.inference.EnsemblePosterior
20f7d47 to
8da9ec0
Compare
8da9ec0 to
4fb79ea
Compare
janfb
left a comment
There was a problem hiding this comment.
Looks very nice now, thanks!
Just one small comment.
| feature_space_dim)` or `(batch, num_channels, | ||
| height, width)` if using ViT |
There was a problem hiding this comment.
these line breaks are not needed, probably two lines are enough.
493d0c8 to
b0e65ab
Compare
janfb
left a comment
There was a problem hiding this comment.
Awesome, thanks @satwiksps 🙏
Support scalar time-series and simplify config handling in
TransformerEmbeddingThis PR improves the usability of
TransformerEmbeddingby adding native support for scalar (1D) time-series inputs and removing the requirement to manually specify a full configuration dictionary.Problem
TransformerEmbeddingdid not support scalar time-series shaped(batch, seq_len)or(batch, seq_len, 1).configdictionary, which was unintuitive and prone to errors.What this PR does
feature_space_dimvia a lazily-initializedinput_projlayer.(batch, seq_len)(batch, seq_len, 1)(batch, seq_len, D)(existing behavior)test_transformer_embedding_scalar_timeseriesto verify:Why this is valuable
TransformerEmbeddingusable out of the box.Testing
embedding_net_test.py.ruff, formatting, linting) passed successfully.Checklist
Closes #1696