@@ -1645,20 +1645,36 @@ def standardize_dims(self, var, dims={}):
1645
1645
dims [name_map [dim ]] = dims .pop (dim )
1646
1646
return dims
1647
1647
1648
+ def clear_cache (self ):
1649
+ """Clear any cached data.
1650
+
1651
+ The default method does nothing but can be reimplemented by subclasses
1652
+ to clear data has been computed. """
1653
+ pass
1654
+
1648
1655
1649
1656
class UGridDecoder (CFDecoder ):
1650
1657
"""
1651
- Decoder for UGrid data sets
1658
+ Decoder for UGrid data sets"""
1652
1659
1653
- Warnings
1654
- --------
1655
- Currently only triangles are supported."""
1660
+ #: mapping from grid name to the :class:`gridded.pyugrid.ugrid.UGrid`
1661
+ # object representing it
1662
+ _grids = {}
1663
+
1664
+ def __init__ (self , * args , ** kwargs ):
1665
+ super ().__init__ (* args , ** kwargs )
1666
+ self ._grids = {}
1667
+
1668
+ def clear_cache (self ):
1669
+ """Clear the cache and remove the UGRID instances."""
1670
+ self ._grids .clear ()
1656
1671
1657
1672
def is_unstructured (self , * args , ** kwargs ):
1658
1673
"""Reimpletemented to return always True. Any ``*args`` and ``**kwargs``
1659
1674
are ignored"""
1660
1675
return True
1661
1676
1677
+ @docstrings .get_sections (base = "UGridDecoder.get_mesh" )
1662
1678
def get_mesh (self , var , coords = None ):
1663
1679
"""Get the mesh variable for the given `var`
1664
1680
@@ -1681,6 +1697,109 @@ def get_mesh(self, var, coords=None):
1681
1697
coords = self .ds .coords
1682
1698
return coords .get (mesh , self .ds .coords .get (mesh ))
1683
1699
1700
+ @docstrings .with_indent (8 )
1701
+ def get_ugrid (self , var , coords = None , loc = "infer" ):
1702
+ """Get the :class:`~gridded.pyugrid.ugrid.UGrid` mesh object.
1703
+
1704
+ This method creates a :class:`gridded.pyugrid.ugrid.UGrid` object for
1705
+ a given variable, depending on the corresponding ``'mesh'`` attribute.
1706
+
1707
+ Parameters
1708
+ ----------
1709
+ %(UGridDecoder.get_mesh.parameters)s
1710
+ dual: {"infer", "node", "edge", "face"}
1711
+ If "node" or "edge", the dual grid will be computed.
1712
+
1713
+ Returns
1714
+ -------
1715
+ gridded.pyugrid.ugrid.UGrid
1716
+ The UGrid object representing the mesh.
1717
+ """
1718
+ from gridded .pyugrid .ugrid import UGrid
1719
+
1720
+ def get_coord (cname , raise_error = True ):
1721
+ try :
1722
+ ret = coords [cname ]
1723
+ except KeyError :
1724
+ if cname not in self .ds .coords :
1725
+ if raise_error :
1726
+ raise
1727
+ return None
1728
+ else :
1729
+ ret = self .ds .coords [cname ]
1730
+ try :
1731
+ idims = var .psy .idims
1732
+ except AttributeError : # got xarray.Variable
1733
+ idims = {}
1734
+ ret = ret .isel (** {
1735
+ d : sl for d , sl in idims .items () if d in ret .dims }
1736
+ )
1737
+ if "start_index" in ret .attrs :
1738
+ return ret .values - int (ret .start_index )
1739
+ else :
1740
+ return ret .values
1741
+
1742
+ mesh = self .get_mesh (var , coords )
1743
+
1744
+ if mesh .name in self ._grids :
1745
+ grid = self ._grids [mesh .name ]
1746
+ else :
1747
+ required_parameters = ["faces" ]
1748
+
1749
+ parameters = {
1750
+ "faces" : "face_node_connectivity" ,
1751
+ "face_face_connectivity" : "face_face_connectivity" ,
1752
+ "edges" : "edge_node_connectivity" ,
1753
+ "boundaries" : "boundary_node_connectivity" ,
1754
+ "face_coordinates" : "face_coordinates" ,
1755
+ "edge_coordinates" : "edge_coordinates" ,
1756
+ "boundary_coordinates" : "boundary_coordinates" ,
1757
+ }
1758
+
1759
+ x_nodes , y_nodes = self .get_nodes (mesh , coords )
1760
+
1761
+ kws = {
1762
+ "node_lon" : x_nodes , "node_lat" : y_nodes , "mesh_name" : mesh .name
1763
+ }
1764
+
1765
+ coords = coords or {}
1766
+
1767
+ for key , attr in parameters .items ():
1768
+ if attr in mesh .attrs :
1769
+ kws [key ] = get_coord (
1770
+ mesh .attrs [attr ], key in required_parameters
1771
+ )
1772
+
1773
+ # now we have to turn NaN into masked integer arrays
1774
+ connectivity_parameters = [
1775
+ "faces" , "face_face_connectivity" , "edges" , "boundaries"
1776
+ ]
1777
+ for param in connectivity_parameters :
1778
+ if kws .get (param ) is not None :
1779
+ arr = kws [param ]
1780
+ mask = np .isnan (arr )
1781
+ if mask .any ():
1782
+ arr = np .where (mask , - 999 , arr ).astype (int )
1783
+ kws [param ] = np .ma .masked_where (mask , arr )
1784
+
1785
+ grid = UGrid (** kws )
1786
+ self ._grids [mesh .name ] = grid
1787
+
1788
+ # create the dual mesh if necessary
1789
+ if loc == "infer" :
1790
+ loc = self .infer_location (var , coords , grid )
1791
+
1792
+ if loc in ["node" , "edge" ]:
1793
+ dual_name = grid .mesh_name + "_dual_" + loc
1794
+ if dual_name in self ._grids :
1795
+ grid = self ._grids [dual_name ]
1796
+ else :
1797
+ grid = grid .create_dual_mesh (loc )
1798
+ grid .mesh_name = dual_name
1799
+ self ._grids [dual_name ] = grid
1800
+
1801
+ return grid
1802
+
1684
1803
@classmethod
1685
1804
@docstrings .dedent
1686
1805
def can_decode (cls , ds , var ):
@@ -1807,37 +1926,55 @@ def get_coord(coord):
1807
1926
if vert is None :
1808
1927
raise ValueError ("Could not find the nodes variables for the %s "
1809
1928
"coordinate!" % axis )
1810
- loc = var .attrs .get ('location' , 'face' )
1811
- if loc == 'node' :
1812
- # we assume a triangular grid and use matplotlibs triangulation
1813
- from matplotlib .tri import Triangulation
1814
- xvert , yvert = nodes
1815
- triangles = Triangulation (xvert , yvert )
1816
- if axis == 'x' :
1817
- bounds = triangles .x [triangles .triangles ]
1818
- else :
1819
- bounds = triangles .y [triangles .triangles ]
1820
- elif loc in ['edge' , 'face' ]:
1821
- connectivity = get_coord (
1822
- mesh .attrs .get ('%s_node_connectivity' % loc , '' ))
1823
- if connectivity is None :
1824
- raise ValueError (
1825
- "Could not find the connectivity information!" )
1826
- connectivity = connectivity .values
1827
- bounds = vert .values [
1828
- np .where (np .isnan (connectivity ), connectivity [:, :1 ],
1829
- connectivity ).astype (int )]
1830
- else :
1831
- raise ValueError (
1832
- "Could not interprete location attribute (%s) of mesh "
1833
- "variable %s!" % (loc , mesh .name ))
1929
+
1930
+ grid = self .get_ugrid (var , coords )
1931
+
1932
+ faces = grid .faces
1933
+ if np .ma .isMA (faces ) and faces .mask .any ():
1934
+ isnull = faces .mask
1935
+ faces = faces .filled (- 999 ).astype (int )
1936
+ for i in range (faces .shape [1 ]):
1937
+ mask = isnull [:, i ]
1938
+ if mask .any ():
1939
+ for j in range (i , faces .shape [1 ]):
1940
+ faces [mask , j ] = faces [mask , j - i ]
1941
+
1942
+ node = grid .nodes [..., 0 if axis == "x" else 1 ]
1943
+ bounds = node [faces ]
1944
+
1945
+ loc = self .infer_location (var , coords )
1946
+
1834
1947
dim0 = '__face' if loc == 'node' else var .dims [- 1 ]
1835
1948
return xr .DataArray (
1836
1949
bounds ,
1837
1950
coords = {key : val for key , val in coords .items ()
1838
1951
if (dim0 , ) == val .dims },
1839
1952
dims = (dim0 , '__bnds' , ),
1840
- name = vert .name + '_bnds' , attrs = vert .attrs .copy ())
1953
+ name = vert .name + '_bnds' , attrs = vert .attrs .copy ())
1954
+
1955
+ @docstrings .with_indent (8 )
1956
+ def infer_location (self , var , coords = None , grid = None ):
1957
+ """Infer the location for the variable.
1958
+
1959
+ Parameters
1960
+ ----------
1961
+ %(UGridDecoder.get_mesh.parameters)s
1962
+ grid: gridded.pyugrid.ugrid.UGrid
1963
+ The grid for this variable. If None, it will be created using the
1964
+ :meth:`get_ugrid` method (if necessary)
1965
+
1966
+ Returns
1967
+ -------
1968
+ str
1969
+ ``"node"``, ``"face"`` or ``"edge"``
1970
+ """
1971
+ if not var .attrs .get ('location' ):
1972
+ if grid is None :
1973
+ grid = self .get_ugrid (var , coords , loc = "face" )
1974
+ loc = grid .infer_location (var )
1975
+ else :
1976
+ loc = var .attrs ["location" ]
1977
+ return loc
1841
1978
1842
1979
@staticmethod
1843
1980
@docstrings .dedent
@@ -1880,20 +2017,71 @@ def decode_coords(ds, gridfile=None):
1880
2017
ds ._coord_names .update (extra_coords .intersection (ds .variables ))
1881
2018
return ds
1882
2019
1883
- def get_nodes (self , coord , coords ):
2020
+ def get_nodes (self , coord , coords = None ):
1884
2021
"""Get the variables containing the definition of the nodes
1885
2022
1886
2023
Parameters
1887
2024
----------
1888
2025
coord: xarray.Coordinate
1889
2026
The mesh variable
1890
- coords: dict
1891
- The coordinates to use to get node coordinates"""
2027
+ coords: dict, optional
2028
+ The coordinates to use to get node coordinates """
2029
+ if coords is None :
2030
+ coords = {}
1892
2031
def get_coord (coord ):
1893
2032
return coords .get (coord , self .ds .coords .get (coord ))
1894
2033
return list (map (get_coord ,
1895
2034
coord .attrs .get ('node_coordinates' , '' ).split ()[:2 ]))
1896
2035
2036
+ @docstrings .with_indent (8 )
2037
+ def get_xname (self , var , coords = None ):
2038
+ """Get the name of the spatial dimension
2039
+
2040
+ Parameters
2041
+ ----------
2042
+ %(CFDecoder.get_y.parameters)s
2043
+
2044
+ Returns
2045
+ -------
2046
+ str
2047
+ The dimension name
2048
+ """
2049
+
2050
+ def get_dim (name ):
2051
+ coord = coords .get (
2052
+ name , ds .coords .get (name , ds .variables .get (name ))
2053
+ )
2054
+ if coord is None :
2055
+ raise KeyError (f"Missing { loc } coordinate { name } " )
2056
+ else :
2057
+ return coord .dims [0 ]
2058
+
2059
+ ds = self .ds
2060
+ loc = self .infer_location (var , coords )
2061
+ mesh = self .get_mesh (var , coords )
2062
+ coords = coords or ds .coords
2063
+ if loc == "node" :
2064
+ return get_dim (mesh .node_coordinates .split ()[0 ])
2065
+ elif loc == "edge" :
2066
+ return get_dim (mesh .edge_node_connectivity )
2067
+ else :
2068
+ return get_dim (mesh .face_node_connectivity )
2069
+
2070
+ @docstrings .with_indent (8 )
2071
+ def get_yname (self , var , coords = None ):
2072
+ """Get the name of the spatial dimension
2073
+
2074
+ Parameters
2075
+ ----------
2076
+ %(CFDecoder.get_y.parameters)s
2077
+
2078
+ Returns
2079
+ -------
2080
+ str
2081
+ The dimension name
2082
+ """
2083
+ return self .get_xname (var , coords ) # x- and y-dimensions are the same
2084
+
1897
2085
@docstrings .dedent
1898
2086
def get_x (self , var , coords = None ):
1899
2087
"""
@@ -1911,18 +2099,25 @@ def get_x(self, var, coords=None):
1911
2099
# first we try the super class
1912
2100
ret = super (UGridDecoder , self ).get_x (var , coords )
1913
2101
# but if that doesn't work because we get the variable name in the
1914
- # dimension of `var`, we use the means of the triangles
2102
+ # dimension of `var`, we use the means of the faces
1915
2103
if ret is None or ret .name in var .dims or (hasattr (var , 'mesh' ) and
1916
2104
ret .name == var .mesh ):
1917
- bounds = self .get_cell_node_coord (var , axis = 'x' , coords = coords )
1918
- if bounds is not None :
1919
- centers = bounds .mean (axis = - 1 )
1920
- x = self .get_nodes (self .get_mesh (var , coords ), coords )[0 ]
2105
+ loc = self .infer_location (var , coords )
2106
+ x = self .get_nodes (self .get_mesh (var , coords ), coords )[0 ]
2107
+ if loc == "node" :
2108
+ return x
2109
+ else :
2110
+ grid = self .get_ugrid (var , coords , loc )
2111
+ if grid .face_coordinates is None :
2112
+ grid .build_face_coordinates ()
1921
2113
try :
1922
2114
cls = xr .IndexVariable
1923
2115
except AttributeError : # xarray < 0.9
1924
2116
cls = xr .Coordinate
1925
- return cls (x .name , centers , attrs = x .attrs .copy ())
2117
+ return cls (
2118
+ x .name , grid .face_coordinates [..., 1 ],
2119
+ attrs = x .attrs .copy ()
2120
+ )
1926
2121
else :
1927
2122
return ret
1928
2123
@@ -1946,15 +2141,22 @@ def get_y(self, var, coords=None):
1946
2141
# dimension of `var`, we use the means of the triangles
1947
2142
if ret is None or ret .name in var .dims or (hasattr (var , 'mesh' ) and
1948
2143
ret .name == var .mesh ):
1949
- bounds = self .get_cell_node_coord (var , axis = 'y' , coords = coords )
1950
- if bounds is not None :
1951
- centers = bounds .mean (axis = - 1 )
1952
- y = self .get_nodes (self .get_mesh (var , coords ), coords )[1 ]
2144
+ loc = self .infer_location (var , coords )
2145
+ y = self .get_nodes (self .get_mesh (var , coords ), coords )[1 ]
2146
+ if loc == "node" :
2147
+ return y
2148
+ else :
2149
+ grid = self .get_ugrid (var , coords , loc )
2150
+ if grid .face_coordinates is None :
2151
+ grid .build_face_coordinates ()
1953
2152
try :
1954
2153
cls = xr .IndexVariable
1955
2154
except AttributeError : # xarray < 0.9
1956
2155
cls = xr .Coordinate
1957
- return cls (y .name , centers , attrs = y .attrs .copy ())
2156
+ return cls (
2157
+ y .name , grid .face_coordinates [..., 1 ],
2158
+ attrs = y .attrs .copy ()
2159
+ )
1958
2160
else :
1959
2161
return ret
1960
2162
@@ -2818,6 +3020,7 @@ def filter_attrs(item):
2818
3020
dims = self ._new_dims
2819
3021
method = self .method
2820
3022
if dims :
3023
+ self .decoder .clear_cache ()
2821
3024
if VARIABLELABEL in self .arr .coords :
2822
3025
self ._update_concatenated (dims , method )
2823
3026
else :
0 commit comments