Skip to content

Add 1D time-series support and automatic input projection to TransformerEmbedding#1703

Merged
janfb merged 4 commits intosbi-dev:mainfrom
satwiksps:feat/transformerembedding-1d-support
Nov 30, 2025
Merged

Add 1D time-series support and automatic input projection to TransformerEmbedding#1703
janfb merged 4 commits intosbi-dev:mainfrom
satwiksps:feat/transformerembedding-1d-support

Conversation

@satwiksps
Copy link
Contributor

Support scalar time-series and simplify config handling in TransformerEmbedding

This PR improves the usability of TransformerEmbedding by adding native support for scalar (1D) time-series inputs and removing the requirement to manually specify a full configuration dictionary.

Problem

What this PR does

  • Adds automatic projection from scalar inputs to the required feature_space_dim via a lazily-initialized input_proj layer.
  • Extends the forward pass to handle:
    • (batch, seq_len)
    • (batch, seq_len, 1)
    • (batch, seq_len, D) (existing behavior)
  • Ensures compatibility with the attention mechanism (e.g., head dimension constraints).
  • Adds new test test_transformer_embedding_scalar_timeseries to verify:
    • Correct handling of 1D inputs
    • Proper projection behavior
    • Successful integration into embedding API
  • Updates docstrings to reflect new behavior.

Why this is valuable

  • Makes the transformer embedding consistent with other embedding nets (RNN, CNN, LRU) that already accept scalar time-series.
  • Eliminates user friction by making TransformerEmbedding usable out of the box.
  • Enables simplified tutorials, workflows, and example code for time-series inference.

Testing

  • New test added in embedding_net_test.py.
  • All transformer-specific tests pass.
  • Pre-commit (ruff, formatting, linting) passed successfully.

Checklist

  • Code changes follow project style.
  • Added targeted tests.
  • Updated docstrings
  • Pre-commit hooks passed.
  • No breaking changes introduced.

Closes #1696

@codecov
Copy link

codecov bot commented Nov 15, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (main@2c216c2). Learn more about missing BASE report.
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1703   +/-   ##
=======================================
  Coverage        ?   84.67%           
=======================================
  Files           ?      137           
  Lines           ?    11493           
  Branches        ?        0           
=======================================
  Hits            ?     9732           
  Misses          ?     1761           
  Partials        ?        0           
Flag Coverage Δ
unittests 84.67% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/neural_nets/embedding_nets/transformer.py 93.40% <100.00%> (ø)

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @satwiksps , good first draft!

I suggest to make all the config options actual arguments to __init__ to give full control to the user and have explicit type hints and defaults.

We could even think about using a dataclass TransformerConfig, but this would create additional overhead for the user, having to instantiate the dataclass config object externall, not sure. what's your take here?

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Added more suggestion for making it look even better in terms of documentation. Would be great if you could add this as well 🙏

super().__init__()
"""
Main class for constructing a transformer embedding
Basic configuration parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Basic configuration parameters:
Args:

@@ -628,91 +628,123 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:


class TransformerEmbedding(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be for adding a short explanatory class docstring here?

e.g., for an SBI user working with time series or images but not so familiar with transformers, give a concise overview how they can use this class? e.g., what means "vit" (for images), what means "is_causal" (for time series). etc. not a tutorial, just a brief high-level explanation. Maybe even with a short code Example block.

When we add this docstring here on the top class level then it will show up nicely in the Sphinx Documentation, e.g., like with the EnsemblePosterior here: https://sbi.readthedocs.io/en/latest/reference/_autosummary/sbi.inference.EnsemblePosterior.html#sbi.inference.EnsemblePosterior

@satwiksps satwiksps force-pushed the feat/transformerembedding-1d-support branch 2 times, most recently from 20f7d47 to 8da9ec0 Compare November 20, 2025 20:18
@satwiksps satwiksps force-pushed the feat/transformerembedding-1d-support branch from 8da9ec0 to 4fb79ea Compare November 20, 2025 20:24
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice now, thanks!

Just one small comment.

Comment on lines 911 to 912
feature_space_dim)` or `(batch, num_channels,
height, width)` if using ViT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these line breaks are not needed, probably two lines are enough.

@satwiksps satwiksps force-pushed the feat/transformerembedding-1d-support branch from 493d0c8 to b0e65ab Compare November 21, 2025 16:12
@satwiksps satwiksps requested a review from janfb November 29, 2025 15:08
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks @satwiksps 🙏

@janfb janfb merged commit 3316888 into sbi-dev:main Nov 30, 2025
9 checks passed
@janfb janfb mentioned this pull request Jan 9, 2026
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support scalar time-series and simplify config handling in TransformerEmbedding

2 participants