44import matplotlib .pyplot as plt
55import numpy as np
66
7+ from polaris .ocean .convergence import (
8+ get_resolution_for_task ,
9+ get_timestep_for_task ,
10+ )
711from polaris .ocean .model import OceanIOStep
8- from polaris .ocean .resolution import resolution_to_subdir
912from polaris .ocean .tasks .manufactured_solution .exact_solution import (
1013 ExactSolution ,
1114)
@@ -19,10 +22,31 @@ class Viz(OceanIOStep):
1922
2023 Attributes
2124 ----------
22- resolutions : list of float
23- The resolutions of the meshes that have been run
25+ dependencies_dict : dict of dict of polaris.Steps
26+ The dependencies of this step must be given as separate keys in the
27+ dict:
28+
29+ mesh : dict of polaris.Steps
30+ Keys of the dict correspond to `refinement_factors`
31+ Values of the dict are polaris.Steps, which must have the
32+ attribute `path`, the path to `base_mesh.nc` of that
33+ resolution
34+ init : dict of polaris.Steps
35+ Keys of the dict correspond to `refinement_factors`
36+ Values of the dict are polaris.Steps, which must have the
37+ attribute `path`, the path to `initial_state.nc` of that
38+ resolution
39+ forward : dict of polaris.Steps
40+ Keys of the dict correspond to `refinement_factors`
41+ Values of the dict are polaris.Steps, which must have the
42+ attribute `path`, the path to `forward.nc` of that
43+ resolution
44+
45+ refinement : str
46+ Refinement type. One of 'space', 'time' or 'both' indicating both
47+ space and time
2448 """
25- def __init__ (self , component , resolutions , taskdir ):
49+ def __init__ (self , component , dependencies , taskdir , refinement = 'both' ):
2650 """
2751 Create the step
2852
@@ -31,37 +55,82 @@ def __init__(self, component, resolutions, taskdir):
3155 component : polaris.Component
3256 The component the step belongs to
3357
34- resolutions : list of float
35- The resolutions of the meshes that have been run
58+ dependencies : dict of dict of polaris.Steps
59+ The dependencies of this step must be given as separate keys in the
60+ dict:
61+
62+ mesh : dict of polaris.Steps
63+ Keys of the dict correspond to `refinement_factors`
64+ Values of the dict are polaris.Steps, which must have the
65+ attribute `path`, the path to `base_mesh.nc` of that
66+ resolution
67+ init : dict of polaris.Steps
68+ Keys of the dict correspond to `refinement_factors`
69+ Values of the dict are polaris.Steps, which must have the
70+ attribute `path`, the path to `initial_state.nc` of that
71+ resolution
72+ forward : dict of polaris.Steps
73+ Keys of the dict correspond to `refinement_factors`
74+ Values of the dict are polaris.Steps, which must have the
75+ attribute `path`, the path to `forward.nc` of that
76+ resolution
3677
3778 taskdir : str
3879 The subdirectory that the task belongs to
80+
81+ refinement : str, optional
82+ Refinement type. One of 'space', 'time' or 'both' indicating both
83+ space and time
3984 """
4085 super ().__init__ (component = component , name = 'viz' , indir = taskdir )
41- self .resolutions = resolutions
86+ self .dependencies_dict = dependencies
87+ self .refinement = refinement
88+ self .add_output_file ('comparison.png' )
4289
43- for resolution in resolutions :
44- mesh_name = resolution_to_subdir (resolution )
90+ def setup (self ):
91+ """
92+ Add input files based on resolutions, which may have been changed by
93+ user config options
94+ """
95+ super ().setup ()
96+ config = self .config
97+ dependencies = self .dependencies_dict
98+
99+ if self .refinement == 'time' :
100+ option = 'refinement_factors_time'
101+ else :
102+ option = 'refinement_factors_space'
103+ refinement_factors = config .getlist ('convergence' , option ,
104+ dtype = float )
105+
106+ for refinement_factor in refinement_factors :
107+ base_mesh = dependencies ['mesh' ][refinement_factor ]
108+ init = dependencies ['init' ][refinement_factor ]
109+ forward = dependencies ['forward' ][refinement_factor ]
45110 self .add_input_file (
46- filename = f'mesh_ { mesh_name } .nc' ,
47- target = f'../init/ { mesh_name } /culled_mesh .nc' )
111+ filename = f'mesh_r { refinement_factor :02g } .nc' ,
112+ work_dir_target = f'{ base_mesh . path } /base_mesh .nc' )
48113 self .add_input_file (
49- filename = f'init_ { mesh_name } .nc' ,
50- target = f'../init/ { mesh_name } /initial_state.nc' )
114+ filename = f'init_r { refinement_factor :02g } .nc' ,
115+ work_dir_target = f'{ init . path } /initial_state.nc' )
51116 self .add_input_file (
52- filename = f'output_{ mesh_name } .nc' ,
53- target = f'../forward/{ mesh_name } /output.nc' )
54-
55- self .add_output_file ('comparison.png' )
117+ filename = f'output_r{ refinement_factor :02g} .nc' ,
118+ work_dir_target = f'{ forward .path } /output.nc' )
56119
57120 def run (self ):
58121 """
59122 Run this step of the test case
60123 """
61124 plt .switch_backend ('Agg' )
62125 config = self .config
63- resolutions = self .resolutions
64- nres = len (resolutions )
126+ if self .refinement == 'time' :
127+ option = 'refinement_factors_time'
128+ else :
129+ option = 'refinement_factors_space'
130+ refinement_factors = config .getlist ('convergence' , option ,
131+ dtype = float )
132+
133+ nres = len (refinement_factors )
65134
66135 section = config ['manufactured_solution' ]
67136 eta0 = section .getfloat ('ssh_amplitude' )
@@ -70,11 +139,14 @@ def run(self):
70139 fig , axes = plt .subplots (nrows = nres , ncols = 3 , figsize = (12 , 2 * nres ))
71140 rmse = []
72141 error_range = None
73- for i , res in enumerate (resolutions ):
74- mesh_name = resolution_to_subdir (res )
75- ds_mesh = self .open_model_dataset (f'mesh_{ mesh_name } .nc' )
76- ds_init = self .open_model_dataset (f'init_{ mesh_name } .nc' )
77- ds = self .open_model_dataset (f'output_{ mesh_name } .nc' )
142+
143+ for i , refinement_factor in enumerate (refinement_factors ):
144+ ds_mesh = self .open_model_dataset (
145+ f'mesh_r{ refinement_factor :02g} .nc' )
146+ ds_init = self .open_model_dataset (
147+ f'init_r{ refinement_factor :02g} .nc' )
148+ ds = self .open_model_dataset (
149+ f'output_r{ refinement_factor :02g} .nc' )
78150 exact = ExactSolution (config , ds_init )
79151
80152 t0 = datetime .datetime .strptime (ds .xtime .values [0 ].decode (),
@@ -110,8 +182,13 @@ def run(self):
110182 axes [0 , 2 ].set_title ('Error (Numerical - Analytical)' )
111183
112184 pad = 5
113- for ax , res in zip (axes [:, 0 ], resolutions ):
114- ax .annotate (f'{ res } km' , xy = (0 , 0.5 ),
185+ for ax , refinement_factor in zip (axes [:, 0 ], refinement_factors ):
186+ timestep , _ = get_timestep_for_task (
187+ config , refinement_factor , refinement = self .refinement )
188+ resolution = get_resolution_for_task (
189+ config , refinement_factor , refinement = self .refinement )
190+
191+ ax .annotate (f'{ resolution } km\n { timestep } s' , xy = (0 , 0.5 ),
115192 xytext = (- ax .yaxis .labelpad - pad , 0 ),
116193 xycoords = ax .yaxis .label , textcoords = 'offset points' ,
117194 size = 'large' , ha = 'right' , va = 'center' )
0 commit comments