Skip to content

[FEATURE] Support AtttentionMask in SetTransformer #626

@arrjon

Description

@arrjon

Is your feature request related to a problem? Please describe.
During training, the summary network is invoked as
summary_metrics, summary_outputs = self._compute_summary_metrics(summary_variables, stage=stage)
so no **kwargs provided to fit() are forwarded to the summary network call function. As a result, it’s not possible (or at least not obvious to me) to pass an attention mask during training.

Additionally, in SetTransformer, the mask is not propagated through the attention stack; it is only passed to the pooling operation:

summary = self.attention_blocks(input_set, training=training)
summary = self.pooling_by_attention(summary, training=training, **kwargs)

Describe the solution you'd like

  1. Add first-class support for an attention mask throughout the summary network during training and inference.
  2. Fix SetTransformer kwargs propagation.

Additional context

This is part of the issue #625, where a minimal example is provided (the compositional part can be ignored for that).

Metadata

Metadata

Assignees

Labels

featureNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions