-
Notifications
You must be signed in to change notification settings - Fork 79
Open
Labels
featureNew feature or requestNew feature or request
Description
Is your feature request related to a problem? Please describe.
During training, the summary network is invoked as
summary_metrics, summary_outputs = self._compute_summary_metrics(summary_variables, stage=stage)
so no **kwargs provided to fit() are forwarded to the summary network call function. As a result, it’s not possible (or at least not obvious to me) to pass an attention mask during training.
Additionally, in SetTransformer, the mask is not propagated through the attention stack; it is only passed to the pooling operation:
summary = self.attention_blocks(input_set, training=training)
summary = self.pooling_by_attention(summary, training=training, **kwargs)
Describe the solution you'd like
- Add first-class support for an attention mask throughout the summary network during training and inference.
- Fix
SetTransformerkwargs propagation.
Additional context
This is part of the issue #625, where a minimal example is provided (the compositional part can be ignored for that).
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
featureNew feature or requestNew feature or request