Skip to content

Commit c422acf

Browse files
committed
cant have codecov going down
1 parent ef61e16 commit c422acf

1 file changed

Lines changed: 44 additions & 8 deletions

File tree

tests/test_layers.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import pytest
66
import tensorflow as tf
77

8-
from phygnn.layers.custom_layers import (SkipConnection,
9-
SpatioTemporalExpansion,
10-
FlattenAxis,
11-
ExpandDims,
12-
TileLayer,
13-
GaussianNoiseAxis)
14-
from phygnn.layers.handlers import Layers, HiddenLayers
8+
from phygnn.layers.custom_layers import (
9+
ExpandDims,
10+
FlattenAxis,
11+
GaussianNoiseAxis,
12+
SkipConnection,
13+
SpatioTemporalExpansion,
14+
TileLayer,
15+
)
16+
from phygnn.layers.handlers import HiddenLayers, Layers
1517

1618

1719
@pytest.mark.parametrize(
@@ -208,7 +210,7 @@ def test_temporal_depth_to_time(t_mult, s_mult, t_roll):
208210
n_filters = 2 * s_mult**2 * t_mult
209211
shape = (1, 4, 4, 3, n_filters)
210212
n = np.product(shape)
211-
x = np.arange(n).reshape((shape))
213+
x = np.arange(n).reshape(shape)
212214
y = layer(x)
213215
assert y.shape[0] == x.shape[0]
214216
assert y.shape[1] == s_mult * x.shape[1]
@@ -387,3 +389,37 @@ def test_squeeze_excite_3d():
387389
x = layer(x)
388390
with pytest.raises(tf.errors.InvalidArgumentError):
389391
tf.assert_equal(x_in, x)
392+
393+
394+
def test_fno_2d():
395+
"""Test the FNO layer with 2D data (4D tensor input)"""
396+
hidden_layers = [
397+
{'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01,
398+
'activation': 'relu'}]
399+
layers = HiddenLayers(hidden_layers)
400+
assert len(layers.layers) == 1
401+
402+
x = np.random.normal(0, 1, size=(1, 4, 4, 3))
403+
404+
for layer in layers:
405+
x_in = x
406+
x = layer(x)
407+
with pytest.raises(tf.errors.InvalidArgumentError):
408+
tf.assert_equal(x_in, x)
409+
410+
411+
def test_fno_3d():
412+
"""Test the FNO layer with 3D data (5D tensor input)"""
413+
hidden_layers = [
414+
{'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01,
415+
'activation': 'relu'}]
416+
layers = HiddenLayers(hidden_layers)
417+
assert len(layers.layers) == 1
418+
419+
x = np.random.normal(0, 1, size=(1, 4, 4, 6, 3))
420+
421+
for layer in layers:
422+
x_in = x
423+
x = layer(x)
424+
with pytest.raises(tf.errors.InvalidArgumentError):
425+
tf.assert_equal(x_in, x)

0 commit comments

Comments
 (0)