Skip to content

Commit d5d3e19

Browse files
ocefpafshoyer
authored andcommitted
Add a filter_by_attrs method to Dataset (#844)
* get_variables_by_attributes * tests * Changelog entry * Review actions * Improved whats_new * Changing to data variables only * Return an empty Dataset * Review actions * More review actions * Fix docstring indentation * Create an example dataset * api docs * Rename to filter_by_attrs
1 parent a128e27 commit d5d3e19

File tree

4 files changed

+125
-0
lines changed

4 files changed

+125
-0
lines changed

doc/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ Dataset methods
395395
Dataset.close
396396
Dataset.load
397397
Dataset.chunk
398+
Dataset.filter_by_attrs
398399

399400
DataArray methods
400401
-----------------

doc/whats-new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ Enhancements
8080
allowing more control on the colorbar (:issue:`872`).
8181
By `Fabien Maussion <https://github.com/fmaussion>`_.
8282

83+
- New Dataset method :py:meth:`filter_by_attrs`, akin to
84+
``netCDF4.Dataset.get_variables_by_attributes``, to easily filter
85+
data variables using its attributes.
86+
`Filipe Fernandes <https://github.com/ocefpaf>`_.
87+
8388
Bug fixes
8489
~~~~~~~~~
8590

xarray/core/dataset.py

+83
Original file line numberDiff line numberDiff line change
@@ -2176,5 +2176,88 @@ def real(self):
21762176
def imag(self):
21772177
return self._unary_op(lambda x: x.imag, keep_attrs=True)(self)
21782178

2179+
def filter_by_attrs(self, **kwargs):
2180+
"""Returns a ``Dataset`` with variables that match specific conditions.
2181+
2182+
Can pass in ``key=value ``or ``key=callable``. Variables are returned
2183+
that contain all of the matches or callable returns True. If using a
2184+
callable note that it should accept a single parameter only,
2185+
the attribute value.
2186+
2187+
Parameters
2188+
----------
2189+
**kwargs : key=value
2190+
key : str
2191+
Attribute name.
2192+
value : callable or obj
2193+
If value is a callable, it should return a boolean in the form
2194+
of bool = func(attr) where attr is da.attrs[key].
2195+
Otherwise, value will be compared to the each
2196+
DataArray's attrs[key].
2197+
2198+
Returns
2199+
-------
2200+
new : Dataset
2201+
New dataset with variables filtered by attribute.
2202+
2203+
Examples
2204+
--------
2205+
# "Create an example dataset:
2206+
>>> import numpy as np
2207+
>>> import pandas as pd
2208+
>>> import xarray as xr
2209+
>>> temp = 15 + 8 * np.random.randn(2, 2, 3)
2210+
>>> precip = 10 * np.random.rand(2, 2, 3)
2211+
>>> lon = [[-99.83, -99.32], [-99.79, -99.23]]
2212+
>>> lat = [[42.25, 42.21], [42.63, 42.59]]
2213+
>>> dims = ['x', 'y', 'time']
2214+
>>> temp_attr = dict(standard_name='air_potential_temperature')
2215+
>>> precip_attr = dict(standard_name='convective_precipitation_flux')
2216+
>>> ds = xr.Dataset({
2217+
... 'temperature': (dims, temp, temp_attr),
2218+
... 'precipitation': (dims, precip, precip_attr)},
2219+
... coords={
2220+
... 'lon': (['x', 'y'], lon),
2221+
... 'lat': (['x', 'y'], lat),
2222+
... 'time': pd.date_range('2014-09-06', periods=3),
2223+
... 'reference_time': pd.Timestamp('2014-09-05')})
2224+
>>> # Get variables matching a specific standard_name.
2225+
>>> ds.filter_by_attrs(standard_name='convective_precipitation_flux')
2226+
<xarray.Dataset>
2227+
Dimensions: (time: 3, x: 2, y: 2)
2228+
Coordinates:
2229+
* x (x) int64 0 1
2230+
* time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08
2231+
lat (x, y) float64 42.25 42.21 42.63 42.59
2232+
* y (y) int64 0 1
2233+
reference_time datetime64[ns] 2014-09-05
2234+
lon (x, y) float64 -99.83 -99.32 -99.79 -99.23
2235+
Data variables:
2236+
precipitation (x, y, time) float64 4.178 2.307 6.041 6.046 0.06648 ...
2237+
>>> # Get all variables that have a standard_name attribute.
2238+
>>> standard_name = lambda v: v is not None
2239+
>>> ds.filter_by_attrs(standard_name=standard_name)
2240+
<xarray.Dataset>
2241+
Dimensions: (time: 3, x: 2, y: 2)
2242+
Coordinates:
2243+
lon (x, y) float64 -99.83 -99.32 -99.79 -99.23
2244+
lat (x, y) float64 42.25 42.21 42.63 42.59
2245+
* x (x) int64 0 1
2246+
* y (y) int64 0 1
2247+
* time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08
2248+
reference_time datetime64[ns] 2014-09-05
2249+
Data variables:
2250+
temperature (x, y, time) float64 25.86 20.82 6.954 23.13 10.25 11.68 ...
2251+
precipitation (x, y, time) float64 5.702 0.9422 2.075 1.178 3.284 ...
2252+
2253+
"""
2254+
selection = []
2255+
for var_name, variable in self.data_vars.items():
2256+
for attr_name, pattern in kwargs.items():
2257+
attr_value = variable.attrs.get(attr_name)
2258+
if ((callable(pattern) and pattern(attr_value))
2259+
or attr_value == pattern):
2260+
selection.append(var_name)
2261+
return self[selection]
21792262

21802263
ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False)

xarray/test/test_dataset.py

+36
Original file line numberDiff line numberDiff line change
@@ -2494,3 +2494,39 @@ def test_setattr_raises(self):
24942494
ds.foo = 2
24952495
with self.assertRaisesRegexp(AttributeError, 'cannot set attr'):
24962496
ds.other = 2
2497+
2498+
def test_filter_by_attrs(self):
2499+
precip = dict(standard_name='convective_precipitation_flux')
2500+
temp0 = dict(standard_name='air_potential_temperature', height='0 m')
2501+
temp10 = dict(standard_name='air_potential_temperature', height='10 m')
2502+
ds = Dataset({'temperature_0': (['t'], [0], temp0),
2503+
'temperature_10': (['t'], [0], temp10),
2504+
'precipitation': (['t'], [0], precip)},
2505+
coords={'time': (['t'], [0], dict(axis='T'))})
2506+
2507+
# Test return empty Dataset.
2508+
ds.filter_by_attrs(standard_name='invalid_standard_name')
2509+
new_ds = ds.filter_by_attrs(standard_name='invalid_standard_name')
2510+
self.assertFalse(bool(new_ds.data_vars))
2511+
2512+
# Test return one DataArray.
2513+
new_ds = ds.filter_by_attrs(standard_name='convective_precipitation_flux')
2514+
self.assertEqual(new_ds['precipitation'].standard_name, 'convective_precipitation_flux')
2515+
self.assertDatasetEqual(new_ds['precipitation'], ds['precipitation'])
2516+
2517+
# Test return more than one DataArray.
2518+
new_ds = ds.filter_by_attrs(standard_name='air_potential_temperature')
2519+
self.assertEqual(len(new_ds.data_vars), 2)
2520+
for var in new_ds.data_vars:
2521+
self.assertEqual(new_ds[var].standard_name, 'air_potential_temperature')
2522+
2523+
# Test callable.
2524+
new_ds = ds.filter_by_attrs(height=lambda v: v is not None)
2525+
self.assertEqual(len(new_ds.data_vars), 2)
2526+
for var in new_ds.data_vars:
2527+
self.assertEqual(new_ds[var].standard_name, 'air_potential_temperature')
2528+
2529+
new_ds = ds.filter_by_attrs(height='10 m')
2530+
self.assertEqual(len(new_ds.data_vars), 1)
2531+
for var in new_ds.data_vars:
2532+
self.assertEqual(new_ds[var].height, '10 m')

0 commit comments

Comments
 (0)