Skip to content

Commit 161495a

Browse files
committed
improve UGRID support
with this PR, we implement the ability to visualize variables on nodes and edges via the dual mesh. We base the UGRID decoder on the gridded package, particularly on NOAA-ORR-ERD/gridded#61 See also - psyplot#29 - psyplot/psy-maps#32
1 parent 4aa4cb1 commit 161495a

File tree

1 file changed

+246
-43
lines changed

1 file changed

+246
-43
lines changed

psyplot/data.py

+246-43
Original file line numberDiff line numberDiff line change
@@ -1645,20 +1645,36 @@ def standardize_dims(self, var, dims={}):
16451645
dims[name_map[dim]] = dims.pop(dim)
16461646
return dims
16471647

1648+
def clear_cache(self):
1649+
"""Clear any cached data.
1650+
1651+
The default method does nothing but can be reimplemented by subclasses
1652+
to clear data has been computed. """
1653+
pass
1654+
16481655

16491656
class UGridDecoder(CFDecoder):
16501657
"""
1651-
Decoder for UGrid data sets
1658+
Decoder for UGrid data sets"""
16521659

1653-
Warnings
1654-
--------
1655-
Currently only triangles are supported."""
1660+
#: mapping from grid name to the :class:`gridded.pyugrid.ugrid.UGrid`
1661+
# object representing it
1662+
_grids = {}
1663+
1664+
def __init__(self, *args, **kwargs):
1665+
super().__init__(*args, **kwargs)
1666+
self._grids = {}
1667+
1668+
def clear_cache(self):
1669+
"""Clear the cache and remove the UGRID instances."""
1670+
self._grids.clear()
16561671

16571672
def is_unstructured(self, *args, **kwargs):
16581673
"""Reimpletemented to return always True. Any ``*args`` and ``**kwargs``
16591674
are ignored"""
16601675
return True
16611676

1677+
@docstrings.get_sections(base="UGridDecoder.get_mesh")
16621678
def get_mesh(self, var, coords=None):
16631679
"""Get the mesh variable for the given `var`
16641680
@@ -1681,6 +1697,109 @@ def get_mesh(self, var, coords=None):
16811697
coords = self.ds.coords
16821698
return coords.get(mesh, self.ds.coords.get(mesh))
16831699

1700+
@docstrings.with_indent(8)
1701+
def get_ugrid(self, var, coords=None, loc="infer"):
1702+
"""Get the :class:`~gridded.pyugrid.ugrid.UGrid` mesh object.
1703+
1704+
This method creates a :class:`gridded.pyugrid.ugrid.UGrid` object for
1705+
a given variable, depending on the corresponding ``'mesh'`` attribute.
1706+
1707+
Parameters
1708+
----------
1709+
%(UGridDecoder.get_mesh.parameters)s
1710+
dual: {"infer", "node", "edge", "face"}
1711+
If "node" or "edge", the dual grid will be computed.
1712+
1713+
Returns
1714+
-------
1715+
gridded.pyugrid.ugrid.UGrid
1716+
The UGrid object representing the mesh.
1717+
"""
1718+
from gridded.pyugrid.ugrid import UGrid
1719+
1720+
def get_coord(cname, raise_error=True):
1721+
try:
1722+
ret = coords[cname]
1723+
except KeyError:
1724+
if cname not in self.ds.coords:
1725+
if raise_error:
1726+
raise
1727+
return None
1728+
else:
1729+
ret = self.ds.coords[cname]
1730+
try:
1731+
idims = var.psy.idims
1732+
except AttributeError: # got xarray.Variable
1733+
idims = {}
1734+
ret = ret.isel(**{
1735+
d: sl for d, sl in idims.items() if d in ret.dims}
1736+
)
1737+
if "start_index" in ret.attrs:
1738+
return ret.values - int(ret.start_index)
1739+
else:
1740+
return ret.values
1741+
1742+
mesh = self.get_mesh(var, coords)
1743+
1744+
if mesh.name in self._grids:
1745+
grid = self._grids[mesh.name]
1746+
else:
1747+
required_parameters = ["faces"]
1748+
1749+
parameters = {
1750+
"faces": "face_node_connectivity",
1751+
"face_face_connectivity": "face_face_connectivity",
1752+
"edges": "edge_node_connectivity",
1753+
"boundaries": "boundary_node_connectivity",
1754+
"face_coordinates": "face_coordinates",
1755+
"edge_coordinates": "edge_coordinates",
1756+
"boundary_coordinates": "boundary_coordinates",
1757+
}
1758+
1759+
x_nodes, y_nodes = self.get_nodes(mesh, coords)
1760+
1761+
kws = {
1762+
"node_lon": x_nodes, "node_lat": y_nodes, "mesh_name": mesh.name
1763+
}
1764+
1765+
coords = coords or {}
1766+
1767+
for key, attr in parameters.items():
1768+
if attr in mesh.attrs:
1769+
kws[key] = get_coord(
1770+
mesh.attrs[attr], key in required_parameters
1771+
)
1772+
1773+
# now we have to turn NaN into masked integer arrays
1774+
connectivity_parameters = [
1775+
"faces", "face_face_connectivity", "edges", "boundaries"
1776+
]
1777+
for param in connectivity_parameters:
1778+
if kws.get(param) is not None:
1779+
arr = kws[param]
1780+
mask = np.isnan(arr)
1781+
if mask.any():
1782+
arr = np.where(mask, -999, arr).astype(int)
1783+
kws[param] = np.ma.masked_where(mask, arr)
1784+
1785+
grid = UGrid(**kws)
1786+
self._grids[mesh.name] = grid
1787+
1788+
# create the dual mesh if necessary
1789+
if loc == "infer":
1790+
loc = self.infer_location(var, coords, grid)
1791+
1792+
if loc in ["node", "edge"]:
1793+
dual_name = grid.mesh_name + "_dual_" + loc
1794+
if dual_name in self._grids:
1795+
grid = self._grids[dual_name]
1796+
else:
1797+
grid = grid.create_dual_mesh(loc)
1798+
grid.mesh_name = dual_name
1799+
self._grids[dual_name] = grid
1800+
1801+
return grid
1802+
16841803
@classmethod
16851804
@docstrings.dedent
16861805
def can_decode(cls, ds, var):
@@ -1807,37 +1926,55 @@ def get_coord(coord):
18071926
if vert is None:
18081927
raise ValueError("Could not find the nodes variables for the %s "
18091928
"coordinate!" % axis)
1810-
loc = var.attrs.get('location', 'face')
1811-
if loc == 'node':
1812-
# we assume a triangular grid and use matplotlibs triangulation
1813-
from matplotlib.tri import Triangulation
1814-
xvert, yvert = nodes
1815-
triangles = Triangulation(xvert, yvert)
1816-
if axis == 'x':
1817-
bounds = triangles.x[triangles.triangles]
1818-
else:
1819-
bounds = triangles.y[triangles.triangles]
1820-
elif loc in ['edge', 'face']:
1821-
connectivity = get_coord(
1822-
mesh.attrs.get('%s_node_connectivity' % loc, ''))
1823-
if connectivity is None:
1824-
raise ValueError(
1825-
"Could not find the connectivity information!")
1826-
connectivity = connectivity.values
1827-
bounds = vert.values[
1828-
np.where(np.isnan(connectivity), connectivity[:, :1],
1829-
connectivity).astype(int)]
1830-
else:
1831-
raise ValueError(
1832-
"Could not interprete location attribute (%s) of mesh "
1833-
"variable %s!" % (loc, mesh.name))
1929+
1930+
grid = self.get_ugrid(var, coords)
1931+
1932+
faces = grid.faces
1933+
if np.ma.isMA(faces) and faces.mask.any():
1934+
isnull = faces.mask
1935+
faces = faces.filled(-999).astype(int)
1936+
for i in range(faces.shape[1]):
1937+
mask = isnull[:, i]
1938+
if mask.any():
1939+
for j in range(i, faces.shape[1]):
1940+
faces[mask, j] = faces[mask, j - i]
1941+
1942+
node = grid.nodes[..., 0 if axis == "x" else 1]
1943+
bounds = node[faces]
1944+
1945+
loc = self.infer_location(var, coords)
1946+
18341947
dim0 = '__face' if loc == 'node' else var.dims[-1]
18351948
return xr.DataArray(
18361949
bounds,
18371950
coords={key: val for key, val in coords.items()
18381951
if (dim0, ) == val.dims},
18391952
dims=(dim0, '__bnds', ),
1840-
name=vert.name + '_bnds', attrs=vert.attrs.copy())
1953+
name=vert.name + '_bnds', attrs=vert.attrs.copy())
1954+
1955+
@docstrings.with_indent(8)
1956+
def infer_location(self, var, coords=None, grid=None):
1957+
"""Infer the location for the variable.
1958+
1959+
Parameters
1960+
----------
1961+
%(UGridDecoder.get_mesh.parameters)s
1962+
grid: gridded.pyugrid.ugrid.UGrid
1963+
The grid for this variable. If None, it will be created using the
1964+
:meth:`get_ugrid` method (if necessary)
1965+
1966+
Returns
1967+
-------
1968+
str
1969+
``"node"``, ``"face"`` or ``"edge"``
1970+
"""
1971+
if not var.attrs.get('location'):
1972+
if grid is None:
1973+
grid = self.get_ugrid(var, coords, loc="face")
1974+
loc = grid.infer_location(var)
1975+
else:
1976+
loc = var.attrs["location"]
1977+
return loc
18411978

18421979
@staticmethod
18431980
@docstrings.dedent
@@ -1880,20 +2017,71 @@ def decode_coords(ds, gridfile=None):
18802017
ds._coord_names.update(extra_coords.intersection(ds.variables))
18812018
return ds
18822019

1883-
def get_nodes(self, coord, coords):
2020+
def get_nodes(self, coord, coords=None):
18842021
"""Get the variables containing the definition of the nodes
18852022
18862023
Parameters
18872024
----------
18882025
coord: xarray.Coordinate
18892026
The mesh variable
1890-
coords: dict
1891-
The coordinates to use to get node coordinates"""
2027+
coords: dict, optional
2028+
The coordinates to use to get node coordinates """
2029+
if coords is None:
2030+
coords = {}
18922031
def get_coord(coord):
18932032
return coords.get(coord, self.ds.coords.get(coord))
18942033
return list(map(get_coord,
18952034
coord.attrs.get('node_coordinates', '').split()[:2]))
18962035

2036+
@docstrings.with_indent(8)
2037+
def get_xname(self, var, coords=None):
2038+
"""Get the name of the spatial dimension
2039+
2040+
Parameters
2041+
----------
2042+
%(CFDecoder.get_y.parameters)s
2043+
2044+
Returns
2045+
-------
2046+
str
2047+
The dimension name
2048+
"""
2049+
2050+
def get_dim(name):
2051+
coord = coords.get(
2052+
name, ds.coords.get(name, ds.variables.get(name))
2053+
)
2054+
if coord is None:
2055+
raise KeyError(f"Missing {loc} coordinate {name}")
2056+
else:
2057+
return coord.dims[0]
2058+
2059+
ds = self.ds
2060+
loc = self.infer_location(var, coords)
2061+
mesh = self.get_mesh(var, coords)
2062+
coords = coords or ds.coords
2063+
if loc == "node":
2064+
return get_dim(mesh.node_coordinates.split()[0])
2065+
elif loc == "edge":
2066+
return get_dim(mesh.edge_node_connectivity)
2067+
else:
2068+
return get_dim(mesh.face_node_connectivity)
2069+
2070+
@docstrings.with_indent(8)
2071+
def get_yname(self, var, coords=None):
2072+
"""Get the name of the spatial dimension
2073+
2074+
Parameters
2075+
----------
2076+
%(CFDecoder.get_y.parameters)s
2077+
2078+
Returns
2079+
-------
2080+
str
2081+
The dimension name
2082+
"""
2083+
return self.get_xname(var, coords) # x- and y-dimensions are the same
2084+
18972085
@docstrings.dedent
18982086
def get_x(self, var, coords=None):
18992087
"""
@@ -1911,18 +2099,25 @@ def get_x(self, var, coords=None):
19112099
# first we try the super class
19122100
ret = super(UGridDecoder, self).get_x(var, coords)
19132101
# but if that doesn't work because we get the variable name in the
1914-
# dimension of `var`, we use the means of the triangles
2102+
# dimension of `var`, we use the means of the faces
19152103
if ret is None or ret.name in var.dims or (hasattr(var, 'mesh') and
19162104
ret.name == var.mesh):
1917-
bounds = self.get_cell_node_coord(var, axis='x', coords=coords)
1918-
if bounds is not None:
1919-
centers = bounds.mean(axis=-1)
1920-
x = self.get_nodes(self.get_mesh(var, coords), coords)[0]
2105+
loc = self.infer_location(var, coords)
2106+
x = self.get_nodes(self.get_mesh(var, coords), coords)[0]
2107+
if loc == "node":
2108+
return x
2109+
else:
2110+
grid = self.get_ugrid(var, coords, loc)
2111+
if grid.face_coordinates is None:
2112+
grid.build_face_coordinates()
19212113
try:
19222114
cls = xr.IndexVariable
19232115
except AttributeError: # xarray < 0.9
19242116
cls = xr.Coordinate
1925-
return cls(x.name, centers, attrs=x.attrs.copy())
2117+
return cls(
2118+
x.name, grid.face_coordinates[..., 1],
2119+
attrs=x.attrs.copy()
2120+
)
19262121
else:
19272122
return ret
19282123

@@ -1946,15 +2141,22 @@ def get_y(self, var, coords=None):
19462141
# dimension of `var`, we use the means of the triangles
19472142
if ret is None or ret.name in var.dims or (hasattr(var, 'mesh') and
19482143
ret.name == var.mesh):
1949-
bounds = self.get_cell_node_coord(var, axis='y', coords=coords)
1950-
if bounds is not None:
1951-
centers = bounds.mean(axis=-1)
1952-
y = self.get_nodes(self.get_mesh(var, coords), coords)[1]
2144+
loc = self.infer_location(var, coords)
2145+
y = self.get_nodes(self.get_mesh(var, coords), coords)[1]
2146+
if loc == "node":
2147+
return y
2148+
else:
2149+
grid = self.get_ugrid(var, coords, loc)
2150+
if grid.face_coordinates is None:
2151+
grid.build_face_coordinates()
19532152
try:
19542153
cls = xr.IndexVariable
19552154
except AttributeError: # xarray < 0.9
19562155
cls = xr.Coordinate
1957-
return cls(y.name, centers, attrs=y.attrs.copy())
2156+
return cls(
2157+
y.name, grid.face_coordinates[..., 1],
2158+
attrs=y.attrs.copy()
2159+
)
19582160
else:
19592161
return ret
19602162

@@ -2818,6 +3020,7 @@ def filter_attrs(item):
28183020
dims = self._new_dims
28193021
method = self.method
28203022
if dims:
3023+
self.decoder.clear_cache()
28213024
if VARIABLELABEL in self.arr.coords:
28223025
self._update_concatenated(dims, method)
28233026
else:

0 commit comments

Comments
 (0)