Skip to content

Commit 46888a9

Browse files
authored
Merge pull request #24 from GeoOcean/feature/wrappers
Feature/wrappers
2 parents 83d3fbf + 85ae2ce commit 46888a9

File tree

9 files changed

+439
-52
lines changed

9 files changed

+439
-52
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
import xarray as xr
3+
import keras.utils
4+
5+
6+
class DataGenerator(keras.utils.Sequence):
7+
def __init__(
8+
self,
9+
msl_path,
10+
tp_path,
11+
num_images,
12+
sequential=False,
13+
batch_size=1,
14+
):
15+
# create memory-mapped files for high_res and low_res datasets
16+
17+
# inputs
18+
19+
self.msl = xr.open_dataarray(msl_path).values[:, :64, :64]
20+
21+
# outputs
22+
23+
self.tp = xr.open_dataarray(tp_path).values[:, :64, :64]
24+
25+
# set boolean for sequential or random dataset
26+
self.sequential = sequential
27+
# counter for keeping track of seuquential generator
28+
self.counter = 0
29+
self.num_images = num_images
30+
# set sequence len
31+
# flag for diffusion/unet
32+
self.batch_size = batch_size
33+
self.num_samples = self.msl.shape[0]
34+
35+
@property
36+
def num_batches(self):
37+
return int(np.floor(self.num_images / self.batch_size))
38+
39+
def __len__(self):
40+
return self.num_batches
41+
42+
def min_max_normalize(self, arr, min, max):
43+
normalized = (arr - min) / (max - min)
44+
result = np.where(np.isnan(arr), np.nan, normalized)
45+
return result
46+
47+
# must be called to restart the sequential
48+
def counter_reset(self):
49+
self.counter = 0
50+
51+
def __getitem__(self, idx):
52+
53+
# prepare the resulting array
54+
inputs = np.zeros((self.batch_size, 64, 64, 1))
55+
outputs = np.zeros((self.batch_size, 64, 64, 1))
56+
57+
# random path
58+
if self.sequential == False:
59+
# compose the batch one element at the time
60+
for i in range(self.batch_size):
61+
# get a random number in range
62+
random = np.random.randint(0, self.num_samples - 1)
63+
64+
# inputs
65+
inputs[i, :, :, 0] = self.min_max_normalize(
66+
self.msl[random], 95680, 104256
67+
)
68+
69+
# outputs
70+
outputs[i, :, :, 0] = self.min_max_normalize(
71+
self.tp[random], 0.0, 0.02197266
72+
)
73+
74+
# sequential path
75+
if self.sequential == True:
76+
# compose the batch one element at the time
77+
for i in range(self.batch_size):
78+
# inputs
79+
inputs[i, :, :, 0] = self.min_max_normalize(
80+
self.msl[self.counter], 95680, 104256
81+
)
82+
83+
# outputs
84+
outputs[i, :, :, 0] = self.min_max_normalize(
85+
self.tp[self.counter], 0.0, 0.02197266
86+
)
87+
88+
self.counter = self.counter + 1
89+
return inputs, outputs

bluemath_tk/deeplearning/resnet.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import keras
22
from models import resnet_model
3-
from generators.mockDataGenerator import MockDataGenerator
3+
from generators.ncDataGenerator import DataGenerator
44

55
# instantiate model class (load memory)
66
model = resnet_model.get_model(
@@ -11,14 +11,17 @@
1111
print(model.summary())
1212

1313
# instantiate generator class
14-
train_generator = MockDataGenerator(
15-
num_images=5000,
16-
input_height=64,
17-
input_width=64,
18-
output_height=64,
19-
output_width=64,
14+
train_generator = DataGenerator(
15+
msl_path="/home/tausiaj/DATA/Comparison-ERA5/msl_spain.nc",
16+
tp_path="/home/tausiaj/DATA/Comparison-ERA5/tp_spain.nc",
17+
num_images=8760,
18+
sequential=False,
2019
batch_size=1,
2120
)
21+
22+
a, b = train_generator.__getitem__(1)
23+
print(a.shape)
24+
print(b.shape)
2225
# define oprimizer
2326
optimizer = keras.optimizers.AdamW
2427
model.compile(

bluemath_tk/wrappers/_base_wrappers.py

Lines changed: 134 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
22
import itertools
33
from typing import List
4+
import numpy as np
45
from jinja2 import Environment, FileSystemLoader
56
from ..core.models import BlueMathModel
67

78

89
class BaseModelWrapper(BlueMathModel):
910
"""
10-
Base class for model wrappers.
11+
Base class for numerical models wrappers.
1112
1213
Attributes
1314
----------
@@ -21,6 +22,10 @@ class BaseModelWrapper(BlueMathModel):
2122
The directory where the output files will be saved.
2223
env : Environment
2324
The Jinja2 environment.
25+
cases_dirs : List[str]
26+
The list with cases directories.
27+
cases_context : List[dict]
28+
The list with cases context.
2429
2530
Methods
2631
-------
@@ -32,8 +37,12 @@ class BaseModelWrapper(BlueMathModel):
3237
from the input dictionary.
3338
render_file_from_template(template_name, context, output_filename=None)
3439
Render a file from a template.
35-
build_cases()
36-
Build the cases.
40+
write_array_in_file(array, filename)
41+
Write an array in a file.
42+
copy_files(src, dst)
43+
Copy file(s) from source to destination.
44+
build_cases(mode="all_combinations")
45+
Create the cases folders and render the input files.
3746
run_cases()
3847
Run the cases.
3948
"""
@@ -44,6 +53,7 @@ def __init__(
4453
templates_name: List[str],
4554
model_parameters: dict,
4655
output_dir: str,
56+
default_parameters: dict = None,
4757
):
4858
"""
4959
Initialize the BaseModelWrapper.
@@ -58,14 +68,59 @@ def __init__(
5868
The parameters to be used in the templates.
5969
output_dir : str
6070
The directory where the output files will be saved.
71+
default_parameters : dict, optional
72+
The default parameters for the model. If None, the parameters will
73+
not be checked.
74+
Default is None.
6175
"""
6276

6377
super().__init__()
78+
if default_parameters is not None:
79+
self._check_parameters_type(
80+
default_parameters=default_parameters, model_parameters=model_parameters
81+
)
6482
self.templates_dir = templates_dir
6583
self.templates_name = templates_name
6684
self.model_parameters = model_parameters
6785
self.output_dir = output_dir
6886
self.env = Environment(loader=FileSystemLoader(self.templates_dir))
87+
self.cases_dirs: List[str] = []
88+
self.cases_context: List[dict] = []
89+
90+
def _check_parameters_type(self, default_parameters: dict, model_parameters: dict):
91+
"""
92+
Check if the parameters have the correct type.
93+
94+
Parameters
95+
----------
96+
default_parameters : dict
97+
The default parameters for the model.
98+
model_parameters : dict
99+
The parameters to be used in the templates.
100+
101+
Raises
102+
------
103+
ValueError
104+
If a parameter has the wrong type.
105+
"""
106+
107+
for model_param, param_value in model_parameters.items():
108+
if model_param not in default_parameters:
109+
self.logger.warning(
110+
f"Parameter {model_param} is not in the default_parameters"
111+
)
112+
else:
113+
if isinstance(param_value, list) and all(
114+
isinstance(item, default_parameters[model_param])
115+
for item in param_value
116+
):
117+
self.logger.info(
118+
f"Parameter {model_param} has the correct type: {type(default_parameters[model_param])}"
119+
)
120+
else:
121+
raise ValueError(
122+
f"Parameter {model_param} has the wrong type: {type(default_parameters[model_param])}"
123+
)
69124

70125
def create_cases_context_one_by_one(self):
71126
"""
@@ -143,18 +198,89 @@ def render_file_from_template(
143198
with open(output_filename, "w") as f:
144199
f.write(rendered_content)
145200

146-
def write_array_in_file(self, array, filename):
201+
def write_array_in_file(self, array: np.ndarray, filename: str):
147202
"""
148203
Write an array in a file.
149204
150205
Parameters
151206
----------
152-
array : np.array
153-
The array to be written.
207+
array : np.ndarray
208+
The array to be written. Can be 1D or 2D.
154209
filename : str
155210
The name of the file.
156211
"""
157212

158213
with open(filename, "w") as f:
159-
for item in array:
160-
f.write(f"{item}\n")
214+
if array.ndim == 1:
215+
for item in array:
216+
f.write(f"{item}\n")
217+
elif array.ndim == 2:
218+
for row in array:
219+
f.write(" ".join(map(str, row)) + "\n")
220+
else:
221+
raise ValueError("Only 1D and 2D arrays are supported")
222+
223+
def copy_files(self, src: str, dst: str):
224+
"""
225+
Copy file(s) from source to destination.
226+
227+
Parameters
228+
----------
229+
src : str
230+
The source file.
231+
dst : str
232+
The destination file.
233+
"""
234+
235+
if os.path.isdir(src):
236+
os.makedirs(dst, exist_ok=True)
237+
for file in os.listdir(src):
238+
with open(file, "r") as f:
239+
content = f.read()
240+
with open(os.path.join(dst, file), "w") as f:
241+
f.write(content)
242+
else:
243+
with open(src, "r") as f:
244+
content = f.read()
245+
with open(dst, "w") as f:
246+
f.write(content)
247+
248+
def build_cases(self, mode: str = "all_combinations"):
249+
"""
250+
Create the cases folders and render the input files.
251+
252+
Parameters
253+
----------
254+
mode : str, optional
255+
The mode to create the cases. Can be "all_combinations" or "one_by_one".
256+
Default is "all_combinations".
257+
"""
258+
259+
if mode == "all_combinations":
260+
self.cases_context = self.create_cases_context_all_combinations()
261+
elif mode == "one_by_one":
262+
self.cases_context = self.create_cases_context_one_by_one()
263+
else:
264+
raise ValueError(f"Invalid mode to create cases: {mode}")
265+
for case_num, case_context in enumerate(self.cases_context):
266+
case_dir = os.path.join(self.output_dir, f"{case_num:04}")
267+
self.cases_dirs.append(case_dir)
268+
os.makedirs(case_dir, exist_ok=True)
269+
for template_name in self.templates_name:
270+
self.render_file_from_template(
271+
template_name=template_name,
272+
context=case_context,
273+
output_filename=os.path.join(case_dir, template_name),
274+
)
275+
self.logger.info(
276+
f"{len(self.cases_dirs)} cases created in {mode} mode and saved in {self.output_dir}"
277+
)
278+
279+
def run_cases(self):
280+
"""
281+
Run the cases.
282+
"""
283+
284+
if self.cases_dirs:
285+
for case_dir in self.cases_dirs:
286+
self.run_model(case_dir)

0 commit comments

Comments
 (0)