Skip to content

Commit 2b54c5e

Browse files
Illviljanpre-commit-ci[bot]dcherian
authored
Fix mypy errors in xarray.py, xrutils.py, cache.py (#144)
* update dim typing * Fix mypy errors in xarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * start mypy ci * Use T_DataArray and T_Dataset * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add mypy ignores * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * correct typing a bit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test newer flake8 if ellipsis passes there * Allow ellipsis in flake8 * Update core.py * Update xarray.py * Update setup.cfg * Update xarray.py * Update xarray.py * Update xarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray.py * Update pyproject.toml * Update xarray.py * Update xarray.py * hopefully no more pytest errors. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure expected_groups doesn't have None * Update flox/xarray.py Co-authored-by: Deepak Cherian <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ds_broad and longer comment * Use same for loop for similar things. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix xrutils.py * fix errors in cache.py * Turn off mypy check * Update flox/xarray.py Co-authored-by: Deepak Cherian <[email protected]> * Update flox/xarray.py Co-authored-by: Deepak Cherian <[email protected]> * Use if else format to avoid tuple creation * Update xarray.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent af3e3ce commit 2b54c5e

File tree

5 files changed

+102
-70
lines changed

5 files changed

+102
-70
lines changed

flox/cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
cache = cachey.Cache(1e6)
99
memoize = partial(cache.memoize, key=dask.base.tokenize)
1010
except ImportError:
11-
memoize = lambda x: x
11+
memoize = lambda x: x # type: ignore

flox/core.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
import operator
66
from collections import namedtuple
77
from functools import partial, reduce
8-
from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping, Sequence, Union
8+
from typing import (
9+
TYPE_CHECKING,
10+
Any,
11+
Callable,
12+
Dict,
13+
Iterable,
14+
Mapping,
15+
Sequence,
16+
Union,
17+
)
918

1019
import numpy as np
1120
import numpy_groupies as npg
@@ -1282,8 +1291,8 @@ def _assert_by_is_aligned(shape, by):
12821291

12831292

12841293
def _convert_expected_groups_to_index(
1285-
expected_groups: tuple, isbin: bool, sort: bool
1286-
) -> pd.Index | None:
1294+
expected_groups: Iterable, isbin: Sequence[bool], sort: bool
1295+
) -> tuple[pd.Index | None]:
12871296
out = []
12881297
for ex, isbin_ in zip(expected_groups, isbin):
12891298
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin):

flox/xarray.py

+86-65
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Hashable, Iterable, Sequence
3+
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Sequence, Union
44

55
import numpy as np
66
import pandas as pd
@@ -19,7 +19,10 @@
1919
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric
2020

2121
if TYPE_CHECKING:
22-
from xarray import DataArray, Dataset, Resample
22+
from xarray.core.resample import Resample
23+
from xarray.core.types import T_DataArray, T_Dataset
24+
25+
Dims = Union[str, Iterable[Hashable], None]
2326

2427

2528
def _get_input_core_dims(group_names, dim, ds, grouper_dims):
@@ -51,13 +54,13 @@ def lookup_order(dimension):
5154

5255

5356
def xarray_reduce(
54-
obj: Dataset | DataArray,
55-
*by: DataArray | Iterable[str] | Iterable[DataArray],
57+
obj: T_Dataset | T_DataArray,
58+
*by: T_DataArray | Hashable,
5659
func: str | Aggregation,
5760
expected_groups=None,
5861
isbin: bool | Sequence[bool] = False,
5962
sort: bool = True,
60-
dim: Hashable = None,
63+
dim: Dims | ellipsis = None,
6164
split_out: int = 1,
6265
fill_value=None,
6366
method: str = "map-reduce",
@@ -203,8 +206,11 @@ def xarray_reduce(
203206
if keep_attrs is None:
204207
keep_attrs = True
205208

206-
if isinstance(isbin, bool):
207-
isbin = (isbin,) * nby
209+
if isinstance(isbin, Sequence):
210+
isbins = isbin
211+
else:
212+
isbins = (isbin,) * nby
213+
208214
if expected_groups is None:
209215
expected_groups = (None,) * nby
210216
if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list
@@ -217,78 +223,86 @@ def xarray_reduce(
217223
raise NotImplementedError
218224

219225
# eventually drop the variables we are grouping by
220-
maybe_drop = [b for b in by if isinstance(b, str)]
226+
maybe_drop = [b for b in by if isinstance(b, Hashable)]
221227
unindexed_dims = tuple(
222228
b
223-
for b, isbin_ in zip(by, isbin)
224-
if isinstance(b, str) and not isbin_ and b in obj.dims and b not in obj.indexes
229+
for b, isbin_ in zip(by, isbins)
230+
if isinstance(b, Hashable) and not isbin_ and b in obj.dims and b not in obj.indexes
225231
)
226232

227-
by: tuple[DataArray] = tuple(obj[g] if isinstance(g, str) else g for g in by) # type: ignore
233+
by_da = tuple(obj[g] if isinstance(g, Hashable) else g for g in by)
228234

229235
grouper_dims = []
230-
for g in by:
236+
for g in by_da:
231237
for d in g.dims:
232238
if d not in grouper_dims:
233239
grouper_dims.append(d)
234240

235-
if isinstance(obj, xr.DataArray):
236-
ds = obj._to_temp_dataset()
237-
else:
241+
if isinstance(obj, xr.Dataset):
238242
ds = obj
243+
else:
244+
ds = obj._to_temp_dataset()
239245

240246
ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])
241247

242248
if dim is Ellipsis:
243249
if nby > 1:
244250
raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.")
245-
dim = tuple(obj.dims)
246-
if by[0].name in ds.dims and not isbin[0]:
247-
dim = tuple(d for d in dim if d != by[0].name)
251+
name_ = by_da[0].name
252+
if name_ in ds.dims and not isbins[0]:
253+
dim_tuple = tuple(d for d in obj.dims if d != name_)
254+
else:
255+
dim_tuple = tuple(obj.dims)
248256
elif dim is not None:
249-
dim = _atleast_1d(dim)
257+
dim_tuple = _atleast_1d(dim)
250258
else:
251-
dim = tuple()
259+
dim_tuple = tuple()
252260

253261
# broadcast all variables against each other along all dimensions in `by` variables
254262
# don't exclude `dim` because it need not be a dimension in any of the `by` variables!
255263
# in the case where dim is Ellipsis, and by.ndim < obj.ndim
256264
# then we also broadcast `by` to all `obj.dims`
257265
# TODO: avoid this broadcasting
258-
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim)
259-
ds, *by = xr.broadcast(ds, *by, exclude=exclude_dims)
266+
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)
267+
ds_broad, *by_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)
260268

261-
if not dim:
262-
dim = tuple(by[0].dims)
269+
# all members of by_broad have the same dimensions
270+
# so we just pull by_broad[0].dims if dim is None
271+
if not dim_tuple:
272+
dim_tuple = tuple(by_broad[0].dims)
263273

264-
if any(d not in grouper_dims and d not in obj.dims for d in dim):
274+
if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
265275
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
266276

267-
dims_not_in_groupers = tuple(d for d in dim if d not in grouper_dims)
268-
if dims_not_in_groupers == tuple(dim) and not any(isbin):
277+
dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims)
278+
if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins):
269279
# reducing along a dimension along which groups do not vary
270280
# This is really just a normal reduction.
271281
# This is not right when binning so we exclude.
272-
if skipna and isinstance(func, str):
273-
dsfunc = func[3:]
282+
if isinstance(func, str):
283+
dsfunc = func[3:] if skipna else func
274284
else:
275-
dsfunc = func
285+
raise NotImplementedError(
286+
"func must be a string when reducing along a dimension not present in `by`"
287+
)
276288
# TODO: skipna needs test
277-
result = getattr(ds, dsfunc)(dim=dim, skipna=skipna)
289+
result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna)
278290
if isinstance(obj, xr.DataArray):
279291
return obj._from_temp_dataset(result)
280292
else:
281293
return result
282294

283-
axis = tuple(range(-len(dim), 0))
284-
group_names = tuple(g.name if not binned else f"{g.name}_bins" for g, binned in zip(by, isbin))
285-
286-
group_shape = [None] * len(by)
287-
expected_groups = list(expected_groups)
295+
axis = tuple(range(-len(dim_tuple), 0))
288296

289297
# Set expected_groups and convert to index since we need coords, sizes
290298
# for output xarray objects
291-
for idx, (b, expect, isbin_) in enumerate(zip(by, expected_groups, isbin)):
299+
expected_groups = list(expected_groups)
300+
group_names: tuple[Any, ...] = ()
301+
group_sizes: dict[Any, int] = {}
302+
for idx, (b_, expect, isbin_) in enumerate(zip(by_broad, expected_groups, isbins)):
303+
group_name = b_.name if not isbin_ else f"{b_.name}_bins"
304+
group_names += (group_name,)
305+
292306
if isbin_ and isinstance(expect, int):
293307
raise NotImplementedError(
294308
"flox does not support binning into an integer number of bins yet."
@@ -297,13 +311,21 @@ def xarray_reduce(
297311
if isbin_:
298312
raise ValueError(
299313
f"Please provided bin edges for group variable {idx} "
300-
f"named {group_names[idx]} in expected_groups."
314+
f"named {group_name} in expected_groups."
301315
)
302-
expected_groups[idx] = _get_expected_groups(b.data, sort=sort, raise_if_dask=True)
303-
304-
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort=sort)
305-
group_shape = tuple(len(e) for e in expected_groups)
306-
group_sizes = dict(zip(group_names, group_shape))
316+
expect_ = _get_expected_groups(b_.data, sort=sort, raise_if_dask=True)
317+
else:
318+
expect_ = expect
319+
expect_index = _convert_expected_groups_to_index((expect_,), (isbin_,), sort=sort)[0]
320+
321+
# The if-check is for type hinting mainly, it narrows down the return
322+
# type of _convert_expected_groups_to_index to pure pd.Index:
323+
if expect_index is not None:
324+
expected_groups[idx] = expect_index
325+
group_sizes[group_name] = len(expect_index)
326+
else:
327+
# This will never be reached
328+
raise ValueError("expect_index cannot be None")
307329

308330
def wrapper(array, *by, func, skipna, **kwargs):
309331
# Handle skipna here because I need to know dtype to make a good default choice.
@@ -349,20 +371,20 @@ def wrapper(array, *by, func, skipna, **kwargs):
349371
if isinstance(obj, xr.Dataset):
350372
# broadcasting means the group dim gets added to ds, so we check the original obj
351373
for k, v in obj.data_vars.items():
352-
is_missing_dim = not (any(d in v.dims for d in dim))
374+
is_missing_dim = not (any(d in v.dims for d in dim_tuple))
353375
if is_missing_dim:
354376
missing_dim[k] = v
355377

356-
input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims)
378+
input_core_dims = _get_input_core_dims(group_names, dim_tuple, ds_broad, grouper_dims)
357379
input_core_dims += [input_core_dims[-1]] * (nby - 1)
358380

359381
actual = xr.apply_ufunc(
360382
wrapper,
361-
ds.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
362-
*by,
383+
ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
384+
*by_broad,
363385
input_core_dims=input_core_dims,
364386
# for xarray's test_groupby_duplicate_coordinate_labels
365-
exclude_dims=set(dim),
387+
exclude_dims=set(dim_tuple),
366388
output_core_dims=[group_names],
367389
dask="allowed",
368390
dask_gufunc_kwargs=dict(output_sizes=group_sizes),
@@ -379,27 +401,27 @@ def wrapper(array, *by, func, skipna, **kwargs):
379401
"engine": engine,
380402
"reindex": reindex,
381403
"expected_groups": tuple(expected_groups),
382-
"isbin": isbin,
404+
"isbin": isbins,
383405
"finalize_kwargs": finalize_kwargs,
384406
},
385407
)
386408

387409
# restore non-dim coord variables without the core dimension
388410
# TODO: shouldn't apply_ufunc handle this?
389-
for var in set(ds.variables) - set(ds.dims):
390-
if all(d not in ds[var].dims for d in dim):
391-
actual[var] = ds[var]
411+
for var in set(ds_broad.variables) - set(ds_broad.dims):
412+
if all(d not in ds_broad[var].dims for d in dim_tuple):
413+
actual[var] = ds_broad[var]
392414

393-
for name, expect, by_ in zip(group_names, expected_groups, by):
415+
for name, expect, by_ in zip(group_names, expected_groups, by_broad):
394416
# Can't remove this till xarray handles IntervalIndex
395417
if isinstance(expect, pd.IntervalIndex):
396418
expect = expect.to_numpy()
397419
if isinstance(actual, xr.Dataset) and name in actual:
398420
actual = actual.drop_vars(name)
399421
# When grouping by MultiIndex, expect is an pd.Index wrapping
400422
# an object array of tuples
401-
if name in ds.indexes and isinstance(ds.indexes[name], pd.MultiIndex):
402-
levelnames = ds.indexes[name].names
423+
if name in ds_broad.indexes and isinstance(ds_broad.indexes[name], pd.MultiIndex):
424+
levelnames = ds_broad.indexes[name].names
403425
expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames)
404426
actual[name] = expect
405427
if Version(xr.__version__) > Version("2022.03.0"):
@@ -414,18 +436,17 @@ def wrapper(array, *by, func, skipna, **kwargs):
414436

415437
if nby == 1:
416438
for var in actual:
417-
if isinstance(obj, xr.DataArray):
418-
template = obj
419-
else:
439+
if isinstance(obj, xr.Dataset):
420440
template = obj[var]
441+
else:
442+
template = obj
443+
421444
if actual[var].ndim > 1:
422-
actual[var] = _restore_dim_order(actual[var], template, by[0])
445+
actual[var] = _restore_dim_order(actual[var], template, by_broad[0])
423446

424447
if missing_dim:
425448
for k, v in missing_dim.items():
426-
missing_group_dims = {
427-
dim: size for dim, size in group_sizes.items() if dim not in v.dims
428-
}
449+
missing_group_dims = {d: size for d, size in group_sizes.items() if d not in v.dims}
429450
# The expand_dims is for backward compat with xarray's questionable behaviour
430451
if missing_group_dims:
431452
actual[k] = v.expand_dims(missing_group_dims).variable
@@ -439,9 +460,9 @@ def wrapper(array, *by, func, skipna, **kwargs):
439460

440461

441462
def rechunk_for_cohorts(
442-
obj: DataArray | Dataset,
463+
obj: T_DataArray | T_Dataset,
443464
dim: str,
444-
labels: DataArray,
465+
labels: T_DataArray,
445466
force_new_chunk_at,
446467
chunksize: int | None = None,
447468
ignore_old_chunks: bool = False,
@@ -486,7 +507,7 @@ def rechunk_for_cohorts(
486507
)
487508

488509

489-
def rechunk_for_blockwise(obj: DataArray | Dataset, dim: str, labels: DataArray):
510+
def rechunk_for_blockwise(obj: T_DataArray | T_Dataset, dim: str, labels: T_DataArray):
490511
"""
491512
Rechunks array so that group boundaries line up with chunk boundaries, allowing
492513
embarassingly parallel group reductions.

flox/xrutils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
dask_array_type = dask.array.Array
2121
except ImportError:
22-
dask_array_type = ()
22+
dask_array_type = () # type: ignore
2323

2424

2525
def asarray(data, xp=np):

setup.cfg

+2
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,5 @@ per-file-ignores =
5757
exclude=
5858
.eggs
5959
doc
60+
builtins =
61+
ellipsis

0 commit comments

Comments
 (0)