Skip to content

Commit e39b1c4

Browse files
authored
Merge pull request #46 from NREL/gb/minmax
added custom functional layer for basic tensorflow functional layers …
2 parents a7ca20d + 28eda8e commit e39b1c4

3 files changed

Lines changed: 65 additions & 1 deletion

File tree

phygnn/layers/custom_layers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,3 +785,46 @@ def call(self, x, hi_res_feature):
785785
Output tensor with the hi_res_feature added to x.
786786
"""
787787
return tf.concat((x, hi_res_feature), axis=-1)
788+
789+
790+
class FunctionalLayer(tf.keras.layers.Layer):
791+
"""Custom layer to implement the tensorflow layer functions (e.g., add,
792+
subtract, multiply, maximum, and minimum) with a constant value. These
793+
cannot be implemented in phygnn as normal layers because they need to
794+
operate on two tensors of equal shape."""
795+
796+
def __init__(self, name, value):
797+
"""
798+
Parameters
799+
----------
800+
name : str
801+
Name of the tensorflow layer function to be implemented, options
802+
are (all lower-case): add, subtract, multiply, maximum, and minimum
803+
value : float
804+
Constant value to use in the function operation
805+
"""
806+
807+
options = ('add', 'subtract', 'multiply', 'maximum', 'minimum')
808+
msg = (f'FunctionalLayer input `name` must be one of "{options}" '
809+
f'but received "{name}"')
810+
assert name in options, msg
811+
812+
super().__init__(name=name)
813+
self.value = value
814+
self.fun = getattr(tf.keras.layers, self.name)
815+
816+
def call(self, x):
817+
"""Operates on x with the specified function
818+
819+
Parameters
820+
----------
821+
x : tf.Tensor
822+
Input tensor
823+
824+
Returns
825+
-------
826+
x : tf.Tensor
827+
Output tensor operated on by the specified function
828+
"""
829+
const = tf.constant(value=self.value, shape=x.shape, dtype=x.dtype)
830+
return self.fun((x, const))

phygnn/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# -*- coding: utf-8 -*-
22
"""Physics Guided Neural Network version."""
33

4-
__version__ = '0.0.25'
4+
__version__ = '0.0.26'

tests/test_layers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
SkipConnection,
1313
SpatioTemporalExpansion,
1414
TileLayer,
15+
FunctionalLayer,
1516
)
1617
from phygnn.layers.handlers import HiddenLayers, Layers
1718

@@ -423,3 +424,23 @@ def test_fno_3d():
423424
x = layer(x)
424425
with pytest.raises(tf.errors.InvalidArgumentError):
425426
tf.assert_equal(x_in, x)
427+
428+
429+
def test_functional_layer():
430+
"""Test the generic functional layer"""
431+
432+
layer = FunctionalLayer('maximum', 1)
433+
x = np.random.normal(0.5, 3, size=(1, 4, 4, 6, 3))
434+
assert layer(x).numpy().min() == 1.0
435+
436+
# make sure layer works with input of arbitrary shape
437+
x = np.random.normal(0.5, 3, size=(2, 8, 8, 4, 1))
438+
assert layer(x).numpy().min() == 1.0
439+
440+
layer = FunctionalLayer('multiply', 1.5)
441+
x = np.random.normal(0.5, 3, size=(1, 4, 4, 6, 3))
442+
assert np.allclose(layer(x).numpy(), x * 1.5)
443+
444+
with pytest.raises(AssertionError) as excinfo:
445+
FunctionalLayer('bad_arg', 0)
446+
assert "must be one of" in str(excinfo.value)

0 commit comments

Comments
 (0)