Skip to content

Commit d9ac633

Browse files
famuraMatthijsMatthijspalsmanuelgloeckler
authored
Add LRU-backed embedding networks (#1512)
* Before discussing the new way-to-go * init LRU implementation based on Orvieto et al * state dim fix * rough version * for loop running * add aggregate_func * fix aggregate * start asociative scan * Oscillator simulation for the test * add scan * add scan * bidirectional running but different results loop and scan * Doc improvements * more efficient bidirectional * fix bidirectional (outputs should also be flippedcd sbi * Alarm in the building * Isolated tests improved and running * LRU API changes * add whole pipeline test * Latest doc updates * fix m * Refactoring; By default bidriectional; Added test cases * start scan loop check test * Undo unintended changes * Final touch-up * Update sbi/neural_nets/embedding_nets/lru.py Co-authored-by: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> * Fixed type hints * Use Tuple instead of tuple * add warning that scan does not support backwards pass yet * Use lambda_abs property; Add test case * fix test * omit some linear weights of the GLU as in Smith, Warrington & Linderman, 2023 * accidentally swapped nonlinearity and dropout -fix --------- Co-authored-by: Matthijs <matthijs@example.com> Co-authored-by: Matthijspals <matthijs-pals@hotmail.com> Co-authored-by: Matthijs Pals <34062419+Matthijspals@users.noreply.github.com> Co-authored-by: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com>
1 parent 2280f33 commit d9ac633

File tree

3 files changed

+752
-2
lines changed

3 files changed

+752
-2
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from sbi.neural_nets.embedding_nets.causal_cnn import CausalCNNEmbedding
22
from sbi.neural_nets.embedding_nets.cnn import CNNEmbedding
33
from sbi.neural_nets.embedding_nets.fully_connected import FCEmbedding
4+
from sbi.neural_nets.embedding_nets.lru import LRUEmbedding
45
from sbi.neural_nets.embedding_nets.permutation_invariant import (
56
PermutationInvariantEmbedding,
67
)

0 commit comments

Comments
 (0)