Skip to content

Commit 6907e58

Browse files
authored
Merge pull request #19 from GeoOcean/feature/nns
Feature/nns
2 parents e4d7c15 + d64ed05 commit 6907e58

File tree

7 files changed

+353
-0
lines changed

7 files changed

+353
-0
lines changed

bluemath_tk/deeplearning/__init__.py

Whitespace-only changes.
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
name: deep
2+
channels:
3+
- defaults
4+
dependencies:
5+
- _libgcc_mutex=0.1=main
6+
- _openmp_mutex=5.1=1_gnu
7+
- anyio=4.2.0=py311h06a4308_0
8+
- argon2-cffi=21.3.0=pyhd3eb1b0_0
9+
- argon2-cffi-bindings=21.2.0=py311h5eee18b_0
10+
- asttokens=2.0.5=pyhd3eb1b0_0
11+
- async-lru=2.0.4=py311h06a4308_0
12+
- attrs=23.1.0=py311h06a4308_0
13+
- babel=2.11.0=py311h06a4308_0
14+
- beautifulsoup4=4.12.2=py311h06a4308_0
15+
- bleach=4.1.0=pyhd3eb1b0_0
16+
- brotli-python=1.0.9=py311h6a678d5_7
17+
- bzip2=1.0.8=h5eee18b_5
18+
- ca-certificates=2024.3.11=h06a4308_0
19+
- certifi=2024.2.2=py311h06a4308_0
20+
- cffi=1.16.0=py311h5eee18b_0
21+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
22+
- comm=0.2.1=py311h06a4308_0
23+
- debugpy=1.6.7=py311h6a678d5_0
24+
- decorator=5.1.1=pyhd3eb1b0_0
25+
- defusedxml=0.7.1=pyhd3eb1b0_0
26+
- executing=0.8.3=pyhd3eb1b0_0
27+
- idna=3.4=py311h06a4308_0
28+
- ipykernel=6.28.0=py311h06a4308_0
29+
- ipython=8.20.0=py311h06a4308_0
30+
- jedi=0.18.1=py311h06a4308_1
31+
- jinja2=3.1.3=py311h06a4308_0
32+
- json5=0.9.6=pyhd3eb1b0_0
33+
- jsonschema=4.19.2=py311h06a4308_0
34+
- jsonschema-specifications=2023.7.1=py311h06a4308_0
35+
- jupyter-lsp=2.2.0=py311h06a4308_0
36+
- jupyter_client=8.6.0=py311h06a4308_0
37+
- jupyter_core=5.5.0=py311h06a4308_0
38+
- jupyter_events=0.8.0=py311h06a4308_0
39+
- jupyter_server=2.10.0=py311h06a4308_0
40+
- jupyter_server_terminals=0.4.4=py311h06a4308_1
41+
- jupyterlab=4.0.11=py311h06a4308_0
42+
- jupyterlab_pygments=0.1.2=py_0
43+
- jupyterlab_server=2.25.1=py311h06a4308_0
44+
- ld_impl_linux-64=2.38=h1181459_1
45+
- libffi=3.4.4=h6a678d5_0
46+
- libgcc-ng=11.2.0=h1234567_1
47+
- libgomp=11.2.0=h1234567_1
48+
- libsodium=1.0.18=h7b6447c_0
49+
- libstdcxx-ng=11.2.0=h1234567_1
50+
- libuuid=1.41.5=h5eee18b_0
51+
- markupsafe=2.1.3=py311h5eee18b_0
52+
- matplotlib-inline=0.1.6=py311h06a4308_0
53+
- mistune=2.0.4=py311h06a4308_0
54+
- nbclient=0.8.0=py311h06a4308_0
55+
- nbconvert=7.10.0=py311h06a4308_0
56+
- nbformat=5.9.2=py311h06a4308_0
57+
- ncurses=6.4=h6a678d5_0
58+
- nest-asyncio=1.6.0=py311h06a4308_0
59+
- notebook-shim=0.2.3=py311h06a4308_0
60+
- openssl=3.0.13=h7f8727e_0
61+
- overrides=7.4.0=py311h06a4308_0
62+
- packaging=23.2=py311h06a4308_0
63+
- pandocfilters=1.5.0=pyhd3eb1b0_0
64+
- parso=0.8.3=pyhd3eb1b0_0
65+
- pexpect=4.8.0=pyhd3eb1b0_3
66+
- platformdirs=3.10.0=py311h06a4308_0
67+
- prometheus_client=0.14.1=py311h06a4308_0
68+
- prompt-toolkit=3.0.43=py311h06a4308_0
69+
- prompt_toolkit=3.0.43=hd3eb1b0_0
70+
- psutil=5.9.0=py311h5eee18b_0
71+
- ptyprocess=0.7.0=pyhd3eb1b0_2
72+
- pure_eval=0.2.2=pyhd3eb1b0_0
73+
- pycparser=2.21=pyhd3eb1b0_0
74+
- pygments=2.15.1=py311h06a4308_1
75+
- pysocks=1.7.1=py311h06a4308_0
76+
- python=3.11.9=h955ad1f_0
77+
- python-dateutil=2.8.2=pyhd3eb1b0_0
78+
- python-fastjsonschema=2.16.2=py311h06a4308_0
79+
- python-json-logger=2.0.7=py311h06a4308_0
80+
- pytz=2024.1=py311h06a4308_0
81+
- pyyaml=6.0.1=py311h5eee18b_0
82+
- pyzmq=25.1.2=py311h6a678d5_0
83+
- readline=8.2=h5eee18b_0
84+
- referencing=0.30.2=py311h06a4308_0
85+
- requests=2.31.0=py311h06a4308_1
86+
- rfc3339-validator=0.1.4=py311h06a4308_0
87+
- rfc3986-validator=0.1.1=py311h06a4308_0
88+
- rpds-py=0.10.6=py311hb02cf49_0
89+
- send2trash=1.8.2=py311h06a4308_0
90+
- setuptools=68.2.2=py311h06a4308_0
91+
- six=1.16.0=pyhd3eb1b0_1
92+
- sniffio=1.3.0=py311h06a4308_0
93+
- soupsieve=2.5=py311h06a4308_0
94+
- sqlite=3.41.2=h5eee18b_0
95+
- stack_data=0.2.0=pyhd3eb1b0_0
96+
- terminado=0.17.1=py311h06a4308_0
97+
- tinycss2=1.2.1=py311h06a4308_0
98+
- tk=8.6.12=h1ccaba5_0
99+
- tornado=6.3.3=py311h5eee18b_0
100+
- traitlets=5.7.1=py311h06a4308_0
101+
- typing-extensions=4.9.0=py311h06a4308_1
102+
- typing_extensions=4.9.0=py311h06a4308_1
103+
- urllib3=2.1.0=py311h06a4308_1
104+
- wcwidth=0.2.5=pyhd3eb1b0_0
105+
- webencodings=0.5.1=py311h06a4308_1
106+
- websocket-client=0.58.0=py311h06a4308_4
107+
- wheel=0.41.2=py311h06a4308_0
108+
- xz=5.4.6=h5eee18b_0
109+
- yaml=0.2.5=h7b6447c_0
110+
- zeromq=4.3.5=h6a678d5_0
111+
- zlib=1.2.13=h5eee18b_0
112+
- pip:
113+
- absl-py==2.1.0
114+
- array-record==0.5.1
115+
- astunparse==1.6.3
116+
- cachetools==5.3.3
117+
- cartopy==0.24.1
118+
- cf-xarray==0.10.0
119+
- cftime==1.6.4
120+
- click==8.1.7
121+
- contourpy==1.2.1
122+
- cycler==0.12.1
123+
- dm-tree==0.1.8
124+
- etils==1.8.0
125+
- filelock==3.13.4
126+
- flatbuffers==24.3.25
127+
- fonttools==4.51.0
128+
- fsspec==2024.3.1
129+
- gast==0.5.4
130+
- gdown==5.1.0
131+
- google-auth==2.29.0
132+
- google-auth-oauthlib==1.2.0
133+
- google-pasta==0.2.0
134+
- googleapis-common-protos==1.63.0
135+
- grpcio==1.62.2
136+
- h5py==3.11.0
137+
- importlib-resources==6.4.0
138+
- joblib==1.4.2
139+
- keras==3.3.3
140+
- kiwisolver==1.4.5
141+
- libclang==18.1.1
142+
- llvmlite==0.43.0
143+
- markdown==3.6
144+
- markdown-it-py==3.0.0
145+
- matplotlib==3.8.4
146+
- mdurl==0.1.2
147+
- ml-dtypes==0.2.0
148+
- namex==0.0.8
149+
- netcdf4==1.7.1.post2
150+
- nltk==3.8.1
151+
- numba==0.60.0
152+
- numpy==1.26.4
153+
- nvidia-cublas-cu12==12.3.4.1
154+
- nvidia-cuda-cupti-cu12==12.3.101
155+
- nvidia-cuda-nvcc-cu12==12.3.107
156+
- nvidia-cuda-nvrtc-cu12==12.3.107
157+
- nvidia-cuda-runtime-cu12==12.3.101
158+
- nvidia-cudnn-cu12==8.9.7.29
159+
- nvidia-cufft-cu12==11.0.12.1
160+
- nvidia-curand-cu12==10.3.4.107
161+
- nvidia-cusolver-cu12==11.5.4.101
162+
- nvidia-cusparse-cu12==12.2.0.103
163+
- nvidia-nccl-cu12==2.19.3
164+
- nvidia-nvjitlink-cu12==12.3.101
165+
- oauthlib==3.2.2
166+
- opencv-python==4.10.0.84
167+
- opt-einsum==3.3.0
168+
- optree==0.11.0
169+
- pandas==2.2.3
170+
- pillow==10.3.0
171+
- pip==24.0
172+
- plotly==5.24.1
173+
- promise==2.3
174+
- protobuf==4.25.3
175+
- pyasn1==0.6.0
176+
- pyasn1-modules==0.4.0
177+
- pyparsing==3.1.2
178+
- pyproj==3.7.0
179+
- pyshp==2.3.1
180+
- regex==2024.5.10
181+
- requests-oauthlib==2.0.0
182+
- rich==13.7.1
183+
- rsa==4.9
184+
- scikit-learn==1.4.2
185+
- scipy==1.13.0
186+
- seaborn==0.13.2
187+
- shapely==2.0.6
188+
- sparse==0.15.4
189+
- tenacity==9.0.0
190+
- tensorboard==2.15.2
191+
- tensorboard-data-server==0.7.2
192+
- tensorflow==2.15.0.post1
193+
- tensorflow-datasets==4.9.4
194+
- tensorflow-estimator==2.15.0
195+
- tensorflow-io-gcs-filesystem==0.36.0
196+
- tensorflow-metadata==1.15.0
197+
- termcolor==2.4.0
198+
- threadpoolctl==3.5.0
199+
- toml==0.10.2
200+
- tqdm==4.66.2
201+
- tzdata==2024.2
202+
- werkzeug==3.0.2
203+
- wrapt==1.14.1
204+
- xarray==2024.9.0
205+
- xesmf==0.8.8
206+
- zipp==3.18.1

bluemath_tk/deeplearning/generators/__init__.py

Whitespace-only changes.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
import keras.utils
3+
4+
class MockDataGenerator(keras.utils.Sequence):
5+
def __init__(self,
6+
num_images: int,
7+
input_frames: int = 1,
8+
output_frames: int = 1,
9+
batch_size: int = 8,
10+
input_height: int = 256,
11+
input_width: int = 256,
12+
output_height: int = 256,
13+
output_width: int = 256):
14+
15+
self.input_height = input_height
16+
self.input_width = input_width
17+
self.output_height = output_height
18+
self.output_width = output_width
19+
self.input_frames = input_frames
20+
self.output_frames = output_frames
21+
self.batch_size = batch_size
22+
self.num_images = num_images
23+
24+
@property
25+
def num_batches(self):
26+
return int(np.floor(self.num_images / self.batch_size))
27+
28+
def __len__(self) -> int:
29+
"""Returns the total number of batches."""
30+
return self.num_batches
31+
32+
def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
33+
"""Generates one batch of random data"""
34+
# Generate random input and output data
35+
inputs = np.random.rand(self.batch_size, self.input_height, self.input_width, self.input_frames)
36+
outputs = np.random.rand(self.batch_size, self.output_height, self.output_width, self.output_frames)
37+
38+
return inputs, outputs
39+
40+
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import keras
2+
from keras import layers
3+
from typing import List, Tuple
4+
5+
def ResidualBlock(width: int) -> layers.Layer:
6+
def apply(x: layers.Layer) -> layers.Layer:
7+
8+
input_width = x.shape[3]
9+
residual = x if input_width == width else layers.Conv2D(width, kernel_size=1)(x)
10+
11+
x = layers.LayerNormalization(axis=-1, center=True, scale=True)(x)
12+
x = layers.Conv2D(
13+
width, kernel_size=3, padding="same", activation=keras.activations.swish
14+
)(x)
15+
x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
16+
x = layers.Add()([x, residual])
17+
return x
18+
19+
return apply
20+
21+
def DownBlock(width: int, block_depth: int) -> layers.Layer:
22+
def apply(x: Tuple[layers.Layer, List[layers.Layer]]) -> Tuple[layers.Layer, List[layers.Layer]]:
23+
24+
x, skips = x
25+
for _ in range(block_depth):
26+
x = ResidualBlock(width)(x)
27+
skips.append(x)
28+
x = layers.AveragePooling2D(pool_size=2)(x)
29+
return x, skips
30+
31+
return apply
32+
33+
def UpBlock(width: int, block_depth: int) -> layers.Layer:
34+
def apply(x: Tuple[layers.Layer, List[layers.Layer]]) -> layers.Layer:
35+
36+
x, skips = x
37+
x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
38+
for _ in range(block_depth):
39+
x = layers.Concatenate()([x, skips.pop()])
40+
x = ResidualBlock(width)(x)
41+
return x
42+
return apply
43+
44+
def get_model(image_height: int,
45+
image_width: int,
46+
input_frames: int,
47+
output_frames: int,
48+
down_widths: List[int] = [64, 128, 256],
49+
up_widths: List[int] = [256, 128, 64],
50+
block_depth: int = 2) -> keras.Model:
51+
"""Builds the U-Net like model with residual blocks and skip connections."""
52+
53+
inputs = keras.Input(shape=(image_height, image_width, input_frames))
54+
x = layers.Conv2D(down_widths[0], kernel_size=1)(inputs)
55+
56+
skips = []
57+
for width in down_widths[:-1]:
58+
x, skips = DownBlock(width, block_depth)([x, skips])
59+
60+
for _ in range(block_depth):
61+
x = ResidualBlock(down_widths[-1])(x)
62+
63+
for width in up_widths[1:]:
64+
x = UpBlock(width, block_depth)([x, skips])
65+
66+
outputs = layers.Conv2D(output_frames, kernel_size=1, kernel_initializer="zeros")(x)
67+
68+
return keras.Model(inputs, outputs, name="residual_unet")
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import keras
2+
from models import resnet_model
3+
from generators.mockDataGenerator import MockDataGenerator
4+
5+
# instantiate model class (load memory)
6+
model = resnet_model.get_model(
7+
image_height = 64,
8+
image_width = 64,
9+
input_frames = 1,
10+
output_frames = 1)
11+
12+
# print summary of the model
13+
print(model.summary())
14+
15+
# instantiate generator class
16+
train_generator = MockDataGenerator(num_images=5000,
17+
input_height = 64,
18+
input_width = 64,
19+
output_height = 64,
20+
output_width = 64,
21+
batch_size=1)
22+
# define oprimizer
23+
optimizer=keras.optimizers.AdamW
24+
model.compile(
25+
optimizer=optimizer(
26+
learning_rate=1e-4, weight_decay=1e-5
27+
),
28+
loss=keras.losses.mean_squared_error,
29+
)
30+
31+
# start the train loop with the fit method
32+
history = model.fit(
33+
train_generator,
34+
initial_epoch = 0,
35+
epochs=20,
36+
steps_per_epoch=500)
37+
38+
39+
print("training complete")

tests/deeplearning/test_cnn.py

Whitespace-only changes.

0 commit comments

Comments
 (0)