Skip to content

Commit 83d3fbf

Browse files
authored
Merge pull request #23 from GeoOcean/feature/nns
Feature/nns
2 parents 2b1e449 + 1e02a00 commit 83d3fbf

File tree

4 files changed

+84
-15
lines changed

4 files changed

+84
-15
lines changed

bluemath_tk/deeplearning/models/__init__.py

Whitespace-only changes.

bluemath_tk/deeplearning/models/resnet_model.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from keras import layers
33
from typing import List, Tuple
44

5+
56
def ResidualBlock(width: int) -> layers.Layer:
67
def apply(x: layers.Layer) -> layers.Layer:
7-
88
input_width = x.shape[3]
99
residual = x if input_width == width else layers.Conv2D(width, kernel_size=1)(x)
10-
10+
1111
x = layers.LayerNormalization(axis=-1, center=True, scale=True)(x)
1212
x = layers.Conv2D(
1313
width, kernel_size=3, padding="same", activation=keras.activations.swish
@@ -18,9 +18,11 @@ def apply(x: layers.Layer) -> layers.Layer:
1818

1919
return apply
2020

21+
2122
def DownBlock(width: int, block_depth: int) -> layers.Layer:
22-
def apply(x: Tuple[layers.Layer, List[layers.Layer]]) -> Tuple[layers.Layer, List[layers.Layer]]:
23-
23+
def apply(
24+
x: Tuple[layers.Layer, List[layers.Layer]],
25+
) -> Tuple[layers.Layer, List[layers.Layer]]:
2426
x, skips = x
2527
for _ in range(block_depth):
2628
x = ResidualBlock(width)(x)
@@ -30,26 +32,30 @@ def apply(x: Tuple[layers.Layer, List[layers.Layer]]) -> Tuple[layers.Layer, Lis
3032

3133
return apply
3234

35+
3336
def UpBlock(width: int, block_depth: int) -> layers.Layer:
3437
def apply(x: Tuple[layers.Layer, List[layers.Layer]]) -> layers.Layer:
35-
3638
x, skips = x
3739
x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
3840
for _ in range(block_depth):
3941
x = layers.Concatenate()([x, skips.pop()])
4042
x = ResidualBlock(width)(x)
4143
return x
44+
4245
return apply
4346

44-
def get_model(image_height: int,
45-
image_width: int,
46-
input_frames: int,
47-
output_frames: int,
48-
down_widths: List[int] = [64, 128, 256],
49-
up_widths: List[int] = [256, 128, 64],
50-
block_depth: int = 2) -> keras.Model:
47+
48+
def get_model(
49+
image_height: int,
50+
image_width: int,
51+
input_frames: int,
52+
output_frames: int,
53+
down_widths: List[int] = [64, 128, 256],
54+
up_widths: List[int] = [256, 128, 64],
55+
block_depth: int = 2,
56+
) -> keras.Model:
5157
"""Builds the U-Net like model with residual blocks and skip connections."""
52-
58+
5359
inputs = keras.Input(shape=(image_height, image_width, input_frames))
5460
x = layers.Conv2D(down_widths[0], kernel_size=1)(inputs)
5561

@@ -64,5 +70,5 @@ def get_model(image_height: int,
6470
x = UpBlock(width, block_depth)([x, skips])
6571

6672
outputs = layers.Conv2D(output_frames, kernel_size=1, kernel_initializer="zeros")(x)
67-
73+
6874
return keras.Model(inputs, outputs, name="residual_unet")

environment.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
name: bluemath
1+
name: bluemath-deep
22
channels:
33
- conda-forge
44
dependencies:
55
- python=3.12
66
- numpy
77
- pandas
88
- xarray
9+
- netcdf4
10+
- scipy
911
- scikit-learn
1012
- matplotlib
1113
- qt
@@ -14,6 +16,8 @@ dependencies:
1416
- cartopy
1517
- pytest
1618
- cdsapi
19+
- keras
20+
- tensorflow
1721
# PIP packages can be added below (avoid this whenever possible)
1822
- pip:
1923
- minisom

tests/deeplearning/test_resnet.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import unittest
2+
from unittest.mock import Mock
3+
import keras
4+
5+
6+
class TestResNetTraining(unittest.TestCase):
7+
def setUp(self):
8+
# Mock the data generator
9+
self.train_generator = Mock()
10+
self.train_generator.num_images = 5000
11+
self.train_generator.input_height = 64
12+
self.train_generator.input_width = 64
13+
self.train_generator.output_height = 64
14+
self.train_generator.output_width = 64
15+
self.train_generator.batch_size = 1
16+
17+
# Define optimizer
18+
self.optimizer = keras.optimizers.AdamW
19+
20+
# Mock the model
21+
self.model = Mock()
22+
self.model.compile = Mock()
23+
self.model.fit = Mock(return_value=Mock(history={"loss": [0.1, 0.05, 0.01]}))
24+
25+
def test_model_training(self):
26+
# Compile the model
27+
self.model.compile(
28+
optimizer=self.optimizer(learning_rate=1e-4, weight_decay=1e-5),
29+
loss=keras.losses.mean_squared_error,
30+
)
31+
32+
# Assert that compile was called with correct parameters
33+
self.model.compile.assert_called_with(
34+
optimizer=self.optimizer(learning_rate=1e-4, weight_decay=1e-5),
35+
loss=keras.losses.mean_squared_error,
36+
)
37+
38+
# Start the train loop with the fit method
39+
history = self.model.fit(
40+
self.train_generator, initial_epoch=0, epochs=20, steps_per_epoch=500
41+
)
42+
43+
# Assert that fit was called with correct parameters
44+
self.model.fit.assert_called_with(
45+
self.train_generator, initial_epoch=0, epochs=20, steps_per_epoch=500
46+
)
47+
48+
# Assert that the training history is as expected
49+
self.assertIn("loss", history.history)
50+
self.assertEqual(len(history.history["loss"]), 3)
51+
self.assertAlmostEqual(history.history["loss"][0], 0.1)
52+
self.assertAlmostEqual(history.history["loss"][1], 0.05)
53+
self.assertAlmostEqual(history.history["loss"][2], 0.01)
54+
55+
print("training complete")
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main()

0 commit comments

Comments
 (0)