Skip to content

Commit 8b1434d

Browse files
authored
Suppress warnings (#197)
* Suppress warnings Closes #188 * Silence a bunch of dask warnings * Fix mypy
1 parent f44738e commit 8b1434d

File tree

5 files changed

+62
-19
lines changed

5 files changed

+62
-19
lines changed

flox/aggregations.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import copy
4+
import warnings
45
from functools import partial
56

67
import numpy as np
@@ -48,9 +49,12 @@ def generic_aggregate(
4849

4950
group_idx = np.asarray(group_idx, like=array)
5051

51-
return method(
52-
group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
53-
)
52+
with warnings.catch_warnings():
53+
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
54+
result = method(
55+
group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
56+
)
57+
return result
5458

5559

5660
def _normalize_dtype(dtype, array_dtype, fill_value=None):
@@ -243,11 +247,18 @@ def __repr__(self):
243247
fill_value=1,
244248
final_fill_value=dtypes.NA,
245249
)
250+
251+
252+
def _mean_finalize(sum_, count):
253+
with np.errstate(invalid="ignore", divide="ignore"):
254+
return sum_ / count
255+
256+
246257
mean = Aggregation(
247258
"mean",
248259
chunk=("sum", "nanlen"),
249260
combine=("sum", "sum"),
250-
finalize=lambda sum_, count: sum_ / count,
261+
finalize=_mean_finalize,
251262
fill_value=(0, 0),
252263
dtypes=(None, np.intp),
253264
final_dtype=np.floating,
@@ -256,7 +267,7 @@ def __repr__(self):
256267
"nanmean",
257268
chunk=("nansum", "nanlen"),
258269
combine=("sum", "sum"),
259-
finalize=lambda sum_, count: sum_ / count,
270+
finalize=_mean_finalize,
260271
fill_value=(0, 0),
261272
dtypes=(None, np.intp),
262273
final_dtype=np.floating,
@@ -265,7 +276,8 @@ def __repr__(self):
265276

266277
# TODO: fix this for complex numbers
267278
def _var_finalize(sumsq, sum_, count, ddof=0):
268-
result = (sumsq - (sum_**2 / count)) / (count - ddof)
279+
with np.errstate(invalid="ignore", divide="ignore"):
280+
result = (sumsq - (sum_**2 / count)) / (count - ddof)
269281
result[count <= ddof] = np.nan
270282
return result
271283

@@ -352,6 +364,10 @@ def _zip_index(array_, idx_):
352364
)
353365

354366

367+
def _pick_second(*x):
368+
return x[1]
369+
370+
355371
argmax = Aggregation(
356372
"argmax",
357373
preprocess=argreduce_preprocess,
@@ -360,7 +376,7 @@ def _zip_index(array_, idx_):
360376
reduction_type="argreduce",
361377
fill_value=(dtypes.NINF, 0),
362378
final_fill_value=-1,
363-
finalize=lambda *x: x[1],
379+
finalize=_pick_second,
364380
dtypes=(None, np.intp),
365381
final_dtype=np.intp,
366382
)
@@ -373,7 +389,7 @@ def _zip_index(array_, idx_):
373389
reduction_type="argreduce",
374390
fill_value=(dtypes.INF, 0),
375391
final_fill_value=-1,
376-
finalize=lambda *x: x[1],
392+
finalize=_pick_second,
377393
dtypes=(None, np.intp),
378394
final_dtype=np.intp,
379395
)
@@ -386,7 +402,7 @@ def _zip_index(array_, idx_):
386402
reduction_type="argreduce",
387403
fill_value=(dtypes.NINF, -1),
388404
final_fill_value=-1,
389-
finalize=lambda *x: x[1],
405+
finalize=_pick_second,
390406
dtypes=(None, np.intp),
391407
final_dtype=np.intp,
392408
)
@@ -399,7 +415,7 @@ def _zip_index(array_, idx_):
399415
reduction_type="argreduce",
400416
fill_value=(dtypes.INF, -1),
401417
final_fill_value=-1,
402-
finalize=lambda *x: x[1],
418+
finalize=_pick_second,
403419
dtypes=(None, np.intp),
404420
final_dtype=np.intp,
405421
)

flox/core.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import itertools
55
import math
66
import operator
7+
import warnings
78
from collections import namedtuple
89
from functools import partial, reduce
910
from numbers import Integral
@@ -881,7 +882,9 @@ def _simple_combine(
881882
for idx, combine in enumerate(agg.combine):
882883
array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis_)
883884
assert array.ndim >= 2
884-
result = getattr(np, combine)(array, axis=axis_, keepdims=True)
885+
with warnings.catch_warnings():
886+
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
887+
result = getattr(np, combine)(array, axis=axis_, keepdims=True)
885888
if is_aggregate:
886889
# squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate
887890
result = result.squeeze(axis=DUMMY_AXIS)

flox/xarray.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def xarray_reduce(
223223
raise NotImplementedError("sort must be True for xarray_reduce")
224224

225225
# eventually drop the variables we are grouping by
226-
maybe_drop = [b for b in by if isinstance(b, Hashable)]
226+
maybe_drop = set(b for b in by if isinstance(b, Hashable))
227227
unindexed_dims = tuple(
228228
b
229229
for b, isbin_ in zip(by, isbins)
@@ -243,6 +243,20 @@ def xarray_reduce(
243243
else:
244244
ds = obj._to_temp_dataset()
245245

246+
try:
247+
from xarray.indexes import PandasMultiIndex
248+
except ImportError:
249+
PandasMultiIndex = tuple() # type: ignore
250+
251+
more_drop = set()
252+
for var in maybe_drop:
253+
maybe_midx = ds._indexes.get(var, None)
254+
if isinstance(maybe_midx, PandasMultiIndex):
255+
idx_coord_names = set(maybe_midx.index.names + [maybe_midx.dim])
256+
idx_other_names = idx_coord_names - set(maybe_drop)
257+
more_drop.update(idx_other_names)
258+
maybe_drop.update(more_drop)
259+
246260
ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])
247261

248262
if dim is Ellipsis:

tests/test_core.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import itertools
4+
import warnings
45
from functools import partial, reduce
56
from typing import TYPE_CHECKING
67

@@ -204,11 +205,18 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
204205
for kwargs in finalize_kwargs:
205206
flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value)
206207
with np.errstate(invalid="ignore", divide="ignore"):
207-
if "arg" in func and add_nan_by:
208-
array[..., nanmask] = np.nan
209-
expected = getattr(np, "nan" + func)(array, axis=-1, **kwargs)
210-
else:
211-
expected = getattr(np, func)(array[..., ~nanmask], axis=-1, **kwargs)
208+
with warnings.catch_warnings():
209+
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
210+
warnings.filterwarnings("ignore", r"Degrees of freedom <= 0 for slice")
211+
warnings.filterwarnings("ignore", r"Mean of empty slice")
212+
213+
# computing silences a bunch of dask warnings
214+
array_ = array.compute() if chunks is not None else array
215+
if "arg" in func and add_nan_by:
216+
array_[..., nanmask] = np.nan
217+
expected = getattr(np, "nan" + func)(array_, axis=-1, **kwargs)
218+
else:
219+
expected = getattr(np, func)(array_[..., ~nanmask], axis=-1, **kwargs)
212220
for _ in range(nby):
213221
expected = np.expand_dims(expected, -1)
214222

tests/test_xarray.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def test_xarray_resample_dataset_multiple_arrays(engine):
286286
# The separate computes are necessary here to force xarray
287287
# to compute all variables in result at the same time.
288288
expected = resampler.mean().compute()
289-
result = resample_reduce(resampler, "mean", engine=engine).compute()
289+
with pytest.warns(DeprecationWarning):
290+
result = resample_reduce(resampler, "mean", engine=engine).compute()
290291
xr.testing.assert_allclose(expected, result)
291292

292293

@@ -450,7 +451,8 @@ def test_datetime_array_reduce(use_cftime, func, engine):
450451
name="time",
451452
)
452453
expected = getattr(time.resample(time="YS"), func)()
453-
actual = resample_reduce(time.resample(time="YS"), func=func, engine=engine)
454+
with pytest.warns(DeprecationWarning):
455+
actual = resample_reduce(time.resample(time="YS"), func=func, engine=engine)
454456
assert_equal(expected, actual)
455457

456458

0 commit comments

Comments
 (0)