33from ._default_dicts import (
44 default_forcing_units ,
55 default_forcing_vars ,
6+ default_models ,
67)
78
89from .templating import render_settings
@@ -59,10 +60,11 @@ def __init__(
5960 forcing_files : Sequence [PathLike ] | PathLike = None , # type: ignore
6061 forcing_units : Dict [str , str ] = {},
6162 pet_method : str = "hamon" ,
62- model_number : Sequence [int ] | int = [ 7 , 37 ] , # HBV-96 and GR4J as default models
63+ model_number : Sequence [int ] | int = default_models , # HBV-96 and GR4J as default models
6364 forcing_time_zone : str = None ,
6465 model_time_zone : str = None ,
6566 streamflow : xr .DataArray | PathLike = None , # type: ignore
67+ elev_bands : PathLike | str = None , # type: ignore
6668 settings : Dict = {},
6769 ) -> 'FUSEWorkflow' :
6870 """
@@ -171,6 +173,9 @@ def __init__(
171173 for key in mandatory_settings :
172174 if key not in self .settings :
173175 raise ValueError (f"Missing mandatory setting: { key } " )
176+
177+ # assign elevation bands
178+ self .elev_bands = elev_bands
174179
175180 # assign an output object
176181 self .output_mat = None # Placeholder for output matrix
@@ -260,6 +265,7 @@ def run(self):
260265 self .init_forcing_files () # defines self.df
261266 self .init_pet () # defines self.pet
262267 self .init_streamflow () # defines self.forcing['q_obs']
268+ self .init_elev_bands () # defines self.elev_bands
263269
264270 # print a message about the timezones
265271 print (f"Using forcing time zone: { self .forcing_time_zone } " )
@@ -271,11 +277,7 @@ def save(self, save_path: PathLike): # type: ignore
271277 """Save the workflow output to a specified path."""
272278 if not hasattr (self , 'forcing' ) or self .forcing is None :
273279 raise ValueError ("No output matrix to save. Run the workflow first." )
274-
275- # Create .mat file using the scipy.io.savemat function
276- # the dataframe must be a cobination of self.df and self.pet
277- self .init_model_file (base_path = save_path )
278-
280+
279281 # check if the save_path exists, if not, create it
280282 if not os .path .exists (save_path ):
281283 os .makedirs (save_path , exist_ok = True )
@@ -284,17 +286,27 @@ def save(self, save_path: PathLike): # type: ignore
284286 os .makedirs (os .path .join (save_path , 'input' ), exist_ok = True )
285287 os .makedirs (os .path .join (save_path , 'output' ), exist_ok = True )
286288
289+ # Create .mat file using the scipy.io.savemat function
290+ # the dataframe must be a cobination of self.df and self.pet
291+ for model in self .model_number :
292+ content = self .init_model_file (base_path = save_path , model_n = model )
293+ # save the `fm_catch` content to a text file
294+ with open (os .path .join (save_path , f'{ self .name } _{ model } .txt' ), 'w' ) as f :
295+ f .write (content )
296+
287297 # save the forcing data to a NetCDF file
288298 self .forcing .to_netcdf (os .path .join (save_path , 'input' , f'{ self .name } _input.nc' ))
299+ # save the elevation bands to a NetCDF file
300+ self .elev_bands .to_netcdf (os .path .join (save_path , 'input' , f'{ self .name } _elev_bands.nc' ))
289301
290- # save the `fm_catch` content to a text file
291- with open (os .path .join (save_path , 'fm_catch.txt' ), 'w' ) as f :
292- f .write (self .fm_catch )
293-
294- # copy the defaults files to the settings directory
302+ # copy the defaults files and folders to the settings directory
295303 for f in glob .glob (os .path .join (setting_path , '*' )):
304+ # if a file, copy it to the settings directory
296305 if os .path .isfile (f ):
297306 shutil .copy2 (f , os .path .join (save_path , 'settings' ))
307+ # if a directory, copy it to the settings directory
308+ elif os .path .isdir (f ):
309+ shutil .copytree (f , os .path .join (save_path , 'settings' , os .path .basename (f )))
298310
299311 return f"Outputs saved to { save_path } "
300312
@@ -515,9 +527,80 @@ def init_streamflow(self):
515527
516528 return
517529
530+ def init_elev_bands (self ):
531+ """Initialize elevation bands."""
532+ # if extra information is provided in the input files
533+ if self .elev_bands is not None :
534+ # read the elevation bands from the provided path
535+ if isinstance (self .elev_bands , (PathLike , str )):
536+ elev_bands_file = pd .read_csv (self .elev_bands , index_col = 0 , header = 0 )
537+ elev_bands_value = elev_bands_file .iloc [0 , 0 ].values ()[0 ]
538+ else :
539+ raise TypeError ("elev_bands must be a PathLike or a string." )
540+
541+ else :
542+ elev_bands_value = 1000
543+
544+ # create a numpy.array of the elevation bands
545+ data = np .array ([[[elev_bands_value ]]]) # Shape: (elevation_band, latitude, longitude)
546+
547+ # default prec_frac and area_frac values
548+ prec_frac = np .array ([[[1.0 ]]]) # Shape: (elevation_band, latitude, longitude)
549+ area_frac = np .array ([[[1.0 ]]]) # Shape: (elevation_band, latitude, longitude)
550+
551+ ds = xr .Dataset (
552+ {
553+ 'mean_elev' : (['elevation_band' , 'latitude' , 'longitude' ], data ),
554+ 'area_frac' : (['elevation_band' , 'latitude' , 'longitude' ], area_frac ),
555+ 'prec_frac' : (['elevation_band' , 'latitude' , 'longitude' ], prec_frac ),
556+ },
557+ coords = {
558+ 'latitude' : self .forcing .latitude , # Assuming latitude is defined in the forcing data
559+ 'longitude' : self .forcing .longitude , # Assuming longitude is defined in the forcing data
560+ }
561+ )
562+
563+ # Assign the elevation_band as a coordinate variable
564+ ds = ds .assign_coords (elevation_band = np .arange (1 , len (data ) + 1 ))
565+
566+ # Add attributes
567+ ds ['mean_elev' ].attrs = {
568+ 'long_name' : 'Mid-point elevation of each elevation band' ,
569+ 'units' : 'm asl'
570+ }
571+ ds ['area_frac' ].attrs = {
572+ 'long_name' : 'Fraction of the catchment covered by each elevation band' ,
573+ 'units' : 'dimensionless'
574+ }
575+ ds ['prec_frac' ].attrs = {
576+ 'long_name' : 'Fraction of catchment precipitation that falls on each elevation band - same as area_frac' ,
577+ 'units' : 'dimensionless'
578+ }
579+ ds ['elevation_band' ].attrs = {
580+ 'long_name' : 'elevation_band' ,
581+ 'units' : 'dimensionless'
582+ }
583+ ds ['latitude' ].attrs = {
584+ 'long_name' : 'latitude' ,
585+ 'units' : 'degreesN'
586+ }
587+ ds ['longitude' ].attrs = {
588+ 'long_name' : 'longitude' ,
589+ 'units' : 'degreesE'
590+ }
591+ ds ['elevation_band' ].attrs = {
592+ 'long_name' : 'elevation_band' ,
593+ 'units' : 'dimensionless'
594+ }
595+
596+ self .elev_bands = ds
597+
598+ return
599+
518600 def init_model_file (
519601 self ,
520602 base_path : str | PathLike , # type: ignore
603+ model_n : int ,
521604 ) -> None :
522605 """Initialize the model file for the given model number."""
523606
@@ -550,13 +633,14 @@ def init_model_file(
550633 }
551634
552635 # create the content of the model file
553- self . fm_catch = render_settings (
636+ fm_catch = render_settings (
554637 paths = paths_dict ,
555638 dates = date_dict ,
639+ model = model_n
556640 )
557641
558642 # return the rendered content
559- return
643+ return fm_catch
560644
561645 def _format_dates (
562646 self ,
0 commit comments