Skip to content

Commit f1312e8

Browse files
L-in-dalsmuhlemanuelgloeckler
authored
Spectral convolution embedding net (#1503)
* spectral convolution embedding net * documentation * fixes spectral embedding - permutation improved - little mistakes corrected * tests * fix n_positions format * linear output layer * cleaning code * documentation minor changes * delete notebook * requested changes * updating tests * add exemplary code snippet to embedding forward method * changing shape of input x * update tests to new x shape * formatting * reworking based on Github comments * small fix in fourier transform --------- Co-authored-by: lsmuhle <36638264+lsmuhle@users.noreply.github.com> Co-authored-by: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Co-authored-by: manuelgloeckler <manu.gloeckler@hotmail.de>
1 parent 0d5a71a commit f1312e8

File tree

3 files changed

+399
-0
lines changed

3 files changed

+399
-0
lines changed
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2+
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3+
4+
# This code is based on on the following three papers:
5+
6+
# Lingsch et al. (2024) FUSE: Fast Unified Simulation and Estimation for PDEs
7+
# (https://proceedings.neurips.cc/paper_files/paper/2024/file/266c0f191b04cbbbe529016d0edc847e-Paper-Conference.pdf)
8+
#
9+
# Lingsch et al. (2024) Beyond Regular Grids: Fourier-Based Neural Operators
10+
# on Arbitrary Domains
11+
# (https://arxiv.org/pdf/2305.19663)
12+
13+
# Li et al. (2021) Fourier Neural Operator for Parametric Partial Differential Equations
14+
# (https://openreview.net/pdf?id=c8P9NQVtmnO)
15+
16+
# and partially adapted from the following repository:
17+
# https://github.com/camlab-ethz/FUSE
18+
19+
from typing import Optional, Tuple
20+
21+
import numpy as np
22+
import torch
23+
import torch.nn.functional as F
24+
from torch import Tensor, nn
25+
26+
27+
class VFT:
28+
"""Class for performing Fourier transformations for non-equally
29+
and equally spaced 1d grids.
30+
31+
It provides a function for creating grid-dependent operator V to compute the
32+
Forward Fourier transform X of data x with X = V*x.
33+
The inverse Fourier transform can then be computed by x = V_inv*X with
34+
V_inv = transpose(conjugate(V)).
35+
36+
Adapted from: Lingsch et al. (2024) Beyond Regular Grids: Fourier-Based
37+
Neural Operators on Arbitrary Domains
38+
39+
Args:
40+
batch_size: Training batch size
41+
n_points: Number of 1d grid points
42+
modes: number of Fourier modes that should be used
43+
(maximal floor(n_points/2) + 1)
44+
point_positions: Grid point positions of shape (batch_size, n_points).
45+
If not provided, equispaced points are used. Positions have to be
46+
normalized with domain length.
47+
"""
48+
49+
def __init__(
50+
self,
51+
batch_size: int,
52+
n_points: int,
53+
modes: int,
54+
point_positions: Optional[Tensor] = None,
55+
):
56+
self.number_points = n_points
57+
self.batch_size = batch_size
58+
self.modes = modes
59+
60+
if point_positions is not None:
61+
new_times = point_positions[:, None, :]
62+
else:
63+
new_times = (
64+
(torch.arange(self.number_points) / self.number_points).repeat(
65+
self.batch_size, 1
66+
)
67+
)[:, None, :]
68+
69+
self.new_times = new_times * 2 * np.pi
70+
71+
self.X_ = torch.arange(modes).repeat(self.batch_size, 1)[:, :, None].float()
72+
# V_fwd: (batch, modes, points) V_inf: (batch, points, modes)
73+
self.V_fwd, self.V_inv = self.make_matrix()
74+
75+
def make_matrix(self) -> Tuple[Tensor, Tensor]:
76+
"""Create matrix operators V and V_inf for forward and backward
77+
Fourier transformation on arbitrary grids
78+
"""
79+
80+
X_mat = torch.bmm(self.X_, self.new_times)
81+
forward_mat = torch.exp(-1j * (X_mat))
82+
83+
inverse_mat = torch.conj(forward_mat.clone()).permute(0, 2, 1)
84+
85+
return forward_mat, inverse_mat
86+
87+
def forward(self, data: Tensor, norm: str = 'forward') -> Tensor:
88+
"""Perform forward Fourier transformation
89+
Args:
90+
data: Input data with shape (batch_size, n_points, conv_channel)
91+
"""
92+
if norm == 'forward':
93+
data_fwd = torch.bmm(self.V_fwd, data) / self.number_points
94+
elif norm == 'ortho':
95+
data_fwd = torch.bmm(self.V_fwd, data) / np.sqrt(self.number_points)
96+
elif norm == 'backward':
97+
data_fwd = torch.bmm(self.V_fwd, data)
98+
99+
return data_fwd # (batch, modes, conv_channels)
100+
101+
def inverse(self, data: Tensor, norm: str = 'forward') -> Tensor:
102+
"""Perform inverse Fourier transformation
103+
Args:
104+
data: Input data with shape (batch_size, modes, conv_channel)
105+
"""
106+
if norm == 'backward':
107+
data_inv = torch.bmm(self.V_inv, data) / self.number_points
108+
elif norm == 'ortho':
109+
data_inv = torch.bmm(self.V_inv, data) / np.sqrt(self.number_points)
110+
elif norm == 'forward':
111+
data_inv = torch.bmm(self.V_inv, data)
112+
113+
return data_inv # (batch, n_points, conv_channels)
114+
115+
116+
class SpectralConv1d_SMM(nn.Module):
117+
"""
118+
A 1D spectral convolutional layer using the Fourier transform.
119+
This layer applies a learned complex multiplication in the frequency domain.
120+
121+
Adapted from:
122+
- Lingsch et al. (2024) FUSE: Fast Unified Simulation and Estimation for PDEs
123+
- Li et al. (2021) Fourier Neural Operator for Parametric Partial Differential
124+
Equations
125+
126+
Args:
127+
in_channels: Number of input channels.
128+
out_channels: Number of output channels.
129+
modes: Number of Fourier modes to multiply,
130+
at most floor(N/2) + 1.
131+
"""
132+
133+
def __init__(self, in_channels: int, out_channels: int, modes: int):
134+
super(SpectralConv1d_SMM, self).__init__()
135+
136+
self.in_channels = in_channels
137+
self.out_channels = out_channels
138+
self.modes = modes
139+
140+
self.scale = 1 / (in_channels * out_channels)
141+
self.weights1 = nn.Parameter(
142+
self.scale
143+
* torch.rand(in_channels, out_channels, self.modes, dtype=torch.cfloat)
144+
)
145+
146+
def compl_mul1d(self, input: Tensor, weights: Tensor) -> Tensor:
147+
"""
148+
Performs complex multiplication in the Fourier domain.
149+
150+
Args:
151+
input: Input tensor of shape (batch, in_channels, modes).
152+
weights: Weight tensor of shape (in_channels, out_channels, modes).
153+
154+
Returns:
155+
torch.Tensor: Output tensor of shape (batch, out_channels, modes).
156+
"""
157+
158+
return torch.einsum("bix,iox->box", input, weights)
159+
160+
def forward(self, x: Tensor, transform: VFT) -> Tensor:
161+
"""
162+
Forward pass of the spectral convolution layer.
163+
164+
Args:
165+
x: Input tensor of shape (batch, n_points, in_channels).
166+
transform: Fourier transform operator with forward and inverse methods.
167+
168+
Returns:
169+
The real part of the transformed output tensor
170+
with shape (batch, points, out_channels).
171+
"""
172+
# Compute Fourier coefficients
173+
x_ft = transform.forward(x.to(torch.complex64), norm='forward')
174+
x_ft = x_ft.permute(0, 2, 1)
175+
out_ft = self.compl_mul1d(x_ft, self.weights1)
176+
x_ft = out_ft.permute(0, 2, 1)
177+
178+
# Return to physical space
179+
x = transform.inverse(x_ft, norm='forward')
180+
181+
return x.real
182+
183+
def last_layer(self, x: Tensor, transform: VFT) -> Tensor:
184+
"""
185+
Last convolutional layer returning Fourier coefficients to be used as embedding
186+
187+
Args:
188+
x: Input tensor of shape (batch, points, in_channels).
189+
transform: Fourier transform operator with forward and inverse methods.
190+
191+
Returns:
192+
Transformed output tensor of shape (batch, 2*modes, out_channels).
193+
"""
194+
195+
# Compute Fourier coeffcients
196+
x_ft = transform.forward(x.to(torch.complex64), norm='forward')
197+
x_ft = x_ft.permute(0, 2, 1)
198+
x_ft = self.compl_mul1d(x_ft, self.weights1) # (batch, conv_channels, modes)
199+
x_ft = x_ft.permute(0, 2, 1) # (batch, modes, conv_channels)
200+
x_ft = torch.view_as_real(x_ft) # (batch, modes, conv_channels, 2)
201+
x_ft = x_ft.permute(0, 1, 3, 2)
202+
x_ft = x_ft.reshape(x.shape[0], 2 * self.modes, self.out_channels)
203+
204+
return x_ft
205+
206+
207+
class SpectralConvEmbedding(nn.Module):
208+
def __init__(
209+
self,
210+
in_channels: int,
211+
modes: int = 10,
212+
out_channels: int = 1,
213+
conv_channels: int = 5,
214+
num_layers: int = 3,
215+
):
216+
"""SpectralConvEmbedding is a neural network module that performs convolution
217+
in Fourier space for 1D input data (that can have multiple channels).
218+
It uses a series of spectral convolution layers and pointwise
219+
convolution layers to transform the input tensor.
220+
221+
Adapted from: Lingsch et al. (2024) Beyond Regular Grids: Fourier-Based
222+
Neural Operators on Arbitrary Domains
223+
224+
Args:
225+
in_channels: Number of channels in the input data.
226+
modes: Number of modes considered in the spectral convolution,
227+
at most floor(n_points/2) + 1.
228+
out_channels: number of channels for final output.
229+
conv_channels: Number of going in and out convolutional layer.
230+
num_layers: Number of convolution layers.
231+
232+
"""
233+
super().__init__()
234+
235+
self.modes = modes
236+
self.in_channels = in_channels
237+
self.out_channels = out_channels
238+
self.conv_channels = conv_channels
239+
self.num_layers = num_layers
240+
241+
# Initialize fully connected layer to raise number of
242+
# input channels to number of convolutional channels
243+
self.fc0 = nn.Linear(self.in_channels, self.conv_channels)
244+
245+
# Inititalize layers performing convolution in Fourier space
246+
self.conv_layers = nn.ModuleList([
247+
SpectralConv1d_SMM(self.conv_channels, self.conv_channels, self.modes)
248+
for _ in range(self.num_layers)
249+
])
250+
251+
# Initialize layer performing pointwise convolution
252+
self.w_layers = nn.ModuleList([
253+
nn.Conv1d(self.conv_channels, self.conv_channels, 1)
254+
for _ in range(self.num_layers)
255+
])
256+
257+
# Initialize last convolutional layer with output in Fourier space
258+
self.conv_last = SpectralConv1d_SMM(
259+
self.conv_channels, self.conv_channels, self.modes
260+
)
261+
262+
# Initialize fully connected layer to reduce number of output channels
263+
self.fc_last = nn.Linear(self.conv_channels, self.out_channels)
264+
265+
def forward(self, x: Tensor) -> Tensor:
266+
"""Network forward pass.
267+
268+
Args:
269+
x: 3D input tensor (batch_size, in_channels, n_points) for equi-spaced data
270+
or 4D tensor (batch_size, 2, in_channels, n_points) for non-equispaced data,
271+
where we additionally pass the point positions in the second dimension,
272+
repeating the same point positions for each channel.
273+
For non-equispaced data, the positions have to be normalized with
274+
physical domain length.
275+
276+
Exemplary code:
277+
278+
# Example for equispaced grid data with batch size of 256, 3 channels and
279+
# sequence length of 500
280+
data_equispaced = torch.rand(256, 3, 500)
281+
embedding_net = SpectralConvEmbedding(modes=15, in_channels=3,
282+
out_channels=1, conv_channels=5, num_layers=4)
283+
neural_posterior = posterior_nn(model="nsf", embedding_net=embedding_net)
284+
inference = SNPE(prior=sbi_prior, density_estimator=neural_posterior)
285+
_ = inference.append_simulations(theta, data_equispaced)
286+
287+
# Example for non-equispaced data with batch size of 256, 3 channels and
288+
# sequence length of 500
289+
irregular_positions = torch.rand(500) # non-equally spaced positions in [0;1]
290+
irregular_positions, indices = torch.sort(irregular_positions, 0)
291+
irregular_positions = irregular_positions.repeat(256, 3, 1)
292+
293+
random_data = torch.rand(256, 3, 500)
294+
295+
data_nonequispaced = torch.zeros(256, 2, 3, 500)
296+
data_nonequispaced[:, 0, :, :] = random_data
297+
data_nonequispaced[:, 1, :, :] = irregular_positions
298+
299+
embedding_net = SpectralConvEmbedding(modes=15, in_channels=3, out_channels=1,
300+
conv_channels=5, num_layers=4)
301+
neural_posterior = posterior_nn(model="nsf", embedding_net=embedding_net)
302+
inference = SNPE(prior=sbi_prior, density_estimator=neural_posterior)
303+
_ = inference.append_simulations(theta, data_nonequispaced)
304+
305+
Returns:
306+
Network output (batch_size, out_channels * 2 * modes).
307+
"""
308+
batch_size = x.shape[0]
309+
310+
# Check dimension of input data and reshape it
311+
if x.ndim == 3:
312+
x = x.permute(0, 2, 1) # (batch, n_points, in_channels)
313+
point_positions = None
314+
315+
elif x.ndim == 4:
316+
point_positions = x[:, 1, 0, :]
317+
x = x[:, 0, :, :].permute(0, 2, 1)
318+
319+
else:
320+
raise ValueError(
321+
'Input tensor should be 3D (batch_size, channels, n_points) '
322+
'or 4D (batch_size, 2, channels, n_points). ',
323+
f'The tensor that was passed has shape {x.shape}.',
324+
)
325+
326+
n_points = x.shape[1]
327+
328+
assert self.modes <= n_points // 2 + 1, (
329+
"Modes should be at most floor(n_points/2) + 1"
330+
)
331+
332+
x = self.fc0(x) # (batch_size, n_points, in_channels)
333+
334+
# Initialize Fourier transform for arbitrarily spaced points
335+
fourier_transform = VFT(batch_size, n_points, self.modes, point_positions)
336+
337+
# Send the data through Fourier layers, output in original space
338+
for conv, w in zip(self.conv_layers, self.w_layers, strict=False):
339+
x1 = conv(x, fourier_transform)
340+
x2 = w(x.permute(0, 2, 1))
341+
x = x1 + x2.permute(0, 2, 1)
342+
x = F.gelu(x)
343+
344+
# Send data through last convolutional layer which returns data in Fourier space
345+
x_spec = self.conv_last.last_layer(
346+
x, fourier_transform
347+
) # (batch, 2*modes, out_channels)
348+
349+
# Reduce the number of channels with last layer
350+
x_spec = self.fc_last(x_spec) # (batch, 2*modes, out_channels)
351+
352+
return x_spec.reshape(batch_size, -1)

sbi/neural_nets/embedding_nets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sbi.neural_nets.embedding_nets.SC_embedding import SpectralConvEmbedding
12
from sbi.neural_nets.embedding_nets.causal_cnn import CausalCNNEmbedding
23
from sbi.neural_nets.embedding_nets.cnn import CNNEmbedding
34
from sbi.neural_nets.embedding_nets.fully_connected import FCEmbedding

0 commit comments

Comments
 (0)