|
5 | 5 | import pytest |
6 | 6 | import tensorflow as tf |
7 | 7 |
|
8 | | -from phygnn.layers.custom_layers import (SkipConnection, |
9 | | - SpatioTemporalExpansion, |
10 | | - FlattenAxis, |
11 | | - ExpandDims, |
12 | | - TileLayer, |
13 | | - GaussianNoiseAxis) |
14 | | -from phygnn.layers.handlers import Layers, HiddenLayers |
| 8 | +from phygnn.layers.custom_layers import ( |
| 9 | + ExpandDims, |
| 10 | + FlattenAxis, |
| 11 | + GaussianNoiseAxis, |
| 12 | + SkipConnection, |
| 13 | + SpatioTemporalExpansion, |
| 14 | + TileLayer, |
| 15 | +) |
| 16 | +from phygnn.layers.handlers import HiddenLayers, Layers |
15 | 17 |
|
16 | 18 |
|
17 | 19 | @pytest.mark.parametrize( |
@@ -208,7 +210,7 @@ def test_temporal_depth_to_time(t_mult, s_mult, t_roll): |
208 | 210 | n_filters = 2 * s_mult**2 * t_mult |
209 | 211 | shape = (1, 4, 4, 3, n_filters) |
210 | 212 | n = np.product(shape) |
211 | | - x = np.arange(n).reshape((shape)) |
| 213 | + x = np.arange(n).reshape(shape) |
212 | 214 | y = layer(x) |
213 | 215 | assert y.shape[0] == x.shape[0] |
214 | 216 | assert y.shape[1] == s_mult * x.shape[1] |
@@ -387,3 +389,37 @@ def test_squeeze_excite_3d(): |
387 | 389 | x = layer(x) |
388 | 390 | with pytest.raises(tf.errors.InvalidArgumentError): |
389 | 391 | tf.assert_equal(x_in, x) |
| 392 | + |
| 393 | + |
| 394 | +def test_fno_2d(): |
| 395 | + """Test the FNO layer with 2D data (4D tensor input)""" |
| 396 | + hidden_layers = [ |
| 397 | + {'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01, |
| 398 | + 'activation': 'relu'}] |
| 399 | + layers = HiddenLayers(hidden_layers) |
| 400 | + assert len(layers.layers) == 1 |
| 401 | + |
| 402 | + x = np.random.normal(0, 1, size=(1, 4, 4, 3)) |
| 403 | + |
| 404 | + for layer in layers: |
| 405 | + x_in = x |
| 406 | + x = layer(x) |
| 407 | + with pytest.raises(tf.errors.InvalidArgumentError): |
| 408 | + tf.assert_equal(x_in, x) |
| 409 | + |
| 410 | + |
| 411 | +def test_fno_3d(): |
| 412 | + """Test the FNO layer with 3D data (5D tensor input)""" |
| 413 | + hidden_layers = [ |
| 414 | + {'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01, |
| 415 | + 'activation': 'relu'}] |
| 416 | + layers = HiddenLayers(hidden_layers) |
| 417 | + assert len(layers.layers) == 1 |
| 418 | + |
| 419 | + x = np.random.normal(0, 1, size=(1, 4, 4, 6, 3)) |
| 420 | + |
| 421 | + for layer in layers: |
| 422 | + x_in = x |
| 423 | + x = layer(x) |
| 424 | + with pytest.raises(tf.errors.InvalidArgumentError): |
| 425 | + tf.assert_equal(x_in, x) |
0 commit comments