|
| 1 | +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed |
| 2 | +# under the Apache License Version 2.0, see <https://www.apache.org/licenses/> |
| 3 | + |
| 4 | +# This code is based on on the following three papers: |
| 5 | + |
| 6 | +# Lingsch et al. (2024) FUSE: Fast Unified Simulation and Estimation for PDEs |
| 7 | +# (https://proceedings.neurips.cc/paper_files/paper/2024/file/266c0f191b04cbbbe529016d0edc847e-Paper-Conference.pdf) |
| 8 | +# |
| 9 | +# Lingsch et al. (2024) Beyond Regular Grids: Fourier-Based Neural Operators |
| 10 | +# on Arbitrary Domains |
| 11 | +# (https://arxiv.org/pdf/2305.19663) |
| 12 | + |
| 13 | +# Li et al. (2021) Fourier Neural Operator for Parametric Partial Differential Equations |
| 14 | +# (https://openreview.net/pdf?id=c8P9NQVtmnO) |
| 15 | + |
| 16 | +# and partially adapted from the following repository: |
| 17 | +# https://github.com/camlab-ethz/FUSE |
| 18 | + |
| 19 | +from typing import Optional, Tuple |
| 20 | + |
| 21 | +import numpy as np |
| 22 | +import torch |
| 23 | +import torch.nn.functional as F |
| 24 | +from torch import Tensor, nn |
| 25 | + |
| 26 | + |
| 27 | +class VFT: |
| 28 | + """Class for performing Fourier transformations for non-equally |
| 29 | + and equally spaced 1d grids. |
| 30 | +
|
| 31 | + It provides a function for creating grid-dependent operator V to compute the |
| 32 | + Forward Fourier transform X of data x with X = V*x. |
| 33 | + The inverse Fourier transform can then be computed by x = V_inv*X with |
| 34 | + V_inv = transpose(conjugate(V)). |
| 35 | +
|
| 36 | + Adapted from: Lingsch et al. (2024) Beyond Regular Grids: Fourier-Based |
| 37 | + Neural Operators on Arbitrary Domains |
| 38 | +
|
| 39 | + Args: |
| 40 | + batch_size: Training batch size |
| 41 | + n_points: Number of 1d grid points |
| 42 | + modes: number of Fourier modes that should be used |
| 43 | + (maximal floor(n_points/2) + 1) |
| 44 | + point_positions: Grid point positions of shape (batch_size, n_points). |
| 45 | + If not provided, equispaced points are used. Positions have to be |
| 46 | + normalized with domain length. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + batch_size: int, |
| 52 | + n_points: int, |
| 53 | + modes: int, |
| 54 | + point_positions: Optional[Tensor] = None, |
| 55 | + ): |
| 56 | + self.number_points = n_points |
| 57 | + self.batch_size = batch_size |
| 58 | + self.modes = modes |
| 59 | + |
| 60 | + if point_positions is not None: |
| 61 | + new_times = point_positions[:, None, :] |
| 62 | + else: |
| 63 | + new_times = ( |
| 64 | + (torch.arange(self.number_points) / self.number_points).repeat( |
| 65 | + self.batch_size, 1 |
| 66 | + ) |
| 67 | + )[:, None, :] |
| 68 | + |
| 69 | + self.new_times = new_times * 2 * np.pi |
| 70 | + |
| 71 | + self.X_ = torch.arange(modes).repeat(self.batch_size, 1)[:, :, None].float() |
| 72 | + # V_fwd: (batch, modes, points) V_inf: (batch, points, modes) |
| 73 | + self.V_fwd, self.V_inv = self.make_matrix() |
| 74 | + |
| 75 | + def make_matrix(self) -> Tuple[Tensor, Tensor]: |
| 76 | + """Create matrix operators V and V_inf for forward and backward |
| 77 | + Fourier transformation on arbitrary grids |
| 78 | + """ |
| 79 | + |
| 80 | + X_mat = torch.bmm(self.X_, self.new_times) |
| 81 | + forward_mat = torch.exp(-1j * (X_mat)) |
| 82 | + |
| 83 | + inverse_mat = torch.conj(forward_mat.clone()).permute(0, 2, 1) |
| 84 | + |
| 85 | + return forward_mat, inverse_mat |
| 86 | + |
| 87 | + def forward(self, data: Tensor, norm: str = 'forward') -> Tensor: |
| 88 | + """Perform forward Fourier transformation |
| 89 | + Args: |
| 90 | + data: Input data with shape (batch_size, n_points, conv_channel) |
| 91 | + """ |
| 92 | + if norm == 'forward': |
| 93 | + data_fwd = torch.bmm(self.V_fwd, data) / self.number_points |
| 94 | + elif norm == 'ortho': |
| 95 | + data_fwd = torch.bmm(self.V_fwd, data) / np.sqrt(self.number_points) |
| 96 | + elif norm == 'backward': |
| 97 | + data_fwd = torch.bmm(self.V_fwd, data) |
| 98 | + |
| 99 | + return data_fwd # (batch, modes, conv_channels) |
| 100 | + |
| 101 | + def inverse(self, data: Tensor, norm: str = 'forward') -> Tensor: |
| 102 | + """Perform inverse Fourier transformation |
| 103 | + Args: |
| 104 | + data: Input data with shape (batch_size, modes, conv_channel) |
| 105 | + """ |
| 106 | + if norm == 'backward': |
| 107 | + data_inv = torch.bmm(self.V_inv, data) / self.number_points |
| 108 | + elif norm == 'ortho': |
| 109 | + data_inv = torch.bmm(self.V_inv, data) / np.sqrt(self.number_points) |
| 110 | + elif norm == 'forward': |
| 111 | + data_inv = torch.bmm(self.V_inv, data) |
| 112 | + |
| 113 | + return data_inv # (batch, n_points, conv_channels) |
| 114 | + |
| 115 | + |
| 116 | +class SpectralConv1d_SMM(nn.Module): |
| 117 | + """ |
| 118 | + A 1D spectral convolutional layer using the Fourier transform. |
| 119 | + This layer applies a learned complex multiplication in the frequency domain. |
| 120 | +
|
| 121 | + Adapted from: |
| 122 | + - Lingsch et al. (2024) FUSE: Fast Unified Simulation and Estimation for PDEs |
| 123 | + - Li et al. (2021) Fourier Neural Operator for Parametric Partial Differential |
| 124 | + Equations |
| 125 | +
|
| 126 | + Args: |
| 127 | + in_channels: Number of input channels. |
| 128 | + out_channels: Number of output channels. |
| 129 | + modes: Number of Fourier modes to multiply, |
| 130 | + at most floor(N/2) + 1. |
| 131 | + """ |
| 132 | + |
| 133 | + def __init__(self, in_channels: int, out_channels: int, modes: int): |
| 134 | + super(SpectralConv1d_SMM, self).__init__() |
| 135 | + |
| 136 | + self.in_channels = in_channels |
| 137 | + self.out_channels = out_channels |
| 138 | + self.modes = modes |
| 139 | + |
| 140 | + self.scale = 1 / (in_channels * out_channels) |
| 141 | + self.weights1 = nn.Parameter( |
| 142 | + self.scale |
| 143 | + * torch.rand(in_channels, out_channels, self.modes, dtype=torch.cfloat) |
| 144 | + ) |
| 145 | + |
| 146 | + def compl_mul1d(self, input: Tensor, weights: Tensor) -> Tensor: |
| 147 | + """ |
| 148 | + Performs complex multiplication in the Fourier domain. |
| 149 | +
|
| 150 | + Args: |
| 151 | + input: Input tensor of shape (batch, in_channels, modes). |
| 152 | + weights: Weight tensor of shape (in_channels, out_channels, modes). |
| 153 | +
|
| 154 | + Returns: |
| 155 | + torch.Tensor: Output tensor of shape (batch, out_channels, modes). |
| 156 | + """ |
| 157 | + |
| 158 | + return torch.einsum("bix,iox->box", input, weights) |
| 159 | + |
| 160 | + def forward(self, x: Tensor, transform: VFT) -> Tensor: |
| 161 | + """ |
| 162 | + Forward pass of the spectral convolution layer. |
| 163 | +
|
| 164 | + Args: |
| 165 | + x: Input tensor of shape (batch, n_points, in_channels). |
| 166 | + transform: Fourier transform operator with forward and inverse methods. |
| 167 | +
|
| 168 | + Returns: |
| 169 | + The real part of the transformed output tensor |
| 170 | + with shape (batch, points, out_channels). |
| 171 | + """ |
| 172 | + # Compute Fourier coefficients |
| 173 | + x_ft = transform.forward(x.to(torch.complex64), norm='forward') |
| 174 | + x_ft = x_ft.permute(0, 2, 1) |
| 175 | + out_ft = self.compl_mul1d(x_ft, self.weights1) |
| 176 | + x_ft = out_ft.permute(0, 2, 1) |
| 177 | + |
| 178 | + # Return to physical space |
| 179 | + x = transform.inverse(x_ft, norm='forward') |
| 180 | + |
| 181 | + return x.real |
| 182 | + |
| 183 | + def last_layer(self, x: Tensor, transform: VFT) -> Tensor: |
| 184 | + """ |
| 185 | + Last convolutional layer returning Fourier coefficients to be used as embedding |
| 186 | +
|
| 187 | + Args: |
| 188 | + x: Input tensor of shape (batch, points, in_channels). |
| 189 | + transform: Fourier transform operator with forward and inverse methods. |
| 190 | +
|
| 191 | + Returns: |
| 192 | + Transformed output tensor of shape (batch, 2*modes, out_channels). |
| 193 | + """ |
| 194 | + |
| 195 | + # Compute Fourier coeffcients |
| 196 | + x_ft = transform.forward(x.to(torch.complex64), norm='forward') |
| 197 | + x_ft = x_ft.permute(0, 2, 1) |
| 198 | + x_ft = self.compl_mul1d(x_ft, self.weights1) # (batch, conv_channels, modes) |
| 199 | + x_ft = x_ft.permute(0, 2, 1) # (batch, modes, conv_channels) |
| 200 | + x_ft = torch.view_as_real(x_ft) # (batch, modes, conv_channels, 2) |
| 201 | + x_ft = x_ft.permute(0, 1, 3, 2) |
| 202 | + x_ft = x_ft.reshape(x.shape[0], 2 * self.modes, self.out_channels) |
| 203 | + |
| 204 | + return x_ft |
| 205 | + |
| 206 | + |
| 207 | +class SpectralConvEmbedding(nn.Module): |
| 208 | + def __init__( |
| 209 | + self, |
| 210 | + in_channels: int, |
| 211 | + modes: int = 10, |
| 212 | + out_channels: int = 1, |
| 213 | + conv_channels: int = 5, |
| 214 | + num_layers: int = 3, |
| 215 | + ): |
| 216 | + """SpectralConvEmbedding is a neural network module that performs convolution |
| 217 | + in Fourier space for 1D input data (that can have multiple channels). |
| 218 | + It uses a series of spectral convolution layers and pointwise |
| 219 | + convolution layers to transform the input tensor. |
| 220 | +
|
| 221 | + Adapted from: Lingsch et al. (2024) Beyond Regular Grids: Fourier-Based |
| 222 | + Neural Operators on Arbitrary Domains |
| 223 | +
|
| 224 | + Args: |
| 225 | + in_channels: Number of channels in the input data. |
| 226 | + modes: Number of modes considered in the spectral convolution, |
| 227 | + at most floor(n_points/2) + 1. |
| 228 | + out_channels: number of channels for final output. |
| 229 | + conv_channels: Number of going in and out convolutional layer. |
| 230 | + num_layers: Number of convolution layers. |
| 231 | +
|
| 232 | + """ |
| 233 | + super().__init__() |
| 234 | + |
| 235 | + self.modes = modes |
| 236 | + self.in_channels = in_channels |
| 237 | + self.out_channels = out_channels |
| 238 | + self.conv_channels = conv_channels |
| 239 | + self.num_layers = num_layers |
| 240 | + |
| 241 | + # Initialize fully connected layer to raise number of |
| 242 | + # input channels to number of convolutional channels |
| 243 | + self.fc0 = nn.Linear(self.in_channels, self.conv_channels) |
| 244 | + |
| 245 | + # Inititalize layers performing convolution in Fourier space |
| 246 | + self.conv_layers = nn.ModuleList([ |
| 247 | + SpectralConv1d_SMM(self.conv_channels, self.conv_channels, self.modes) |
| 248 | + for _ in range(self.num_layers) |
| 249 | + ]) |
| 250 | + |
| 251 | + # Initialize layer performing pointwise convolution |
| 252 | + self.w_layers = nn.ModuleList([ |
| 253 | + nn.Conv1d(self.conv_channels, self.conv_channels, 1) |
| 254 | + for _ in range(self.num_layers) |
| 255 | + ]) |
| 256 | + |
| 257 | + # Initialize last convolutional layer with output in Fourier space |
| 258 | + self.conv_last = SpectralConv1d_SMM( |
| 259 | + self.conv_channels, self.conv_channels, self.modes |
| 260 | + ) |
| 261 | + |
| 262 | + # Initialize fully connected layer to reduce number of output channels |
| 263 | + self.fc_last = nn.Linear(self.conv_channels, self.out_channels) |
| 264 | + |
| 265 | + def forward(self, x: Tensor) -> Tensor: |
| 266 | + """Network forward pass. |
| 267 | +
|
| 268 | + Args: |
| 269 | + x: 3D input tensor (batch_size, in_channels, n_points) for equi-spaced data |
| 270 | + or 4D tensor (batch_size, 2, in_channels, n_points) for non-equispaced data, |
| 271 | + where we additionally pass the point positions in the second dimension, |
| 272 | + repeating the same point positions for each channel. |
| 273 | + For non-equispaced data, the positions have to be normalized with |
| 274 | + physical domain length. |
| 275 | +
|
| 276 | + Exemplary code: |
| 277 | +
|
| 278 | + # Example for equispaced grid data with batch size of 256, 3 channels and |
| 279 | + # sequence length of 500 |
| 280 | + data_equispaced = torch.rand(256, 3, 500) |
| 281 | + embedding_net = SpectralConvEmbedding(modes=15, in_channels=3, |
| 282 | + out_channels=1, conv_channels=5, num_layers=4) |
| 283 | + neural_posterior = posterior_nn(model="nsf", embedding_net=embedding_net) |
| 284 | + inference = SNPE(prior=sbi_prior, density_estimator=neural_posterior) |
| 285 | + _ = inference.append_simulations(theta, data_equispaced) |
| 286 | +
|
| 287 | + # Example for non-equispaced data with batch size of 256, 3 channels and |
| 288 | + # sequence length of 500 |
| 289 | + irregular_positions = torch.rand(500) # non-equally spaced positions in [0;1] |
| 290 | + irregular_positions, indices = torch.sort(irregular_positions, 0) |
| 291 | + irregular_positions = irregular_positions.repeat(256, 3, 1) |
| 292 | +
|
| 293 | + random_data = torch.rand(256, 3, 500) |
| 294 | +
|
| 295 | + data_nonequispaced = torch.zeros(256, 2, 3, 500) |
| 296 | + data_nonequispaced[:, 0, :, :] = random_data |
| 297 | + data_nonequispaced[:, 1, :, :] = irregular_positions |
| 298 | +
|
| 299 | + embedding_net = SpectralConvEmbedding(modes=15, in_channels=3, out_channels=1, |
| 300 | + conv_channels=5, num_layers=4) |
| 301 | + neural_posterior = posterior_nn(model="nsf", embedding_net=embedding_net) |
| 302 | + inference = SNPE(prior=sbi_prior, density_estimator=neural_posterior) |
| 303 | + _ = inference.append_simulations(theta, data_nonequispaced) |
| 304 | +
|
| 305 | + Returns: |
| 306 | + Network output (batch_size, out_channels * 2 * modes). |
| 307 | + """ |
| 308 | + batch_size = x.shape[0] |
| 309 | + |
| 310 | + # Check dimension of input data and reshape it |
| 311 | + if x.ndim == 3: |
| 312 | + x = x.permute(0, 2, 1) # (batch, n_points, in_channels) |
| 313 | + point_positions = None |
| 314 | + |
| 315 | + elif x.ndim == 4: |
| 316 | + point_positions = x[:, 1, 0, :] |
| 317 | + x = x[:, 0, :, :].permute(0, 2, 1) |
| 318 | + |
| 319 | + else: |
| 320 | + raise ValueError( |
| 321 | + 'Input tensor should be 3D (batch_size, channels, n_points) ' |
| 322 | + 'or 4D (batch_size, 2, channels, n_points). ', |
| 323 | + f'The tensor that was passed has shape {x.shape}.', |
| 324 | + ) |
| 325 | + |
| 326 | + n_points = x.shape[1] |
| 327 | + |
| 328 | + assert self.modes <= n_points // 2 + 1, ( |
| 329 | + "Modes should be at most floor(n_points/2) + 1" |
| 330 | + ) |
| 331 | + |
| 332 | + x = self.fc0(x) # (batch_size, n_points, in_channels) |
| 333 | + |
| 334 | + # Initialize Fourier transform for arbitrarily spaced points |
| 335 | + fourier_transform = VFT(batch_size, n_points, self.modes, point_positions) |
| 336 | + |
| 337 | + # Send the data through Fourier layers, output in original space |
| 338 | + for conv, w in zip(self.conv_layers, self.w_layers, strict=False): |
| 339 | + x1 = conv(x, fourier_transform) |
| 340 | + x2 = w(x.permute(0, 2, 1)) |
| 341 | + x = x1 + x2.permute(0, 2, 1) |
| 342 | + x = F.gelu(x) |
| 343 | + |
| 344 | + # Send data through last convolutional layer which returns data in Fourier space |
| 345 | + x_spec = self.conv_last.last_layer( |
| 346 | + x, fourier_transform |
| 347 | + ) # (batch, 2*modes, out_channels) |
| 348 | + |
| 349 | + # Reduce the number of channels with last layer |
| 350 | + x_spec = self.fc_last(x_spec) # (batch, 2*modes, out_channels) |
| 351 | + |
| 352 | + return x_spec.reshape(batch_size, -1) |
0 commit comments