Implementation of Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks (Lee++ 2019) in jax and equinox.
Uses MNIST (loaded with torch) converted into point-clouds.
To do:
- ISAB blocks
- AdaNorm Layer normalisation
- Dataloader for mixed-cardinality sets
- Dropout
- LR schedule
