Skip to content

Commit 2280f33

Browse files
authored
feat: add embedding net that uses 1D causal convolutions (#1459) (#1499)
* first wavenet inspired causal convolution embedding net * add same padding and a kernel size sanity check * changes requested in PR review * pass dilation scheme as a string * take activation outside of causalConv1d, and add sanity check for dilation size
1 parent cc6b2cd commit 2280f33

File tree

4 files changed

+325
-1
lines changed

4 files changed

+325
-1
lines changed

sbi/neural_nets/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99

1010

1111
def __getattr__(name):
12-
if name in ["CNNEmbedding", "FCEmbedding", "PermutationInvariantEmbedding"]:
12+
if name in [
13+
"CausalCNNEmbedding",
14+
"CNNEmbedding",
15+
"FCEmbedding",
16+
"PermutationInvariantEmbedding",
17+
]:
1318
raise ImportError(
1419
"As of sbi v0.23.0, you have to import embedding networks from "
1520
"`sbi.neural_nets.embedding_nets`. For example, use: "

sbi/neural_nets/embedding_nets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sbi.neural_nets.embedding_nets.causal_cnn import CausalCNNEmbedding
12
from sbi.neural_nets.embedding_nets.cnn import CNNEmbedding
23
from sbi.neural_nets.embedding_nets.fully_connected import FCEmbedding
34
from sbi.neural_nets.embedding_nets.permutation_invariant import (
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2+
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3+
4+
from typing import List, Optional, Tuple, Union
5+
6+
from torch import Tensor, nn
7+
8+
from sbi.neural_nets.embedding_nets.cnn import calculate_filter_output_size
9+
10+
11+
def causalConv1d(
12+
in_channels: int,
13+
out_channels: int,
14+
kernel_size: int,
15+
dilation: int = 1,
16+
stride: int = 1,
17+
) -> nn.Module:
18+
"""Returns a causal convolution by left padding the input
19+
20+
Args:
21+
in_channels: number of input channels
22+
out_channels: number of output channels wanted
23+
kernel_size: wanted kernel size
24+
dilation: dilation to use in the convolution.
25+
stride: stride to use in the convolution.
26+
Stride and dilation cannot both be > 1.
27+
28+
Returns:
29+
An nn.Sequential object that represents a 1D causal convolution.
30+
"""
31+
assert not (dilation > 1 and stride > 1), (
32+
"we don't allow combining stride with dilation."
33+
)
34+
padding_size = dilation * (kernel_size - 1)
35+
padding = nn.ZeroPad1d(padding=(padding_size, 0))
36+
conv_layer = nn.Conv1d(
37+
in_channels=in_channels,
38+
out_channels=out_channels,
39+
kernel_size=kernel_size,
40+
dilation=dilation,
41+
stride=stride,
42+
padding=0,
43+
)
44+
return nn.Sequential(padding, conv_layer)
45+
46+
47+
def WaveNetSRLikeAggregator(
48+
in_channels: int,
49+
num_timepoints: int,
50+
output_dim: int,
51+
activation: nn.Module = nn.LeakyReLU(inplace=True),
52+
kernel_sizes: Optional[List] = None,
53+
intermediate_channel_sizes: Optional[List] = None,
54+
stride_sizes: Union[int, List] = 1,
55+
) -> nn.Module:
56+
"""
57+
Creates a non-causal 1D CNN aggregator based on the WaveNet speach recognition task
58+
59+
By default this function creates an aggregator with two CNN layers,
60+
after every convolution a maxpooling operation halves the number of timepoints.
61+
The final CNN will have as many channels as the desired output dimension.
62+
A global average pooling operation is applied in the end. The dimension of the
63+
output will thus be (batch_size, output_dim, 1) regardless of the input size.
64+
65+
Args:
66+
in_channels: number of channels at input.
67+
num_timepoints: length of the input.
68+
output dim: wanted number of features as output.
69+
activation: activation to apply after the convolution.
70+
kernel_sizes: (optional) alter the kernel size used and the number of CNN layers
71+
(through the length of the kernel size vector).
72+
intermediate_channel_sizes: (optional) alter the intermediate channel sizes
73+
used, should have length = len(kernel_sizes) - 1.
74+
stride_sizes = Optional alter the stride used, either a vector of
75+
len = len(kernel_sizes) or a single integer, in which case the same stride
76+
is used in every convolution.
77+
78+
Returns:
79+
nn.Module object that contains a sequence of CNN and max_pool layer
80+
and finally a global average pooling layer.
81+
"""
82+
aggregator_out_shape = (
83+
in_channels,
84+
num_timepoints,
85+
)
86+
if kernel_sizes is None:
87+
kernel_sizes = [
88+
min(9, aggregator_out_shape[-1]),
89+
min(5, int(aggregator_out_shape[-1] / 2)),
90+
]
91+
if intermediate_channel_sizes is None:
92+
intermediate_channel_sizes = [64]
93+
assert len(intermediate_channel_sizes) == len(kernel_sizes) - 1, (
94+
"Provided kernel size list should be exactly one element longer "
95+
"than channel size list."
96+
)
97+
intermediate_channel_sizes += [output_dim]
98+
if isinstance(stride_sizes, List):
99+
assert len(stride_sizes) == len(kernel_sizes), (
100+
"Provided stride size list should be have the same size as"
101+
"the kernel size list."
102+
)
103+
else:
104+
stride_sizes = [stride_sizes] * len(kernel_sizes)
105+
106+
non_causal_layers = []
107+
for ll in range(len(kernel_sizes)):
108+
print(aggregator_out_shape)
109+
conv_layer = nn.Conv1d(
110+
in_channels=in_channels if ll == 0 else intermediate_channel_sizes[ll - 1],
111+
out_channels=intermediate_channel_sizes[ll],
112+
kernel_size=kernel_sizes[ll],
113+
stride=stride_sizes[ll],
114+
padding='same',
115+
)
116+
maxpool = nn.MaxPool1d(kernel_size=2 if aggregator_out_shape[-1] > 2 else 1)
117+
non_causal_layers += [conv_layer, activation, maxpool]
118+
aggregator_out_shape = (
119+
intermediate_channel_sizes[ll],
120+
int(
121+
calculate_filter_output_size(
122+
aggregator_out_shape[-1],
123+
(kernel_sizes[ll] - 1) / 2,
124+
1,
125+
kernel_sizes[ll],
126+
stride_sizes[ll],
127+
)
128+
/ 2
129+
),
130+
)
131+
print(aggregator_out_shape)
132+
aggregator = nn.Sequential(*non_causal_layers, nn.AdaptiveAvgPool1d(1))
133+
return aggregator
134+
135+
136+
class CausalCNNEmbedding(nn.Module):
137+
def __init__(
138+
self,
139+
input_shape: Tuple,
140+
in_channels: int = 1,
141+
out_channels_per_layer: Optional[List] = None,
142+
dilation: Union[str, List] = "exponential_cyclic",
143+
num_conv_layers: int = 5,
144+
activation: nn.Module = nn.LeakyReLU(inplace=True),
145+
pool_kernel_size: int = 160,
146+
kernel_size: int = 2,
147+
aggregator: Optional[nn.Module] = None,
148+
output_dim: int = 20,
149+
):
150+
"""Embedding network that uses 1D causal convolutions
151+
152+
This is a simplified version of the architecture introduced for
153+
the speech recognition task in the WaveNet paper (van den Oord, et al. (2016))
154+
155+
After several dilated causal convolutions (that maintain the dimensionality
156+
of the input), an aggregator network is used to bring down the dimensionality.
157+
You can provide an aggregator network that you deem reasonable for your data.
158+
If you do not provide an aggregator network yourself, a default aggregator
159+
is used. This default aggregator is based on the WaveNet paper's description
160+
of their Speech Recognition Task, and uses non-causal convolutions and pooling
161+
layers, and global average poolingg to obtain a final low dimensional embedding.
162+
163+
Args:
164+
input_shape: Dimensionality of the input e.g. (num_timepoints,),
165+
currently only 1D is supported.
166+
in_channels: Number of input channels, default = 1.
167+
out_channels_per_layer: number of out_channels for each layer, number
168+
of entries should correspond with num_conv_layers passed below.
169+
Default = 16 in every convolutional layer.
170+
dilation: type of dilation to use either one of "none" (dilation = 1
171+
in every layer), "exponential" (increase dilation by a factor of 2
172+
every layer), "exponential_cyclic" (as exponential, but reset to 1
173+
after dilation = 2**9) or pass a list with dilation size per layer.
174+
By default the cyclic, exponential scheme from WaveNet is used.
175+
num_conv_layers: the number of causal convolutional layers
176+
kernel_size: size of the kernels in the causal convolutional layers.
177+
activation: activation function to use between convolutions,
178+
default = LeakyReLU.
179+
pool_kernel_size: pool size to use for the AvgPool1d operation after
180+
the causal convolutional layers.
181+
aggregator: aggregation net that reduces the dimensionality of the data
182+
to a low-dimensional embedding.
183+
output_dim: number of output units in the final layer when using
184+
the default aggregation
185+
"""
186+
187+
super(CausalCNNEmbedding, self).__init__()
188+
assert isinstance(input_shape, Tuple), (
189+
"input_shape must be a Tuple of size 1, e.g. (timepoints,)."
190+
)
191+
assert len(input_shape) == 1, "Currently only 1D causal CNNs are supported."
192+
self.input_shape = (in_channels, *input_shape)
193+
194+
total_timepoints = input_shape[0]
195+
assert total_timepoints >= pool_kernel_size, (
196+
"Please ensure that the pool kernel size is not "
197+
"larger than the number of observed timepoints."
198+
)
199+
if isinstance(dilation, str):
200+
match dilation.lower():
201+
case "exponential_cyclic":
202+
max_dil_exp = 10
203+
## Use dilation scheme as in WaveNet paper
204+
dilation_per_layer = [
205+
2 ** (i % max_dil_exp) for i in range(num_conv_layers)
206+
]
207+
case "exponential":
208+
dilation_per_layer = [2**i for i in range(num_conv_layers)]
209+
case "none":
210+
dilation_per_layer = [1] * num_conv_layers
211+
case _:
212+
raise ValueError(
213+
f"{dilation} is not a valid option, please use \"none\","
214+
"\"exponential\",or \"exponential_cyclic\", or pass a list "
215+
"of dilation sizes."
216+
)
217+
else:
218+
assert isinstance(dilation, List), (
219+
"Please pass dilation size as list or a string option."
220+
)
221+
dilation_per_layer = dilation
222+
223+
assert max(dilation_per_layer) < total_timepoints, (
224+
"Your maximal dilations size used is larger than the number of "
225+
"timepoints in your input, please provide a list with smaller dilations."
226+
)
227+
if out_channels_per_layer is None:
228+
out_channels_per_layer = [16] * num_conv_layers
229+
230+
causal_layers = []
231+
for ll in range(num_conv_layers):
232+
causal_layers += [
233+
causalConv1d(
234+
in_channels if ll == 0 else out_channels_per_layer[ll - 1],
235+
out_channels_per_layer[ll],
236+
kernel_size,
237+
dilation_per_layer[ll],
238+
1,
239+
),
240+
activation,
241+
]
242+
243+
self.causal_cnns = nn.Sequential(*causal_layers)
244+
245+
self.pooling_layer = nn.AvgPool1d(kernel_size=pool_kernel_size)
246+
247+
if aggregator is None:
248+
aggregator_out_shape = (
249+
out_channels_per_layer[-1],
250+
int(total_timepoints / pool_kernel_size),
251+
)
252+
assert aggregator_out_shape[-1] > 1, (
253+
"Your dimensionality is already small,"
254+
"Please ensure a larger input size or use a custom aggregator."
255+
)
256+
aggregator = WaveNetSRLikeAggregator(
257+
aggregator_out_shape[0],
258+
aggregator_out_shape[-1],
259+
output_dim=output_dim,
260+
)
261+
self.aggregation = aggregator
262+
263+
def forward(self, x: Tensor) -> Tensor:
264+
batch_size = x.size(0)
265+
x = x.view(batch_size, *self.input_shape)
266+
x = self.causal_cnns(x)
267+
x = self.pooling_layer(x)
268+
x = self.aggregation(x)
269+
# ensure flattening when aggregator uses global average pooling
270+
x = x.view(batch_size, -1)
271+
return x

tests/embedding_net_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sbi.neural_nets import classifier_nn, flowmatching_nn, likelihood_nn, posterior_nn
1414
from sbi.neural_nets.embedding_nets import (
1515
CNNEmbedding,
16+
CausalCNNEmbedding,
1617
FCEmbedding,
1718
PermutationInvariantEmbedding,
1819
)
@@ -173,6 +174,52 @@ def simulator1d(theta):
173174
posterior.potential(s)
174175

175176

177+
@pytest.mark.parametrize("input_shape", [(32,), (64,)])
178+
@pytest.mark.parametrize("num_channels", (1, 2, 3))
179+
def test_1d_causal_cnn_embedding_net(input_shape, num_channels):
180+
estimator_provider = posterior_nn(
181+
"mdn",
182+
embedding_net=CausalCNNEmbedding(
183+
input_shape, in_channels=num_channels, pool_kernel_size=2, output_dim=20
184+
),
185+
)
186+
187+
num_dim = input_shape[0]
188+
189+
def simulator2d(theta):
190+
x = MultivariateNormal(
191+
loc=theta, covariance_matrix=0.5 * torch.eye(num_dim)
192+
).sample()
193+
return x.unsqueeze(2).repeat(1, 1, input_shape[1])
194+
195+
def simulator1d(theta):
196+
return torch.rand_like(theta) + theta
197+
198+
if len(input_shape) == 1:
199+
simulator = simulator1d
200+
xo = torch.ones(1, num_channels, *input_shape).squeeze(1)
201+
else:
202+
simulator = simulator2d
203+
xo = torch.ones(1, num_channels, *input_shape).squeeze(1)
204+
205+
prior = MultivariateNormal(torch.zeros(num_dim), torch.eye(num_dim))
206+
207+
num_simulations = 1000
208+
theta = prior.sample(torch.Size((num_simulations,)))
209+
x = simulator(theta)
210+
if num_channels > 1:
211+
x = x.unsqueeze(1).repeat(
212+
1, num_channels, *[1 for _ in range(len(input_shape))]
213+
)
214+
215+
trainer = NPE(prior=prior, density_estimator=estimator_provider)
216+
trainer.append_simulations(theta, x).train(max_num_epochs=2)
217+
posterior = trainer.build_posterior().set_default_x(xo)
218+
219+
s = posterior.sample((10,))
220+
posterior.potential(s)
221+
222+
176223
@pytest.mark.slow
177224
def test_npe_with_with_iid_embedding_varying_num_trials(trial_factor=50):
178225
"""Test inference accuracy with embeddings for varying number of trials.

0 commit comments

Comments
 (0)