Skip to content

Attention Refactor (WIP)#627

Open
stefanradev93 wants to merge 8 commits intodevfrom
attention_refactor
Open

Attention Refactor (WIP)#627
stefanradev93 wants to merge 8 commits intodevfrom
attention_refactor

Conversation

@stefanradev93
Copy link
Contributor

@stefanradev93 stefanradev93 commented Jan 29, 2026

This PR refactors the transformers module for internal consistency and directly exposes attention_mask and use_causal_mask in the relevant transformers. The following changes were made:

  • Transformer building blocks were moved into a dedicated attention module.
  • Abstract base class Transformer was added to easily tell apart transformer summaries from other summary nets.
  • Files were renamed to reflect semantic names.
  • TimeSeriesTransformer can now act as a many-to-many network (e.g., for modeling time-varying targets)

It also prepares to address ##626.

@arrjon @paul-buerkner It remains to decide how we want the attention mask passed? Should we search for it in the simulator outputs (as we do for other special args like sample_weights)? This would have the advantage that the mask can be constructed very flexibly.

@stefanradev93 stefanradev93 requested a review from arrjon January 29, 2026 23:40
@codecov
Copy link

codecov bot commented Jan 30, 2026

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Jan 30, 2026

Nice! Yes, I would love to search for attention_mask in the simulator output! In fact, I cannot really think of an alternative that would be nearly as nice. Is there any disadvantage to doing so?

@stefanradev93
Copy link
Contributor Author

stefanradev93 commented Jan 30, 2026

Not any I could think of, except some overhead on the user's side which may be unavoidable. @arrjon if you concur, I will proceed with adding the search for attention mask among simulation outputs.

@arrjon
Copy link
Member

arrjon commented Jan 30, 2026

Sounds good to me as it gives maximal flexibility to the user!

Copy link
Member

@arrjon arrjon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@stefanradev93
Copy link
Contributor Author

stefanradev93 commented Jan 31, 2026

I now allowed for mask and attention_mask keyword arguments in compute_metrics of the approximator, which are expected to be part of the output (or None - default) and be propagated both to the inference and summary network, depending on the signature of each network.

The case for sampling is tricky tho, as now the user will have to provide the attention_mask or mask arguments as **kwargs instead of passing them as part of the conditions dict. This may pose some confusion, but expecting the masks to be part of the conditions invalidates our internal logic...

@paul-buerkner
Copy link
Contributor

Can you elaborate why we cannot use conditions for this purpose?

@stefanradev93
Copy link
Contributor Author

I can solve it with a bunch of checks for now.

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Feb 3, 2026 via email

@stefanradev93
Copy link
Contributor Author

@arrjon Can you please check if my latest commit enables the functionality you needed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants