Skip to content

Commit 2b1e449

Browse files
authored
Merge pull request #22 from GeoOcean/feature/wrappers
Feature/wrappers
2 parents 8e6f9f8 + 60db0ee commit 2b1e449

File tree

11 files changed

+495
-97
lines changed

11 files changed

+495
-97
lines changed

bluemath_tk/deeplearning/generators/mockDataGenerator.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import numpy as np
22
import keras.utils
33

4+
45
class MockDataGenerator(keras.utils.Sequence):
5-
def __init__(self,
6-
num_images: int,
7-
input_frames: int = 1,
8-
output_frames: int = 1,
9-
batch_size: int = 8,
10-
input_height: int = 256,
11-
input_width: int = 256,
12-
output_height: int = 256,
13-
output_width: int = 256):
14-
6+
def __init__(
7+
self,
8+
num_images: int,
9+
input_frames: int = 1,
10+
output_frames: int = 1,
11+
batch_size: int = 8,
12+
input_height: int = 256,
13+
input_width: int = 256,
14+
output_height: int = 256,
15+
output_width: int = 256,
16+
):
1517
self.input_height = input_height
1618
self.input_width = input_width
1719
self.output_height = output_height
@@ -32,9 +34,11 @@ def __len__(self) -> int:
3234
def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
3335
"""Generates one batch of random data"""
3436
# Generate random input and output data
35-
inputs = np.random.rand(self.batch_size, self.input_height, self.input_width, self.input_frames)
36-
outputs = np.random.rand(self.batch_size, self.output_height, self.output_width, self.output_frames)
37+
inputs = np.random.rand(
38+
self.batch_size, self.input_height, self.input_width, self.input_frames
39+
)
40+
outputs = np.random.rand(
41+
self.batch_size, self.output_height, self.output_width, self.output_frames
42+
)
3743

3844
return inputs, outputs
39-
40-

bluemath_tk/deeplearning/resnet.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import keras
2+
from models import resnet_model
3+
from generators.mockDataGenerator import MockDataGenerator
4+
5+
# instantiate model class (load memory)
6+
model = resnet_model.get_model(
7+
image_height=64, image_width=64, input_frames=1, output_frames=1
8+
)
9+
10+
# print summary of the model
11+
print(model.summary())
12+
13+
# instantiate generator class
14+
train_generator = MockDataGenerator(
15+
num_images=5000,
16+
input_height=64,
17+
input_width=64,
18+
output_height=64,
19+
output_width=64,
20+
batch_size=1,
21+
)
22+
# define oprimizer
23+
optimizer = keras.optimizers.AdamW
24+
model.compile(
25+
optimizer=optimizer(learning_rate=1e-4, weight_decay=1e-5),
26+
loss=keras.losses.mean_squared_error,
27+
)
28+
29+
# start the train loop with the fit method
30+
history = model.fit(train_generator, initial_epoch=0, epochs=20, steps_per_epoch=500)
31+
32+
33+
print("training complete")

bluemath_tk/deeplearning/test_resnet.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

bluemath_tk/downloaders/copernicus/copernicus_downloader.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ class CopernicusDownloader(BlueMathDownloader):
3838
products_configs = {
3939
"ERA5": json.load(
4040
# open(os.path.join(os.path.dirname(__file__), "ERA5", "ERA5_config.json"))
41-
open("/home/grupos/geocean/tausiaj/BlueMath_tk/bluemath_tk/downloaders/copernicus/ERA5/ERA5_config.json")
41+
open(
42+
"/home/grupos/geocean/tausiaj/BlueMath_tk/bluemath_tk/downloaders/copernicus/ERA5/ERA5_config.json"
43+
)
4244
)
4345
}
4446

@@ -308,15 +310,16 @@ def download_data_era5(
308310
""")
309311

310312
try:
311-
312313
if self.check or not force:
313314
if os.path.exists(output_nc_file):
314315
self.logger.debug(
315316
f"Checking {output_nc_file} file is complete"
316317
)
317318
try:
318319
nc = xr.open_dataset(output_nc_file)
319-
_, last_day = calendar.monthrange(int(year), int(month))
320+
_, last_day = calendar.monthrange(
321+
int(year), int(month)
322+
)
320323
last_hour = f"{year}-{int(month):02d}-{last_day}T23"
321324
last_hour_nc = str(nc.valid_time[-1].values)
322325
nc.close()
@@ -325,7 +328,9 @@ def download_data_era5(
325328
f"{output_nc_file} ends at {last_hour_nc} instead of {last_hour}"
326329
)
327330
if self.check:
328-
NOT_fullly_downloaded_files.append(output_nc_file)
331+
NOT_fullly_downloaded_files.append(
332+
output_nc_file
333+
)
329334
else:
330335
self.logger.debug(
331336
f"Downloading: {variable} to {output_nc_file} because it is not complete"
@@ -335,15 +340,19 @@ def download_data_era5(
335340
request=template_for_variable,
336341
target=output_nc_file,
337342
)
338-
fully_downloaded_files.append(output_nc_file)
343+
fully_downloaded_files.append(
344+
output_nc_file
345+
)
339346
else:
340347
fully_downloaded_files.append(output_nc_file)
341348
except Exception as e:
342349
self.logger.error(
343350
f"Error was raised opening {output_nc_file}, re-downloading..."
344351
)
345352
if self.check:
346-
NOT_fullly_downloaded_files.append(output_nc_file)
353+
NOT_fullly_downloaded_files.append(
354+
output_nc_file
355+
)
347356
else:
348357
self.logger.debug(
349358
f"Downloading: {variable} to {output_nc_file} because it is not complete"
@@ -378,17 +387,16 @@ def download_data_era5(
378387
fully_downloaded_files.append(output_nc_file)
379388

380389
except Exception as e:
381-
382390
self.logger.error(f"""
383391
384392
Skippping {output_nc_file} for {e}
385393
386394
""")
387395
error_files.append(output_nc_file)
388396

389-
fully_downloaded_files_str = '\n'.join(fully_downloaded_files)
390-
NOT_fullly_downloaded_files_str = '\n'.join(NOT_fullly_downloaded_files)
391-
error_files = '\n'.join(error_files)
397+
fully_downloaded_files_str = "\n".join(fully_downloaded_files)
398+
NOT_fullly_downloaded_files_str = "\n".join(NOT_fullly_downloaded_files)
399+
error_files = "\n".join(error_files)
392400

393401
return f"""
394402
Fully downloaded files:

0 commit comments

Comments
 (0)