@@ -785,3 +785,46 @@ def call(self, x, hi_res_feature):
785785 Output tensor with the hi_res_feature added to x.
786786 """
787787 return tf .concat ((x , hi_res_feature ), axis = - 1 )
788+
789+
790+ class FunctionalLayer (tf .keras .layers .Layer ):
791+ """Custom layer to implement the tensorflow layer functions (e.g., add,
792+ subtract, multiply, maximum, and minimum) with a constant value. These
793+ cannot be implemented in phygnn as normal layers because they need to
794+ operate on two tensors of equal shape."""
795+
796+ def __init__ (self , name , value ):
797+ """
798+ Parameters
799+ ----------
800+ name : str
801+ Name of the tensorflow layer function to be implemented, options
802+ are (all lower-case): add, subtract, multiply, maximum, and minimum
803+ value : float
804+ Constant value to use in the function operation
805+ """
806+
807+ options = ('add' , 'subtract' , 'multiply' , 'maximum' , 'minimum' )
808+ msg = (f'FunctionalLayer input `name` must be one of "{ options } " '
809+ f'but received "{ name } "' )
810+ assert name in options , msg
811+
812+ super ().__init__ (name = name )
813+ self .value = value
814+ self .fun = getattr (tf .keras .layers , self .name )
815+
816+ def call (self , x ):
817+ """Operates on x with the specified function
818+
819+ Parameters
820+ ----------
821+ x : tf.Tensor
822+ Input tensor
823+
824+ Returns
825+ -------
826+ x : tf.Tensor
827+ Output tensor operated on by the specified function
828+ """
829+ const = tf .constant (value = self .value , shape = x .shape , dtype = x .dtype )
830+ return self .fun ((x , const ))
0 commit comments