|
| 1 | +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed |
| 2 | +# under the Apache License Version 2.0, see <https://www.apache.org/licenses/> |
| 3 | + |
| 4 | +from typing import List, Optional, Tuple, Union |
| 5 | + |
| 6 | +from torch import Tensor, nn |
| 7 | + |
| 8 | +from sbi.neural_nets.embedding_nets.cnn import calculate_filter_output_size |
| 9 | + |
| 10 | + |
| 11 | +def causalConv1d( |
| 12 | + in_channels: int, |
| 13 | + out_channels: int, |
| 14 | + kernel_size: int, |
| 15 | + dilation: int = 1, |
| 16 | + stride: int = 1, |
| 17 | +) -> nn.Module: |
| 18 | + """Returns a causal convolution by left padding the input |
| 19 | +
|
| 20 | + Args: |
| 21 | + in_channels: number of input channels |
| 22 | + out_channels: number of output channels wanted |
| 23 | + kernel_size: wanted kernel size |
| 24 | + dilation: dilation to use in the convolution. |
| 25 | + stride: stride to use in the convolution. |
| 26 | + Stride and dilation cannot both be > 1. |
| 27 | +
|
| 28 | + Returns: |
| 29 | + An nn.Sequential object that represents a 1D causal convolution. |
| 30 | + """ |
| 31 | + assert not (dilation > 1 and stride > 1), ( |
| 32 | + "we don't allow combining stride with dilation." |
| 33 | + ) |
| 34 | + padding_size = dilation * (kernel_size - 1) |
| 35 | + padding = nn.ZeroPad1d(padding=(padding_size, 0)) |
| 36 | + conv_layer = nn.Conv1d( |
| 37 | + in_channels=in_channels, |
| 38 | + out_channels=out_channels, |
| 39 | + kernel_size=kernel_size, |
| 40 | + dilation=dilation, |
| 41 | + stride=stride, |
| 42 | + padding=0, |
| 43 | + ) |
| 44 | + return nn.Sequential(padding, conv_layer) |
| 45 | + |
| 46 | + |
| 47 | +def WaveNetSRLikeAggregator( |
| 48 | + in_channels: int, |
| 49 | + num_timepoints: int, |
| 50 | + output_dim: int, |
| 51 | + activation: nn.Module = nn.LeakyReLU(inplace=True), |
| 52 | + kernel_sizes: Optional[List] = None, |
| 53 | + intermediate_channel_sizes: Optional[List] = None, |
| 54 | + stride_sizes: Union[int, List] = 1, |
| 55 | +) -> nn.Module: |
| 56 | + """ |
| 57 | + Creates a non-causal 1D CNN aggregator based on the WaveNet speach recognition task |
| 58 | +
|
| 59 | + By default this function creates an aggregator with two CNN layers, |
| 60 | + after every convolution a maxpooling operation halves the number of timepoints. |
| 61 | + The final CNN will have as many channels as the desired output dimension. |
| 62 | + A global average pooling operation is applied in the end. The dimension of the |
| 63 | + output will thus be (batch_size, output_dim, 1) regardless of the input size. |
| 64 | +
|
| 65 | + Args: |
| 66 | + in_channels: number of channels at input. |
| 67 | + num_timepoints: length of the input. |
| 68 | + output dim: wanted number of features as output. |
| 69 | + activation: activation to apply after the convolution. |
| 70 | + kernel_sizes: (optional) alter the kernel size used and the number of CNN layers |
| 71 | + (through the length of the kernel size vector). |
| 72 | + intermediate_channel_sizes: (optional) alter the intermediate channel sizes |
| 73 | + used, should have length = len(kernel_sizes) - 1. |
| 74 | + stride_sizes = Optional alter the stride used, either a vector of |
| 75 | + len = len(kernel_sizes) or a single integer, in which case the same stride |
| 76 | + is used in every convolution. |
| 77 | +
|
| 78 | + Returns: |
| 79 | + nn.Module object that contains a sequence of CNN and max_pool layer |
| 80 | + and finally a global average pooling layer. |
| 81 | + """ |
| 82 | + aggregator_out_shape = ( |
| 83 | + in_channels, |
| 84 | + num_timepoints, |
| 85 | + ) |
| 86 | + if kernel_sizes is None: |
| 87 | + kernel_sizes = [ |
| 88 | + min(9, aggregator_out_shape[-1]), |
| 89 | + min(5, int(aggregator_out_shape[-1] / 2)), |
| 90 | + ] |
| 91 | + if intermediate_channel_sizes is None: |
| 92 | + intermediate_channel_sizes = [64] |
| 93 | + assert len(intermediate_channel_sizes) == len(kernel_sizes) - 1, ( |
| 94 | + "Provided kernel size list should be exactly one element longer " |
| 95 | + "than channel size list." |
| 96 | + ) |
| 97 | + intermediate_channel_sizes += [output_dim] |
| 98 | + if isinstance(stride_sizes, List): |
| 99 | + assert len(stride_sizes) == len(kernel_sizes), ( |
| 100 | + "Provided stride size list should be have the same size as" |
| 101 | + "the kernel size list." |
| 102 | + ) |
| 103 | + else: |
| 104 | + stride_sizes = [stride_sizes] * len(kernel_sizes) |
| 105 | + |
| 106 | + non_causal_layers = [] |
| 107 | + for ll in range(len(kernel_sizes)): |
| 108 | + print(aggregator_out_shape) |
| 109 | + conv_layer = nn.Conv1d( |
| 110 | + in_channels=in_channels if ll == 0 else intermediate_channel_sizes[ll - 1], |
| 111 | + out_channels=intermediate_channel_sizes[ll], |
| 112 | + kernel_size=kernel_sizes[ll], |
| 113 | + stride=stride_sizes[ll], |
| 114 | + padding='same', |
| 115 | + ) |
| 116 | + maxpool = nn.MaxPool1d(kernel_size=2 if aggregator_out_shape[-1] > 2 else 1) |
| 117 | + non_causal_layers += [conv_layer, activation, maxpool] |
| 118 | + aggregator_out_shape = ( |
| 119 | + intermediate_channel_sizes[ll], |
| 120 | + int( |
| 121 | + calculate_filter_output_size( |
| 122 | + aggregator_out_shape[-1], |
| 123 | + (kernel_sizes[ll] - 1) / 2, |
| 124 | + 1, |
| 125 | + kernel_sizes[ll], |
| 126 | + stride_sizes[ll], |
| 127 | + ) |
| 128 | + / 2 |
| 129 | + ), |
| 130 | + ) |
| 131 | + print(aggregator_out_shape) |
| 132 | + aggregator = nn.Sequential(*non_causal_layers, nn.AdaptiveAvgPool1d(1)) |
| 133 | + return aggregator |
| 134 | + |
| 135 | + |
| 136 | +class CausalCNNEmbedding(nn.Module): |
| 137 | + def __init__( |
| 138 | + self, |
| 139 | + input_shape: Tuple, |
| 140 | + in_channels: int = 1, |
| 141 | + out_channels_per_layer: Optional[List] = None, |
| 142 | + dilation: Union[str, List] = "exponential_cyclic", |
| 143 | + num_conv_layers: int = 5, |
| 144 | + activation: nn.Module = nn.LeakyReLU(inplace=True), |
| 145 | + pool_kernel_size: int = 160, |
| 146 | + kernel_size: int = 2, |
| 147 | + aggregator: Optional[nn.Module] = None, |
| 148 | + output_dim: int = 20, |
| 149 | + ): |
| 150 | + """Embedding network that uses 1D causal convolutions |
| 151 | +
|
| 152 | + This is a simplified version of the architecture introduced for |
| 153 | + the speech recognition task in the WaveNet paper (van den Oord, et al. (2016)) |
| 154 | +
|
| 155 | + After several dilated causal convolutions (that maintain the dimensionality |
| 156 | + of the input), an aggregator network is used to bring down the dimensionality. |
| 157 | + You can provide an aggregator network that you deem reasonable for your data. |
| 158 | + If you do not provide an aggregator network yourself, a default aggregator |
| 159 | + is used. This default aggregator is based on the WaveNet paper's description |
| 160 | + of their Speech Recognition Task, and uses non-causal convolutions and pooling |
| 161 | + layers, and global average poolingg to obtain a final low dimensional embedding. |
| 162 | +
|
| 163 | + Args: |
| 164 | + input_shape: Dimensionality of the input e.g. (num_timepoints,), |
| 165 | + currently only 1D is supported. |
| 166 | + in_channels: Number of input channels, default = 1. |
| 167 | + out_channels_per_layer: number of out_channels for each layer, number |
| 168 | + of entries should correspond with num_conv_layers passed below. |
| 169 | + Default = 16 in every convolutional layer. |
| 170 | + dilation: type of dilation to use either one of "none" (dilation = 1 |
| 171 | + in every layer), "exponential" (increase dilation by a factor of 2 |
| 172 | + every layer), "exponential_cyclic" (as exponential, but reset to 1 |
| 173 | + after dilation = 2**9) or pass a list with dilation size per layer. |
| 174 | + By default the cyclic, exponential scheme from WaveNet is used. |
| 175 | + num_conv_layers: the number of causal convolutional layers |
| 176 | + kernel_size: size of the kernels in the causal convolutional layers. |
| 177 | + activation: activation function to use between convolutions, |
| 178 | + default = LeakyReLU. |
| 179 | + pool_kernel_size: pool size to use for the AvgPool1d operation after |
| 180 | + the causal convolutional layers. |
| 181 | + aggregator: aggregation net that reduces the dimensionality of the data |
| 182 | + to a low-dimensional embedding. |
| 183 | + output_dim: number of output units in the final layer when using |
| 184 | + the default aggregation |
| 185 | + """ |
| 186 | + |
| 187 | + super(CausalCNNEmbedding, self).__init__() |
| 188 | + assert isinstance(input_shape, Tuple), ( |
| 189 | + "input_shape must be a Tuple of size 1, e.g. (timepoints,)." |
| 190 | + ) |
| 191 | + assert len(input_shape) == 1, "Currently only 1D causal CNNs are supported." |
| 192 | + self.input_shape = (in_channels, *input_shape) |
| 193 | + |
| 194 | + total_timepoints = input_shape[0] |
| 195 | + assert total_timepoints >= pool_kernel_size, ( |
| 196 | + "Please ensure that the pool kernel size is not " |
| 197 | + "larger than the number of observed timepoints." |
| 198 | + ) |
| 199 | + if isinstance(dilation, str): |
| 200 | + match dilation.lower(): |
| 201 | + case "exponential_cyclic": |
| 202 | + max_dil_exp = 10 |
| 203 | + ## Use dilation scheme as in WaveNet paper |
| 204 | + dilation_per_layer = [ |
| 205 | + 2 ** (i % max_dil_exp) for i in range(num_conv_layers) |
| 206 | + ] |
| 207 | + case "exponential": |
| 208 | + dilation_per_layer = [2**i for i in range(num_conv_layers)] |
| 209 | + case "none": |
| 210 | + dilation_per_layer = [1] * num_conv_layers |
| 211 | + case _: |
| 212 | + raise ValueError( |
| 213 | + f"{dilation} is not a valid option, please use \"none\"," |
| 214 | + "\"exponential\",or \"exponential_cyclic\", or pass a list " |
| 215 | + "of dilation sizes." |
| 216 | + ) |
| 217 | + else: |
| 218 | + assert isinstance(dilation, List), ( |
| 219 | + "Please pass dilation size as list or a string option." |
| 220 | + ) |
| 221 | + dilation_per_layer = dilation |
| 222 | + |
| 223 | + assert max(dilation_per_layer) < total_timepoints, ( |
| 224 | + "Your maximal dilations size used is larger than the number of " |
| 225 | + "timepoints in your input, please provide a list with smaller dilations." |
| 226 | + ) |
| 227 | + if out_channels_per_layer is None: |
| 228 | + out_channels_per_layer = [16] * num_conv_layers |
| 229 | + |
| 230 | + causal_layers = [] |
| 231 | + for ll in range(num_conv_layers): |
| 232 | + causal_layers += [ |
| 233 | + causalConv1d( |
| 234 | + in_channels if ll == 0 else out_channels_per_layer[ll - 1], |
| 235 | + out_channels_per_layer[ll], |
| 236 | + kernel_size, |
| 237 | + dilation_per_layer[ll], |
| 238 | + 1, |
| 239 | + ), |
| 240 | + activation, |
| 241 | + ] |
| 242 | + |
| 243 | + self.causal_cnns = nn.Sequential(*causal_layers) |
| 244 | + |
| 245 | + self.pooling_layer = nn.AvgPool1d(kernel_size=pool_kernel_size) |
| 246 | + |
| 247 | + if aggregator is None: |
| 248 | + aggregator_out_shape = ( |
| 249 | + out_channels_per_layer[-1], |
| 250 | + int(total_timepoints / pool_kernel_size), |
| 251 | + ) |
| 252 | + assert aggregator_out_shape[-1] > 1, ( |
| 253 | + "Your dimensionality is already small," |
| 254 | + "Please ensure a larger input size or use a custom aggregator." |
| 255 | + ) |
| 256 | + aggregator = WaveNetSRLikeAggregator( |
| 257 | + aggregator_out_shape[0], |
| 258 | + aggregator_out_shape[-1], |
| 259 | + output_dim=output_dim, |
| 260 | + ) |
| 261 | + self.aggregation = aggregator |
| 262 | + |
| 263 | + def forward(self, x: Tensor) -> Tensor: |
| 264 | + batch_size = x.size(0) |
| 265 | + x = x.view(batch_size, *self.input_shape) |
| 266 | + x = self.causal_cnns(x) |
| 267 | + x = self.pooling_layer(x) |
| 268 | + x = self.aggregation(x) |
| 269 | + # ensure flattening when aggregator uses global average pooling |
| 270 | + x = x.view(batch_size, -1) |
| 271 | + return x |
0 commit comments