Skip to content

Commit a7ca20d

Browse files
authored
Merge pull request #45 from NREL/bnb/fno
Bnb/fno
2 parents 42c914a + c422acf commit a7ca20d

2 files changed

Lines changed: 164 additions & 8 deletions

File tree

phygnn/layers/custom_layers.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,126 @@ def call(self, x):
597597
return x
598598

599599

600+
class FNO(tf.keras.layers.Layer):
601+
"""Custom layer for fourier neural operator block
602+
603+
Note that this is only set up to take a channels-last input
604+
605+
References
606+
----------
607+
1. FourCastNet: A Global Data-driven High-resolution Weather Model using
608+
Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214
609+
2. Adaptive Fourier Neural Operators: Efficient Token Mixers for
610+
Transformers. http://arxiv.org/abs/2111.13587
611+
"""
612+
613+
def __init__(self, filters, sparsity_threshold=0.5, activation='relu'):
614+
"""
615+
Parameters
616+
----------
617+
filters : int
618+
Number of dense connections in the FNO block.
619+
sparsity_threshold : float
620+
Parameter to control sparsity and shrinkage in the softshrink
621+
activation function following the MLP layers.
622+
activation : str
623+
Activation function used in MLP layers.
624+
"""
625+
626+
super().__init__()
627+
self._filters = filters
628+
self._fft_layer = None
629+
self._ifft_layer = None
630+
self._mlp_layers = None
631+
self._activation = activation
632+
self._n_channels = None
633+
self._perms_in = None
634+
self._perms_out = None
635+
self._lambd = sparsity_threshold
636+
637+
def _softshrink(self, x):
638+
"""Softshrink activation function
639+
640+
https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html
641+
"""
642+
values_below_lower = tf.where(x < -self._lambd, x + self._lambd, 0)
643+
values_above_upper = tf.where(self._lambd < x, x - self._lambd, 0)
644+
return values_below_lower + values_above_upper
645+
646+
def _fft(self, x):
647+
"""Apply needed transpositions and fft operation."""
648+
x = tf.transpose(x, perm=self._perms_in)
649+
x = self._fft_layer(tf.cast(x, tf.complex64))
650+
x = tf.transpose(x, perm=self._perms_out)
651+
return x
652+
653+
def _ifft(self, x):
654+
"""Apply needed transpositions and ifft operation."""
655+
x = tf.transpose(x, perm=self._perms_in)
656+
x = self._ifft_layer(tf.cast(x, tf.complex64))
657+
x = tf.transpose(x, perm=self._perms_out)
658+
return x
659+
660+
def build(self, input_shape):
661+
"""Build the FNO layer based on an input shape
662+
663+
Parameters
664+
----------
665+
input_shape : tuple
666+
Shape tuple of the input tensor
667+
"""
668+
self._n_channels = input_shape[-1]
669+
dims = list(range(len(input_shape)))
670+
self._perms_in = [dims[-1], *dims[:-1]]
671+
self._perms_out = [*dims[1:], dims[0]]
672+
673+
if len(input_shape) == 4:
674+
self._fft_layer = tf.signal.fft2d
675+
self._ifft_layer = tf.signal.ifft2d
676+
elif len(input_shape) == 5:
677+
self._fft_layer = tf.signal.fft3d
678+
self._ifft_layer = tf.signal.ifft3d
679+
else:
680+
msg = ('FNO layer can only accept 4D or 5D data '
681+
'for image or video input but received input shape: {}'
682+
.format(input_shape))
683+
logger.error(msg)
684+
raise RuntimeError(msg)
685+
686+
self._mlp_layers = [
687+
tf.keras.layers.Dense(self._filters, activation=self._activation),
688+
tf.keras.layers.Dense(self._n_channels)]
689+
690+
def _mlp_block(self, x):
691+
"""Run mlp layers on input"""
692+
for layer in self._mlp_layers:
693+
x = layer(x)
694+
return x
695+
696+
def call(self, x):
697+
"""Call the custom FourierNeuralOperator layer
698+
699+
Parameters
700+
----------
701+
x : tf.Tensor
702+
Input tensor.
703+
704+
Returns
705+
-------
706+
x : tf.Tensor
707+
Output tensor, this is the FNO weights added to the original input
708+
tensor.
709+
"""
710+
t_in = x
711+
x = self._fft(x)
712+
x = self._mlp_block(x)
713+
x = self._softshrink(x)
714+
x = self._ifft(x)
715+
x = tf.cast(x, dtype=t_in.dtype)
716+
717+
return x + t_in
718+
719+
600720
class Sup3rAdder(tf.keras.layers.Layer):
601721
"""Layer to add high-resolution data to a sup3r model in the middle of a
602722
super resolution forward pass."""

tests/test_layers.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import pytest
66
import tensorflow as tf
77

8-
from phygnn.layers.custom_layers import (SkipConnection,
9-
SpatioTemporalExpansion,
10-
FlattenAxis,
11-
ExpandDims,
12-
TileLayer,
13-
GaussianNoiseAxis)
14-
from phygnn.layers.handlers import Layers, HiddenLayers
8+
from phygnn.layers.custom_layers import (
9+
ExpandDims,
10+
FlattenAxis,
11+
GaussianNoiseAxis,
12+
SkipConnection,
13+
SpatioTemporalExpansion,
14+
TileLayer,
15+
)
16+
from phygnn.layers.handlers import HiddenLayers, Layers
1517

1618

1719
@pytest.mark.parametrize(
@@ -208,7 +210,7 @@ def test_temporal_depth_to_time(t_mult, s_mult, t_roll):
208210
n_filters = 2 * s_mult**2 * t_mult
209211
shape = (1, 4, 4, 3, n_filters)
210212
n = np.product(shape)
211-
x = np.arange(n).reshape((shape))
213+
x = np.arange(n).reshape(shape)
212214
y = layer(x)
213215
assert y.shape[0] == x.shape[0]
214216
assert y.shape[1] == s_mult * x.shape[1]
@@ -387,3 +389,37 @@ def test_squeeze_excite_3d():
387389
x = layer(x)
388390
with pytest.raises(tf.errors.InvalidArgumentError):
389391
tf.assert_equal(x_in, x)
392+
393+
394+
def test_fno_2d():
395+
"""Test the FNO layer with 2D data (4D tensor input)"""
396+
hidden_layers = [
397+
{'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01,
398+
'activation': 'relu'}]
399+
layers = HiddenLayers(hidden_layers)
400+
assert len(layers.layers) == 1
401+
402+
x = np.random.normal(0, 1, size=(1, 4, 4, 3))
403+
404+
for layer in layers:
405+
x_in = x
406+
x = layer(x)
407+
with pytest.raises(tf.errors.InvalidArgumentError):
408+
tf.assert_equal(x_in, x)
409+
410+
411+
def test_fno_3d():
412+
"""Test the FNO layer with 3D data (5D tensor input)"""
413+
hidden_layers = [
414+
{'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01,
415+
'activation': 'relu'}]
416+
layers = HiddenLayers(hidden_layers)
417+
assert len(layers.layers) == 1
418+
419+
x = np.random.normal(0, 1, size=(1, 4, 4, 6, 3))
420+
421+
for layer in layers:
422+
x_in = x
423+
x = layer(x)
424+
with pytest.raises(tf.errors.InvalidArgumentError):
425+
tf.assert_equal(x_in, x)

0 commit comments

Comments
 (0)