@@ -597,6 +597,126 @@ def call(self, x):
597597 return x
598598
599599
600+ class FNO (tf .keras .layers .Layer ):
601+ """Custom layer for fourier neural operator block
602+
603+ Note that this is only set up to take a channels-last input
604+
605+ References
606+ ----------
607+ 1. FourCastNet: A Global Data-driven High-resolution Weather Model using
608+ Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214
609+ 2. Adaptive Fourier Neural Operators: Efficient Token Mixers for
610+ Transformers. http://arxiv.org/abs/2111.13587
611+ """
612+
613+ def __init__ (self , filters , sparsity_threshold = 0.5 , activation = 'relu' ):
614+ """
615+ Parameters
616+ ----------
617+ filters : int
618+ Number of dense connections in the FNO block.
619+ sparsity_threshold : float
620+ Parameter to control sparsity and shrinkage in the softshrink
621+ activation function following the MLP layers.
622+ activation : str
623+ Activation function used in MLP layers.
624+ """
625+
626+ super ().__init__ ()
627+ self ._filters = filters
628+ self ._fft_layer = None
629+ self ._ifft_layer = None
630+ self ._mlp_layers = None
631+ self ._activation = activation
632+ self ._n_channels = None
633+ self ._perms_in = None
634+ self ._perms_out = None
635+ self ._lambd = sparsity_threshold
636+
637+ def _softshrink (self , x ):
638+ """Softshrink activation function
639+
640+ https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html
641+ """
642+ values_below_lower = tf .where (x < - self ._lambd , x + self ._lambd , 0 )
643+ values_above_upper = tf .where (self ._lambd < x , x - self ._lambd , 0 )
644+ return values_below_lower + values_above_upper
645+
646+ def _fft (self , x ):
647+ """Apply needed transpositions and fft operation."""
648+ x = tf .transpose (x , perm = self ._perms_in )
649+ x = self ._fft_layer (tf .cast (x , tf .complex64 ))
650+ x = tf .transpose (x , perm = self ._perms_out )
651+ return x
652+
653+ def _ifft (self , x ):
654+ """Apply needed transpositions and ifft operation."""
655+ x = tf .transpose (x , perm = self ._perms_in )
656+ x = self ._ifft_layer (tf .cast (x , tf .complex64 ))
657+ x = tf .transpose (x , perm = self ._perms_out )
658+ return x
659+
660+ def build (self , input_shape ):
661+ """Build the FNO layer based on an input shape
662+
663+ Parameters
664+ ----------
665+ input_shape : tuple
666+ Shape tuple of the input tensor
667+ """
668+ self ._n_channels = input_shape [- 1 ]
669+ dims = list (range (len (input_shape )))
670+ self ._perms_in = [dims [- 1 ], * dims [:- 1 ]]
671+ self ._perms_out = [* dims [1 :], dims [0 ]]
672+
673+ if len (input_shape ) == 4 :
674+ self ._fft_layer = tf .signal .fft2d
675+ self ._ifft_layer = tf .signal .ifft2d
676+ elif len (input_shape ) == 5 :
677+ self ._fft_layer = tf .signal .fft3d
678+ self ._ifft_layer = tf .signal .ifft3d
679+ else :
680+ msg = ('FNO layer can only accept 4D or 5D data '
681+ 'for image or video input but received input shape: {}'
682+ .format (input_shape ))
683+ logger .error (msg )
684+ raise RuntimeError (msg )
685+
686+ self ._mlp_layers = [
687+ tf .keras .layers .Dense (self ._filters , activation = self ._activation ),
688+ tf .keras .layers .Dense (self ._n_channels )]
689+
690+ def _mlp_block (self , x ):
691+ """Run mlp layers on input"""
692+ for layer in self ._mlp_layers :
693+ x = layer (x )
694+ return x
695+
696+ def call (self , x ):
697+ """Call the custom FourierNeuralOperator layer
698+
699+ Parameters
700+ ----------
701+ x : tf.Tensor
702+ Input tensor.
703+
704+ Returns
705+ -------
706+ x : tf.Tensor
707+ Output tensor, this is the FNO weights added to the original input
708+ tensor.
709+ """
710+ t_in = x
711+ x = self ._fft (x )
712+ x = self ._mlp_block (x )
713+ x = self ._softshrink (x )
714+ x = self ._ifft (x )
715+ x = tf .cast (x , dtype = t_in .dtype )
716+
717+ return x + t_in
718+
719+
600720class Sup3rAdder (tf .keras .layers .Layer ):
601721 """Layer to add high-resolution data to a sup3r model in the middle of a
602722 super resolution forward pass."""
0 commit comments