Skip to content

Commit 7b926ff

Browse files
StefanWahljanfbmanuelgloeckler
authored
Add ResNet as embedding model (#1472)
* Add ResNet embedding model * Minor * Add reference * Update init * Bug fix for one channel input * Update init * Add test for ResNet embedding model for image like data * Update documentation and add 1D Resnet * Update tests * Apply suggestions from code review Co-authored-by: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> * Update sbi/neural_nets/embedding_nets/resnet.py * avoid nasty merge conflict * formatting --------- Co-authored-by: Jan <janfb@users.noreply.github.com> Co-authored-by: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Co-authored-by: manuelgloeckler <manu.gloeckler@hotmail.de>
1 parent dd9765b commit 7b926ff

File tree

4 files changed

+632
-0
lines changed

4 files changed

+632
-0
lines changed

sbi/neural_nets/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010

1111
def __getattr__(name):
1212
if name in [
13+
"CNNEmbedding",
14+
"FCEmbedding",
15+
"PermutationInvariantEmbedding",
16+
"ResNetEmbedding1D",
17+
"ResNetEmbedding2D",
1318
"CausalCNNEmbedding",
1419
"CNNEmbedding",
1520
"FCEmbedding",

sbi/neural_nets/embedding_nets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,7 @@
55
from sbi.neural_nets.embedding_nets.permutation_invariant import (
66
PermutationInvariantEmbedding,
77
)
8+
from sbi.neural_nets.embedding_nets.resnet import (
9+
ResNetEmbedding1D,
10+
ResNetEmbedding2D,
11+
)

0 commit comments

Comments
 (0)