diff --git a/polaris/ocean/__init__.py b/polaris/ocean/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/polaris/tasks/ocean/add_tasks.py b/polaris/tasks/ocean/add_tasks.py index 975e3d277e..8bb4ac5ae3 100644 --- a/polaris/tasks/ocean/add_tasks.py +++ b/polaris/tasks/ocean/add_tasks.py @@ -1,6 +1,7 @@ from polaris.tasks.ocean.baroclinic_channel import add_baroclinic_channel_tasks from polaris.tasks.ocean.barotropic_gyre import add_barotropic_gyre_tasks from polaris.tasks.ocean.cosine_bell import add_cosine_bell_tasks +from polaris.tasks.ocean.customizable_viz import add_customizable_viz_tasks from polaris.tasks.ocean.external_gravity_wave import ( add_external_gravity_wave_tasks as add_external_gravity_wave_tasks, ) @@ -44,6 +45,7 @@ def add_ocean_tasks(component): add_single_column_tasks(component=component) # spherical tasks + add_customizable_viz_tasks(component=component) add_cosine_bell_tasks(component=component) add_external_gravity_wave_tasks(component=component) add_geostrophic_tasks(component=component) diff --git a/polaris/tasks/ocean/customizable_viz/__init__.py b/polaris/tasks/ocean/customizable_viz/__init__.py new file mode 100644 index 0000000000..4d1beacd59 --- /dev/null +++ b/polaris/tasks/ocean/customizable_viz/__init__.py @@ -0,0 +1,51 @@ +import os + +from polaris import Task +from polaris.config import PolarisConfigParser as PolarisConfigParser +from polaris.tasks.ocean.customizable_viz.viz_horiz_field import ( + VizHorizField as VizHorizField, +) +from polaris.tasks.ocean.customizable_viz.viz_transect import ( + VizTransect as VizTransect, +) + + +def add_customizable_viz_tasks(component): + customizable_viz_task = CustomizableViz(component=component) + component.add_task(customizable_viz_task) + + +class CustomizableViz(Task): + """ + A customizable visualization task for MPAS-Ocean output + """ + + def __init__(self, component): + basedir = 'customizable_viz' + name = 'customizable_viz' + super().__init__(component=component, name=name, subdir=basedir) + + config_filename = 'customizable_viz.cfg' + config = PolarisConfigParser( + filepath=os.path.join(component.name, config_filename) + ) + config.add_from_package( + 'polaris.tasks.ocean.customizable_viz', config_filename + ) + self.set_shared_config(config, link=config_filename) + + viz_step = VizHorizField( + component=component, + name='viz_horiz_field', + indir=self.subdir, + ) + viz_step.set_shared_config(config, link=config_filename) + self.add_step(viz_step, run_by_default=True) + + transect_step = VizTransect( + component=component, + name='viz_transect', + indir=self.subdir, + ) + transect_step.set_shared_config(config, link=config_filename) + self.add_step(transect_step, run_by_default=False) diff --git a/polaris/tasks/ocean/customizable_viz/customizable_viz.cfg b/polaris/tasks/ocean/customizable_viz/customizable_viz.cfg new file mode 100644 index 0000000000..af8bfa985b --- /dev/null +++ b/polaris/tasks/ocean/customizable_viz/customizable_viz.cfg @@ -0,0 +1,57 @@ +[customizable_viz] + +# Mesh file, absolute file path +mesh_file = /lcrc/group/e3sm/data/inputdata/ocn/mpas-o/IcoswISC30E3r5/mpaso.IcoswISC30E3r5.rstFromG-chrysalis.20231121.nc + +# Data file, absolute file path +input_file = /lcrc/group/e3sm/data/inputdata/ocn/mpas-o/IcoswISC30E3r5/mpaso.IcoswISC30E3r5.rstFromG-chrysalis.20231121.nc + +[customizable_viz_horiz_field] + +# Projection to use for the horizontal field plot, must be supported by mosaic +projection = PlateCarree + +# Fields to plot, comma-separated +variables = layerThickness + +# Bounding box for the plot, global by default +min_latitude = -90 + +max_latitude = 90 + +min_longitude = 0 + +max_longitude = 360 + +# Optional additional arguments to provide to the colormap norm +norm_args = {} + +# the type of norm used in the colormap +norm_type = linear + +# Note: for some projections, choosing a different central longitude may +# result in cells spanning the plot boundary +central_longitude = 180. + +# The depth in m below the surface to use for layer selection +# Only one vertical level will be selected and its depth may vary +z_target = 0. + +[customizable_viz_transect] + +# Fields to plot, comma-separated +variables = salinity + +# The start and end coordinates for the transect +# where x is longitude and y is latitude +x_start = -5.0 +y_start = -65.0 +x_end = 10.0 +y_end = -65.0 + +# Optional limits for colormap scaling +vmin = None +vmax = None + +# Color to use for interfaces between vertical levels +layer_interface_color = None diff --git a/polaris/tasks/ocean/customizable_viz/viz_horiz_field.py b/polaris/tasks/ocean/customizable_viz/viz_horiz_field.py new file mode 100644 index 0000000000..947514cf4a --- /dev/null +++ b/polaris/tasks/ocean/customizable_viz/viz_horiz_field.py @@ -0,0 +1,197 @@ +import cmocean # noqa: F401 +import numpy as np + +from polaris.ocean.model import OceanIOStep +from polaris.viz import ( + determine_time_variable, + get_viz_defaults, + plot_global_mpas_field, +) + + +class VizHorizField(OceanIOStep): + def __init__(self, component, name, indir): + super().__init__(component=component, name=name, indir=indir) + + def runtime_setup(self): + section = self.config['customizable_viz'] + self.mesh_file = section.get('mesh_file') + self.input_file = section.get('input_file') + + section_name = 'customizable_viz_horiz_field' + self.variables = self.config.getlist( + section_name, 'variables', dtype=str + ) + if not self.variables: + raise ValueError( + f'No variables specified in the {section_name} section of ' + 'the config file.' + ) + + def run(self): # noqa:C901 + section_name = 'customizable_viz_horiz_field' + section = self.config[section_name] + + # Determine the projection from the config file + projection_name = section.get('projection') + central_longitude = section.getfloat('central_longitude') + + # Descriptor is none for the first variable and assigned thereafter + descriptor = None + + ds_mesh = self.open_model_dataset( + self.mesh_file, decode_timedelta=False + ) + min_latitude = section.getfloat('min_latitude') + max_latitude = section.getfloat('max_latitude') + min_longitude = section.getfloat('min_longitude') + max_longitude = section.getfloat('max_longitude') + lat_cell = np.rad2deg(ds_mesh['latCell']) + lon_cell = np.rad2deg(ds_mesh['lonCell']) + if min_longitude < 0.0 and lon_cell.min().values > 0.0: + max_longitude_copy = max_longitude + max_longitude = 360.0 - min_longitude + min_longitude = max_longitude_copy + cell_indices = np.where( + (ds_mesh.maxLevelCell > 0) + & (lat_cell >= min_latitude) + & (lat_cell <= max_latitude) + & (lon_cell >= min_longitude) + & (lon_cell <= max_longitude) + ) + if len(cell_indices[0]) == 0: + raise ValueError( + f'No cells of {ds_mesh.sizes["nCells"]} cells found within the' + ' specified lat/lon bounds. Please adjust the min/max ' + 'latitude/longitude values to be within the bounds of the ' + f'dataset: latitude ' + f'{lat_cell.min().values},{lat_cell.max().values} \n' + f'longitude {lon_cell.min().values},{lon_cell.max().values}' + ) + print( + f'Using {len(cell_indices[0])} cells of ' + f'{ds_mesh.sizes["nCells"]} cells in the mesh' + ) + ds_mesh = ds_mesh.isel(nCells=cell_indices[0]) + z_target = section.getfloat('z_target') + z_bottom = ds_mesh['restingThickness'].cumsum(dim='nVertLevels') + dz = z_bottom.mean(dim='nCells') - z_target + z_index = np.argmin(np.abs(dz.values)) + if dz[z_index] > 0 and z_index > 0: + z_index -= 1 + z_mean = z_bottom.mean(dim='nCells')[z_index].values + print( + f'Using z_index {z_index} for z_target {z_target} ' + f'with mean depth {z_mean} ' + ) + + ds = self.open_model_dataset(self.input_file, decode_timedelta=False) + + if 'Time' in ds.sizes: + t_index = 0 + # TODO support different time selection from config file + ds = ds.isel(Time=t_index) + + prefix, time_variable = determine_time_variable(ds) + if time_variable is not None: + start_time = ds[time_variable].values[0] + start_time = start_time.decode() + time_stamp = f'_{start_time.split("_")[0]}' + else: + time_stamp = '' + + ds = ds.isel(nCells=cell_indices[0]) + if ds.sizes['nCells'] != ds_mesh.sizes['nCells']: + raise ValueError( + f'Number of cells in the mesh {ds_mesh.sizes["nCells"]} ' + f'and input {ds.sizes["nCells"]} do not match. ' + ) + viz_dict = get_viz_defaults() + + for var_name in self.variables: + if 'accumulated' in var_name: + full_var_name = var_name + else: + full_var_name = f'{prefix}{var_name}' + if full_var_name not in ds.keys(): + if f'{prefix}activeTracers_{var_name}' in ds.keys(): + full_var_name = f'{prefix}activeTracers_{var_name}' + elif var_name == 'columnThickness': + ds[full_var_name] = ds.bottomDepth + ds.ssh + else: + print( + f'Skipping {full_var_name}, ' + f'not found in {self.input_file}' + ) + continue + print(f'Plotting {full_var_name}') + filename_suffix = '' + mpas_field = ds[full_var_name] + if 'nVertLevels' in mpas_field.sizes: + mpas_field = mpas_field.isel(nVertLevels=z_index) + if z_index != 0: + filename_suffix = f'_z{z_index}' + + if self.config.has_option(section_name, 'colormap_name'): + cmap = self.config.get(section_name, 'colormap_name') + else: + if var_name in viz_dict.keys(): + cmap = viz_dict[var_name]['colormap'] + else: + cmap = viz_dict['default']['colormap'] + self.config.set(section_name, 'colormap_name', value=cmap) + + if self.config.has_option(section_name, 'colormap_range_percent'): + colormap_range_percent = self.config.getfloat( + section_name, 'colormap_range_percent' + ) + else: + colormap_range_percent = 0.0 + + if colormap_range_percent > 0.0: + vmin = np.percentile(mpas_field.values, colormap_range_percent) + vmax = np.percentile( + mpas_field.values, 100.0 - colormap_range_percent + ) + else: + vmin = mpas_field.min().values + vmax = mpas_field.max().values + + if self.config.has_option( + section_name, 'vmin' + ) and self.config.has_option(section_name, 'vmax'): + vmin = section.getfloat('vmin') + vmax = section.getfloat('vmax') + elif ( + cmap == 'cmo.balance' + or 'vertVelocityTop' in var_name + or 'Tendency' in var_name + or 'Flux' in var_name + ): + vmax = max(abs(vmin), abs(vmax)) + vmin = -vmax + + self.config.set( + section_name, + 'norm_args', + value='{"vmin": ' + str(vmin) + ', "vmax": ' + str(vmax) + '}', + ) + + if var_name in viz_dict.keys(): + units = viz_dict[var_name]['units'] + else: + units = viz_dict['default']['units'] + + descriptor = plot_global_mpas_field( + mesh_filename=self.mesh_file, + da=mpas_field, + out_filename=f'{var_name}_horiz{time_stamp}{filename_suffix}.png', + config=self.config, + colormap_section='customizable_viz_horiz_field', + descriptor=descriptor, + colorbar_label=f'{var_name} [{units}]', + plot_land=True, + projection_name=projection_name, + central_longitude=central_longitude, + cell_indices=cell_indices[0], + ) diff --git a/polaris/tasks/ocean/customizable_viz/viz_transect.py b/polaris/tasks/ocean/customizable_viz/viz_transect.py new file mode 100644 index 0000000000..a88e327a1d --- /dev/null +++ b/polaris/tasks/ocean/customizable_viz/viz_transect.py @@ -0,0 +1,102 @@ +import cmocean # noqa: F401 +import numpy as np +import xarray as xr +from mpas_tools.ocean.viz.transect import compute_transect, plot_transect + +from polaris.ocean.model import OceanIOStep as OceanIOStep +from polaris.viz import ( + determine_time_variable, + get_viz_defaults, +) + + +class VizTransect(OceanIOStep): + def __init__(self, component, name, indir): + super().__init__(component=component, name=name, indir=indir) + + def runtime_setup(self): + section = self.config['customizable_viz'] + self.mesh_file = section.get('mesh_file') + self.input_file = section.get('input_file') + section_name = 'customizable_viz_transect' + self.variables = self.config.getlist( + section_name, 'variables', dtype=str + ) + if not self.variables: + raise ValueError( + f'No variables specified in the {section_name} section of ' + 'the config file.' + ) + + def run(self): + section_name = 'customizable_viz_transect' + section = self.config[section_name] + layer_interface_color = section.get('layer_interface_color') + x_start = section.getfloat('x_start') + x_end = section.getfloat('x_end') + y_start = section.getfloat('y_start') + y_end = section.getfloat('y_end') + + x = xr.DataArray(data=[x_start, x_end]) + y = xr.DataArray(data=[y_start, y_end]) + + ds_mesh = self.open_model_dataset(self.mesh_file) + ds = self.open_model_dataset(self.input_file, decode_timedelta=False) + + # TODO support time selection from config file + t_index = 0 + ds = ds.isel(Time=t_index) + prefix, time_variable = determine_time_variable(ds) + if time_variable is not None: + start_time = ds[time_variable].values[0] + start_time = start_time.decode() + time_stamp = f'_{start_time.split("_")[0]}' + else: + time_stamp = '' + + # Transect is constructed for nVertLevels quantities + if 'nVertLevelsP1' in ds.sizes: + ds = ds.isel(nVertLevelsP1=slice(0, -1)) + ds_transect = compute_transect( + x=x, + y=y, + ds_horiz_mesh=ds_mesh, + layer_thickness=ds[f'{prefix}layerThickness'], + bottom_depth=ds_mesh.bottomDepth, + min_level_cell=ds_mesh.minLevelCell - 1, + max_level_cell=ds_mesh.maxLevelCell - 1, + spherical=True, + ) + + viz_dict = get_viz_defaults() + for var_name in self.variables: + mpas_field = ds[f'{prefix}{var_name}'] + if self.config.has_option(section_name, 'vmin'): + vmin = section.getfloat('vmin') + else: + vmin = np.percentile(mpas_field.values, 5) + if self.config.has_option(section_name, 'vmax'): + vmax = section.getfloat('vmax') + else: + vmax = np.percentile(mpas_field.values, 95) + if vmax > 0.0 and vmin < 0.0: + vmax = max(abs(vmax), abs(vmin)) + vmin = -vmax + if var_name in viz_dict.keys(): + cmap = viz_dict[var_name]['colormap'] + units = viz_dict[var_name]['units'] + else: + cmap = viz_dict['default']['colormap'] + units = viz_dict['default']['units'] + plot_transect( + ds_transect=ds_transect, + mpas_field=mpas_field, + title=f'{var_name}', + out_filename=f'{var_name}_transect{time_stamp}.png', + interface_color=layer_interface_color, + vmin=vmin, + vmax=vmax, + cmap=cmap, + colorbar_label=units, + color_start_and_end=True, + ) diff --git a/polaris/viz/__init__.py b/polaris/viz/__init__.py index 49a18f7d7f..40863a1dec 100644 --- a/polaris/viz/__init__.py +++ b/polaris/viz/__init__.py @@ -1,3 +1,12 @@ +from polaris.viz.helper import ( + determine_time_variable as determine_time_variable, +) +from polaris.viz.helper import ( + get_projection as get_projection, +) +from polaris.viz.helper import ( + get_viz_defaults as get_viz_defaults, +) from polaris.viz.planar import plot_horiz_field as plot_horiz_field from polaris.viz.spherical import ( plot_global_lat_lon_field as plot_global_lat_lon_field, diff --git a/polaris/viz/helper.py b/polaris/viz/helper.py new file mode 100644 index 0000000000..6fc2e2d546 --- /dev/null +++ b/polaris/viz/helper.py @@ -0,0 +1,64 @@ +import cartopy.crs as ccrs + +projections = { + 'PlateCarree': ccrs.PlateCarree, + 'LambertCylindrical': ccrs.LambertCylindrical, + 'Mercator': ccrs.Mercator, + 'Miller': ccrs.Miller, + 'Robinson': ccrs.Robinson, + 'Stereographic': ccrs.Stereographic, + 'RotatedPole': ccrs.RotatedPole, + 'InterruptedGoodeHomolosine': ccrs.InterruptedGoodeHomolosine, + 'EckertI': ccrs.EckertI, + 'EckertII': ccrs.EckertII, + 'EckertIII': ccrs.EckertIII, + 'EckertIV': ccrs.EckertIV, + 'EckertV': ccrs.EckertV, + 'EckertVI': ccrs.EckertVI, + 'EqualEarth': ccrs.EqualEarth, + 'NorthPolarStereo': ccrs.NorthPolarStereo, + 'SouthPolarStereo': ccrs.SouthPolarStereo, +} + + +def get_projection(name: str, **kwargs): + """Return a Cartopy projection by string name.""" + if name not in projections: + raise ValueError( + f"Unknown projection '{name}'. Available: {list(projections)}" + ) + return projections[name](**kwargs) + + +def get_viz_defaults(): + # indexed by mpas-ocean variable name in instantaneous output + viz_dict = { + 'bottomDepth': {'colormap': 'cmo.deep', 'units': r'm'}, + 'layerThickness': {'colormap': 'cmo.thermal', 'units': r'm'}, + 'temperature': {'colormap': 'cmo.thermal', 'units': r'$^{\circ}$C'}, + 'salinity': {'colormap': 'cmo.haline', 'units': r'g/kg'}, + 'density': {'colormap': 'cmo.dense', 'units': r'kg/m$^3$'}, + 'ssh': {'colormap': 'cmo.delta', 'units': r'm'}, + 'vertVelocityTop': {'colormap': 'cmo.balance', 'units': r'm/s'}, + 'normalVelocity': {'colormap': 'cmo.balance', 'units': r'm/s'}, + 'velocityZonal': {'colormap': 'cmo.balance', 'units': r'm/s'}, + 'velocityMeridional': {'colormap': 'cmo.balance', 'units': r'm/s'}, + 'landIceFraction': {'colormap': 'cmo.ice', 'units': r''}, + 'seaIceFraction': {'colormap': 'cmo.ice', 'units': r''}, + 'default': {'colormap': 'cmo.dense', 'units': r''}, + } + return viz_dict + + +def determine_time_variable(ds): + prefix = '' + time_variable = None + if 'timeSeriesStatsMonthly' in ds.keys(): + prefix = 'timeMonthly_avg_' + time_variable = 'xtime_startMonthly' + elif 'xtime' in ds.keys(): + time_variable = 'xtime' + elif 'Time' in ds.keys(): + prefix = 'timeMonthly_avg_' + time_variable = 'Time' + return prefix, time_variable diff --git a/polaris/viz/spherical.py b/polaris/viz/spherical.py index 116f5cebf7..2aa413f71d 100644 --- a/polaris/viz/spherical.py +++ b/polaris/viz/spherical.py @@ -5,19 +5,22 @@ import matplotlib.colors as cols import matplotlib.pyplot as plt import mosaic +import numpy as np import xarray as xr +from cartopy.geodesic import Geodesic from mpl_toolkits.axes_grid1.inset_locator import inset_axes from pyremap.descriptor.utility import interp_extrap_corner +from polaris.viz.helper import get_projection from polaris.viz.style import use_mplstyle def plot_global_mpas_field( - mesh_filename, da, out_filename, config, colormap_section, + mesh_filename=None, title=None, plot_land=True, colorbar_label='', @@ -26,6 +29,9 @@ def plot_global_mpas_field( dpi=200, patch_edge_color=None, descriptor=None, + projection_name='PlateCarree', + cell_indices=None, + enforce_aspect_ratio=False, ): """ Plots a data set as a longitude-latitude map @@ -87,10 +93,20 @@ def plot_global_mpas_field( use_mplstyle() transform = cartopy.crs.Geodetic() - projection = cartopy.crs.PlateCarree(central_longitude=central_longitude) + projection = get_projection( + projection_name, central_longitude=central_longitude + ) - mesh_ds = xr.open_dataset(mesh_filename) if descriptor is None: + if mesh_filename is None: + raise ValueError( + 'Either mesh_filename or descriptor must be given' + ' as parameters to Descriptor' + ) + mesh_ds = xr.open_dataset(mesh_filename) + mesh_ds.attrs['is_periodic'] = 'NO' + if cell_indices is not None: + mesh_ds = mesh_ds.isel(nCells=cell_indices) descriptor = mosaic.Descriptor( mesh_ds, projection=projection, @@ -131,6 +147,20 @@ def plot_global_mpas_field( pc, ax=ax, label=colorbar_label, extend='both', shrink=0.6 ) + if enforce_aspect_ratio: + min_latitude = np.rad2deg(mesh_ds.latCell.min().values) + max_latitude = np.rad2deg(mesh_ds.latCell.max().values) + min_longitude = np.rad2deg(mesh_ds.lonCell.min().values) + max_longitude = np.rad2deg(mesh_ds.lonCell.max().values) + geod = Geodesic() + x_distance = geod.inverse( + [min_longitude, min_latitude], [max_longitude, min_latitude] + )[0, 0] + y_distance = geod.inverse( + [min_longitude, min_latitude], [min_longitude, max_latitude] + )[0, 0] + ax.set_aspect(y_distance / x_distance) + if ticks is not None: cbar.set_ticks(ticks) cbar.set_ticklabels([f'{tick}' for tick in ticks]) @@ -385,7 +415,7 @@ def _add_land_lakes_coastline(ax, ice_shelves=True): 'antarctic_ice_shelves_polys', '50m', edgecolor='k', - facecolor=land_color, + facecolor='none', linewidth=0.5, ) ax.add_feature(ice_50m, zorder=3)