Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed polaris/ocean/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions polaris/tasks/ocean/add_tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from polaris.tasks.ocean.baroclinic_channel import add_baroclinic_channel_tasks
from polaris.tasks.ocean.barotropic_gyre import add_barotropic_gyre_tasks
from polaris.tasks.ocean.cosine_bell import add_cosine_bell_tasks
from polaris.tasks.ocean.customizable_viz import add_customizable_viz_tasks
from polaris.tasks.ocean.external_gravity_wave import (
add_external_gravity_wave_tasks as add_external_gravity_wave_tasks,
)
Expand Down Expand Up @@ -44,6 +45,7 @@ def add_ocean_tasks(component):
add_single_column_tasks(component=component)

# spherical tasks
add_customizable_viz_tasks(component=component)
add_cosine_bell_tasks(component=component)
add_external_gravity_wave_tasks(component=component)
add_geostrophic_tasks(component=component)
Expand Down
51 changes: 51 additions & 0 deletions polaris/tasks/ocean/customizable_viz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os

from polaris import Task
from polaris.config import PolarisConfigParser as PolarisConfigParser
from polaris.tasks.ocean.customizable_viz.viz_horiz_field import (
VizHorizField as VizHorizField,
)
from polaris.tasks.ocean.customizable_viz.viz_transect import (
VizTransect as VizTransect,
)


def add_customizable_viz_tasks(component):
customizable_viz_task = CustomizableViz(component=component)
component.add_task(customizable_viz_task)


class CustomizableViz(Task):
"""
A customizable visualization task for MPAS-Ocean output
"""

def __init__(self, component):
basedir = 'customizable_viz'
name = 'customizable_viz'
super().__init__(component=component, name=name, subdir=basedir)

config_filename = 'customizable_viz.cfg'
config = PolarisConfigParser(
filepath=os.path.join(component.name, config_filename)
)
config.add_from_package(
'polaris.tasks.ocean.customizable_viz', config_filename
)
self.set_shared_config(config, link=config_filename)

viz_step = VizHorizField(
component=component,
name='viz_horiz_field',
indir=self.subdir,
)
viz_step.set_shared_config(config, link=config_filename)
self.add_step(viz_step, run_by_default=True)

transect_step = VizTransect(
component=component,
name='viz_transect',
indir=self.subdir,
)
transect_step.set_shared_config(config, link=config_filename)
self.add_step(transect_step, run_by_default=False)
57 changes: 57 additions & 0 deletions polaris/tasks/ocean/customizable_viz/customizable_viz.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
[customizable_viz]

# Mesh file, absolute file path
mesh_file = /lcrc/group/e3sm/data/inputdata/ocn/mpas-o/IcoswISC30E3r5/mpaso.IcoswISC30E3r5.rstFromG-chrysalis.20231121.nc

# Data file, absolute file path
input_file = /lcrc/group/e3sm/data/inputdata/ocn/mpas-o/IcoswISC30E3r5/mpaso.IcoswISC30E3r5.rstFromG-chrysalis.20231121.nc

[customizable_viz_horiz_field]

# Projection to use for the horizontal field plot, must be supported by mosaic
projection = PlateCarree

# Fields to plot, comma-separated
variables = layerThickness

# Bounding box for the plot, global by default
min_latitude = -90

max_latitude = 90

min_longitude = 0

max_longitude = 360

# Optional additional arguments to provide to the colormap norm
norm_args = {}

# the type of norm used in the colormap
norm_type = linear

# Note: for some projections, choosing a different central longitude may
# result in cells spanning the plot boundary
central_longitude = 180.

# The depth in m below the surface to use for layer selection
# Only one vertical level will be selected and its depth may vary
z_target = 0.

[customizable_viz_transect]

# Fields to plot, comma-separated
variables = salinity

# The start and end coordinates for the transect
# where x is longitude and y is latitude
x_start = -5.0
y_start = -65.0
x_end = 10.0
y_end = -65.0

# Optional limits for colormap scaling
vmin = None
vmax = None

# Color to use for interfaces between vertical levels
layer_interface_color = None
197 changes: 197 additions & 0 deletions polaris/tasks/ocean/customizable_viz/viz_horiz_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import cmocean # noqa: F401
import numpy as np

from polaris.ocean.model import OceanIOStep
from polaris.viz import (
determine_time_variable,
get_viz_defaults,
plot_global_mpas_field,
)


class VizHorizField(OceanIOStep):
def __init__(self, component, name, indir):
super().__init__(component=component, name=name, indir=indir)

def runtime_setup(self):
section = self.config['customizable_viz']
self.mesh_file = section.get('mesh_file')
self.input_file = section.get('input_file')

section_name = 'customizable_viz_horiz_field'
self.variables = self.config.getlist(
section_name, 'variables', dtype=str
)
if not self.variables:
raise ValueError(
f'No variables specified in the {section_name} section of '
'the config file.'
)

def run(self): # noqa:C901
section_name = 'customizable_viz_horiz_field'
section = self.config[section_name]

# Determine the projection from the config file
projection_name = section.get('projection')
central_longitude = section.getfloat('central_longitude')

# Descriptor is none for the first variable and assigned thereafter
descriptor = None

ds_mesh = self.open_model_dataset(
self.mesh_file, decode_timedelta=False
)
min_latitude = section.getfloat('min_latitude')
max_latitude = section.getfloat('max_latitude')
min_longitude = section.getfloat('min_longitude')
max_longitude = section.getfloat('max_longitude')
lat_cell = np.rad2deg(ds_mesh['latCell'])
lon_cell = np.rad2deg(ds_mesh['lonCell'])
if min_longitude < 0.0 and lon_cell.min().values > 0.0:
max_longitude_copy = max_longitude
max_longitude = 360.0 - min_longitude
min_longitude = max_longitude_copy
cell_indices = np.where(
(ds_mesh.maxLevelCell > 0)
& (lat_cell >= min_latitude)
& (lat_cell <= max_latitude)
& (lon_cell >= min_longitude)
& (lon_cell <= max_longitude)
)
Comment on lines +55 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This culling process is a little bit risky because cells with centers outside the range may nevertheless have vertices in the range and the boundary may look ragged with this approach.

Also, many projections don't have straight edges in latitude/longitude space so this may drop cells that should appear within the project bounds. Might it be better to have mosaic try to do this kind of optimization and not try to handle it on the Polaris side?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewdnolan, do you have thoughts on this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I could just check lonVertex/latVertex instead of cell centers?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is okay for cell-centered fields, though I might need to give it a bit more thought.

What we definitely need to do is iterate over the list of variables and make sure they are cell fields. (Only if this culling does subset the mesh, otherwise no restriction on them having to be cell-fields).

@xylar and I have talked about adding a more complete regional masking functionality (which would replicate the culler), which accepts a cell mask but would allow plotting cell, edge, or vertex fields. That's definitely something I can put at the top of the mosiac development priorities.

if len(cell_indices[0]) == 0:
raise ValueError(
f'No cells of {ds_mesh.sizes["nCells"]} cells found within the'
' specified lat/lon bounds. Please adjust the min/max '
'latitude/longitude values to be within the bounds of the '
f'dataset: latitude '
f'{lat_cell.min().values},{lat_cell.max().values} \n'
f'longitude {lon_cell.min().values},{lon_cell.max().values}'
)
print(
f'Using {len(cell_indices[0])} cells of '
f'{ds_mesh.sizes["nCells"]} cells in the mesh'
)
ds_mesh = ds_mesh.isel(nCells=cell_indices[0])
z_target = section.getfloat('z_target')
z_bottom = ds_mesh['restingThickness'].cumsum(dim='nVertLevels')
dz = z_bottom.mean(dim='nCells') - z_target
z_index = np.argmin(np.abs(dz.values))
if dz[z_index] > 0 and z_index > 0:
z_index -= 1
z_mean = z_bottom.mean(dim='nCells')[z_index].values
print(
f'Using z_index {z_index} for z_target {z_target} '
f'with mean depth {z_mean} '
)

ds = self.open_model_dataset(self.input_file, decode_timedelta=False)

if 'Time' in ds.sizes:
t_index = 0
# TODO support different time selection from config file
ds = ds.isel(Time=t_index)

prefix, time_variable = determine_time_variable(ds)
if time_variable is not None:
start_time = ds[time_variable].values[0]
start_time = start_time.decode()
time_stamp = f'_{start_time.split("_")[0]}'
else:
time_stamp = ''

ds = ds.isel(nCells=cell_indices[0])
if ds.sizes['nCells'] != ds_mesh.sizes['nCells']:
raise ValueError(
f'Number of cells in the mesh {ds_mesh.sizes["nCells"]} '
f'and input {ds.sizes["nCells"]} do not match. '
)
viz_dict = get_viz_defaults()

for var_name in self.variables:
if 'accumulated' in var_name:
full_var_name = var_name
else:
full_var_name = f'{prefix}{var_name}'
if full_var_name not in ds.keys():
if f'{prefix}activeTracers_{var_name}' in ds.keys():
full_var_name = f'{prefix}activeTracers_{var_name}'
elif var_name == 'columnThickness':
ds[full_var_name] = ds.bottomDepth + ds.ssh
else:
print(
f'Skipping {full_var_name}, '
f'not found in {self.input_file}'
)
continue
print(f'Plotting {full_var_name}')
filename_suffix = ''
mpas_field = ds[full_var_name]
if 'nVertLevels' in mpas_field.sizes:
mpas_field = mpas_field.isel(nVertLevels=z_index)
if z_index != 0:
filename_suffix = f'_z{z_index}'

if self.config.has_option(section_name, 'colormap_name'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since colormap only accepts a string, all variables requested will be plotted with the same colormap (if the colormap is specified). Same goes for vmin / vmax.

I think this is fine for now, especially given the viz_dict you've defined. Might be something to tackle later, cause it could cause some unexpected behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wasn't really sure how much to build out the customization options here. I feel like commented lists aren't great because the user will eventually lose track of the correspondence between variable and options. I also wasn't sure how much we wanted this task to be a stepping stone to MPAS-A. Probably easier to have a conversation some time about philosophically where to go with this.

cmap = self.config.get(section_name, 'colormap_name')
else:
if var_name in viz_dict.keys():
cmap = viz_dict[var_name]['colormap']
else:
cmap = viz_dict['default']['colormap']
self.config.set(section_name, 'colormap_name', value=cmap)

if self.config.has_option(section_name, 'colormap_range_percent'):
colormap_range_percent = self.config.getfloat(
section_name, 'colormap_range_percent'
)
else:
colormap_range_percent = 0.0

if colormap_range_percent > 0.0:
vmin = np.percentile(mpas_field.values, colormap_range_percent)
vmax = np.percentile(
mpas_field.values, 100.0 - colormap_range_percent
)
else:
vmin = mpas_field.min().values
vmax = mpas_field.max().values

if self.config.has_option(
section_name, 'vmin'
) and self.config.has_option(section_name, 'vmax'):
vmin = section.getfloat('vmin')
vmax = section.getfloat('vmax')
elif (
cmap == 'cmo.balance'
or 'vertVelocityTop' in var_name
or 'Tendency' in var_name
or 'Flux' in var_name
):
vmax = max(abs(vmin), abs(vmax))
vmin = -vmax

self.config.set(
section_name,
'norm_args',
value='{"vmin": ' + str(vmin) + ', "vmax": ' + str(vmax) + '}',
)

if var_name in viz_dict.keys():
units = viz_dict[var_name]['units']
else:
units = viz_dict['default']['units']

descriptor = plot_global_mpas_field(
mesh_filename=self.mesh_file,
da=mpas_field,
out_filename=f'{var_name}_horiz{time_stamp}{filename_suffix}.png',
config=self.config,
colormap_section='customizable_viz_horiz_field',
descriptor=descriptor,
colorbar_label=f'{var_name} [{units}]',
plot_land=True,
projection_name=projection_name,
central_longitude=central_longitude,
cell_indices=cell_indices[0],
)
Loading
Loading