11import os
22import itertools
33from typing import List
4+ import numpy as np
45from jinja2 import Environment , FileSystemLoader
56from ..core .models import BlueMathModel
67
78
89class BaseModelWrapper (BlueMathModel ):
910 """
10- Base class for model wrappers.
11+ Base class for numerical models wrappers.
1112
1213 Attributes
1314 ----------
@@ -21,6 +22,10 @@ class BaseModelWrapper(BlueMathModel):
2122 The directory where the output files will be saved.
2223 env : Environment
2324 The Jinja2 environment.
25+ cases_dirs : List[str]
26+ The list with cases directories.
27+ cases_context : List[dict]
28+ The list with cases context.
2429
2530 Methods
2631 -------
@@ -32,8 +37,12 @@ class BaseModelWrapper(BlueMathModel):
3237 from the input dictionary.
3338 render_file_from_template(template_name, context, output_filename=None)
3439 Render a file from a template.
35- build_cases()
36- Build the cases.
40+ write_array_in_file(array, filename)
41+ Write an array in a file.
42+ copy_files(src, dst)
43+ Copy file(s) from source to destination.
44+ build_cases(mode="all_combinations")
45+ Create the cases folders and render the input files.
3746 run_cases()
3847 Run the cases.
3948 """
@@ -44,6 +53,7 @@ def __init__(
4453 templates_name : List [str ],
4554 model_parameters : dict ,
4655 output_dir : str ,
56+ default_parameters : dict = None ,
4757 ):
4858 """
4959 Initialize the BaseModelWrapper.
@@ -58,14 +68,59 @@ def __init__(
5868 The parameters to be used in the templates.
5969 output_dir : str
6070 The directory where the output files will be saved.
71+ default_parameters : dict, optional
72+ The default parameters for the model. If None, the parameters will
73+ not be checked.
74+ Default is None.
6175 """
6276
6377 super ().__init__ ()
78+ if default_parameters is not None :
79+ self ._check_parameters_type (
80+ default_parameters = default_parameters , model_parameters = model_parameters
81+ )
6482 self .templates_dir = templates_dir
6583 self .templates_name = templates_name
6684 self .model_parameters = model_parameters
6785 self .output_dir = output_dir
6886 self .env = Environment (loader = FileSystemLoader (self .templates_dir ))
87+ self .cases_dirs : List [str ] = []
88+ self .cases_context : List [dict ] = []
89+
90+ def _check_parameters_type (self , default_parameters : dict , model_parameters : dict ):
91+ """
92+ Check if the parameters have the correct type.
93+
94+ Parameters
95+ ----------
96+ default_parameters : dict
97+ The default parameters for the model.
98+ model_parameters : dict
99+ The parameters to be used in the templates.
100+
101+ Raises
102+ ------
103+ ValueError
104+ If a parameter has the wrong type.
105+ """
106+
107+ for model_param , param_value in model_parameters .items ():
108+ if model_param not in default_parameters :
109+ self .logger .warning (
110+ f"Parameter { model_param } is not in the default_parameters"
111+ )
112+ else :
113+ if isinstance (param_value , list ) and all (
114+ isinstance (item , default_parameters [model_param ])
115+ for item in param_value
116+ ):
117+ self .logger .info (
118+ f"Parameter { model_param } has the correct type: { type (default_parameters [model_param ])} "
119+ )
120+ else :
121+ raise ValueError (
122+ f"Parameter { model_param } has the wrong type: { type (default_parameters [model_param ])} "
123+ )
69124
70125 def create_cases_context_one_by_one (self ):
71126 """
@@ -143,18 +198,89 @@ def render_file_from_template(
143198 with open (output_filename , "w" ) as f :
144199 f .write (rendered_content )
145200
146- def write_array_in_file (self , array , filename ):
201+ def write_array_in_file (self , array : np . ndarray , filename : str ):
147202 """
148203 Write an array in a file.
149204
150205 Parameters
151206 ----------
152- array : np.array
153- The array to be written.
207+ array : np.ndarray
208+ The array to be written. Can be 1D or 2D.
154209 filename : str
155210 The name of the file.
156211 """
157212
158213 with open (filename , "w" ) as f :
159- for item in array :
160- f .write (f"{ item } \n " )
214+ if array .ndim == 1 :
215+ for item in array :
216+ f .write (f"{ item } \n " )
217+ elif array .ndim == 2 :
218+ for row in array :
219+ f .write (" " .join (map (str , row )) + "\n " )
220+ else :
221+ raise ValueError ("Only 1D and 2D arrays are supported" )
222+
223+ def copy_files (self , src : str , dst : str ):
224+ """
225+ Copy file(s) from source to destination.
226+
227+ Parameters
228+ ----------
229+ src : str
230+ The source file.
231+ dst : str
232+ The destination file.
233+ """
234+
235+ if os .path .isdir (src ):
236+ os .makedirs (dst , exist_ok = True )
237+ for file in os .listdir (src ):
238+ with open (file , "r" ) as f :
239+ content = f .read ()
240+ with open (os .path .join (dst , file ), "w" ) as f :
241+ f .write (content )
242+ else :
243+ with open (src , "r" ) as f :
244+ content = f .read ()
245+ with open (dst , "w" ) as f :
246+ f .write (content )
247+
248+ def build_cases (self , mode : str = "all_combinations" ):
249+ """
250+ Create the cases folders and render the input files.
251+
252+ Parameters
253+ ----------
254+ mode : str, optional
255+ The mode to create the cases. Can be "all_combinations" or "one_by_one".
256+ Default is "all_combinations".
257+ """
258+
259+ if mode == "all_combinations" :
260+ self .cases_context = self .create_cases_context_all_combinations ()
261+ elif mode == "one_by_one" :
262+ self .cases_context = self .create_cases_context_one_by_one ()
263+ else :
264+ raise ValueError (f"Invalid mode to create cases: { mode } " )
265+ for case_num , case_context in enumerate (self .cases_context ):
266+ case_dir = os .path .join (self .output_dir , f"{ case_num :04} " )
267+ self .cases_dirs .append (case_dir )
268+ os .makedirs (case_dir , exist_ok = True )
269+ for template_name in self .templates_name :
270+ self .render_file_from_template (
271+ template_name = template_name ,
272+ context = case_context ,
273+ output_filename = os .path .join (case_dir , template_name ),
274+ )
275+ self .logger .info (
276+ f"{ len (self .cases_dirs )} cases created in { mode } mode and saved in { self .output_dir } "
277+ )
278+
279+ def run_cases (self ):
280+ """
281+ Run the cases.
282+ """
283+
284+ if self .cases_dirs :
285+ for case_dir in self .cases_dirs :
286+ self .run_model (case_dir )
0 commit comments