Skip to content

Commit e26fb39

Browse files
authored
numpy 2.0 compat (#311)
* numpy 2.0 compat * Update upstream-dev-ci * Update again * Fix env building * Update type-ignore * Add mamba * Better version check * Remove numbagg from usptream * Allow deps * add dateutil * add list deps * mamba -> micromamba * remove env cache * change order * remove * update env * Try again * Try building cftime * Try again * Small updates * Remove netCDF4 * fix type ignore
1 parent 26a9541 commit e26fb39

10 files changed

+103
-27
lines changed

.github/workflows/upstream-dev-ci.yaml

+38-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ on:
77
types: [opened, reopened, synchronize, labeled]
88
branches:
99
- main
10+
paths:
11+
- ".github/workflows/upstream-dev-ci.yaml"
12+
- "ci/upstream-dev-env.yml"
1013
schedule:
1114
- cron: "0 0 * * *" # Daily “At 00:00” UTC
1215
workflow_dispatch: # allows you to trigger the workflow run manually
@@ -41,16 +44,49 @@ jobs:
4144
- name: Set up conda environment
4245
uses: mamba-org/setup-micromamba@v1
4346
with:
44-
environment-file: ci/upstream-dev-env.yml
4547
environment-name: flox-tests
4648
init-shell: bash
47-
cache-environment: true
49+
# cache-environment: true
50+
# micromamba list does not list pip dependencies, so install mamba
4851
create-args: >-
52+
mamba
53+
pip
4954
python=${{ matrix.python-version }}
5055
pytest-reportlog
56+
57+
- name: Install upstream dev dependencies
58+
run: |
59+
# install cython for building cftime without build isolation
60+
micromamba install -f ci/upstream-dev-env.yml
61+
micromamba remove --force numpy scipy pandas cftime
62+
python -m pip install \
63+
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \
64+
--no-deps \
65+
--pre \
66+
--upgrade \
67+
numpy \
68+
scipy \
69+
pandas \
70+
xarray
71+
# without build isolation for packages compiling against numpy
72+
# TODO: remove once there are `numpy>=2.0` builds for cftime
73+
python -m pip install \
74+
--no-deps \
75+
--upgrade \
76+
--no-build-isolation \
77+
git+https://github.com/Unidata/cftime
78+
python -m pip install \
79+
git+https://github.com/dask/dask \
80+
git+https://github.com/ml31415/numpy-groupies
81+
5182
- name: Install flox
5283
run: |
5384
python -m pip install --no-deps -e .
85+
86+
- name: List deps
87+
run: |
88+
# micromamba list does not list pip dependencies
89+
mamba list
5490
- name: Run Tests
5591
if: success()
5692
id: status

ci/environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ channels:
44
dependencies:
55
- asv
66
- cachey
7+
- cftime
78
- codecov
89
- dask-core
9-
- netcdf4
1010
- pandas
1111
- numpy>=1.22
1212
- scipy

ci/minimal-requirements.yml

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ channels:
33
- conda-forge
44
dependencies:
55
- codecov
6-
- netcdf4
76
- pip
87
- pytest
98
- pytest-cov

ci/no-dask.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ channels:
33
- conda-forge
44
dependencies:
55
- codecov
6-
- netcdf4
76
- pandas
7+
- cftime
88
- numpy>=1.22
99
- scipy
1010
- pip

ci/no-numba.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ channels:
44
dependencies:
55
- asv
66
- cachey
7+
- cftime
78
- codecov
89
- dask-core
9-
- netcdf4
1010
- pandas
1111
- numpy>=1.22
1212
- scipy

ci/no-xarray.yml

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ channels:
33
- conda-forge
44
dependencies:
55
- codecov
6-
- netcdf4
76
- pandas
87
- numpy>=1.22
98
- scipy

ci/upstream-dev-env.yml

+17-11
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,25 @@ channels:
44
dependencies:
55
- cachey
66
- codecov
7-
- netcdf4
87
- pooch
98
- toolz
10-
- numba
11-
- scipy
12-
- pytest
13-
- pytest-cov
9+
# - numpy
10+
# - pandas
11+
# - scipy
1412
- pytest-pretty
1513
- pytest-xdist
1614
- pip
17-
- pip:
18-
- git+https://github.com/pydata/xarray
19-
- git+https://github.com/pandas-dev/pandas
20-
- git+https://github.com/dask/dask
21-
- git+https://github.com/ml31415/numpy-groupies
22-
- git+https://github.com/numbagg/numbagg
15+
# for cftime
16+
- cython>=0.29.20
17+
- py-cpuinfo
18+
# - numba
19+
- pytest
20+
- pytest-cov
21+
# for upstream pandas
22+
- python-dateutil
23+
- pytz
24+
# - pip:
25+
# - git+https://github.com/pydata/xarray
26+
# - git+https://github.com/dask/dask
27+
# - git+https://github.com/ml31415/numpy-groupies
28+
# # - git+https://github.com/numbagg/numbagg

flox/core.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@
3838
from .cache import memoize
3939
from .xrutils import is_duck_array, is_duck_dask_array, isnull, module_available
4040

41+
if module_available("numpy", minversion="2.0.0"):
42+
from numpy.lib.array_utils import ( # type: ignore[import-not-found]
43+
normalize_axis_tuple,
44+
)
45+
else:
46+
from numpy.core.numeric import normalize_axis_tuple # type: ignore[attr-defined]
47+
4148
HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")
4249

4350
if TYPE_CHECKING:
@@ -2179,8 +2186,7 @@ def groupby_reduce(
21792186
if axis is None:
21802187
axis_ = tuple(array.ndim + np.arange(-by_.ndim, 0))
21812188
else:
2182-
# TODO: How come this function doesn't exist according to mypy?
2183-
axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore[attr-defined]
2189+
axis_ = normalize_axis_tuple(axis, array.ndim)
21842190
nax = len(axis_)
21852191

21862192
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)

tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def LooseVersion(vstring):
4545
return packaging.version.Version(vstring)
4646

4747

48+
has_cftime, requires_cftime = _importorskip("cftime")
4849
has_dask, requires_dask = _importorskip("dask")
4950
has_numba, requires_numba = _importorskip("numba")
5051
has_numbagg, requires_numbagg = _importorskip("numbagg")

tests/test_xarray.py

+36-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99
from flox.xarray import rechunk_for_blockwise, xarray_reduce
1010

11-
from . import assert_equal, has_dask, raise_if_dask_computes, requires_dask
11+
from . import (
12+
assert_equal,
13+
has_dask,
14+
raise_if_dask_computes,
15+
requires_cftime,
16+
requires_dask,
17+
)
1218

1319
if has_dask:
1420
import dask
@@ -178,10 +184,18 @@ def test_validate_expected_groups(expected_groups):
178184
)
179185

180186

187+
@requires_cftime
181188
@requires_dask
182189
def test_xarray_reduce_single_grouper(engine):
183190
# DataArray
184-
ds = xr.tutorial.open_dataset("rasm", chunks={"time": 9})
191+
ds = xr.Dataset(
192+
{"Tair": (("time", "x", "y"), dask.array.ones((36, 205, 275), chunks=(9, -1, -1)))},
193+
coords={
194+
"time": xr.date_range(
195+
"1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap"
196+
)
197+
},
198+
)
185199
actual = xarray_reduce(ds.Tair, ds.time.dt.month, func="mean", engine=engine)
186200
expected = ds.Tair.groupby("time.month").mean()
187201
xr.testing.assert_allclose(actual, expected)
@@ -355,7 +369,14 @@ def test_xarray_groupby_bins(chunks, engine):
355369
def test_func_is_aggregation():
356370
from flox.aggregations import mean
357371

358-
ds = xr.tutorial.open_dataset("rasm", chunks={"time": 9})
372+
ds = xr.Dataset(
373+
{"Tair": (("time", "x", "y"), dask.array.ones((36, 205, 275), chunks=(9, -1, -1)))},
374+
coords={
375+
"time": xr.date_range(
376+
"1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap"
377+
)
378+
},
379+
)
359380
expected = xarray_reduce(ds.Tair, ds.time.dt.month, func="mean")
360381
actual = xarray_reduce(ds.Tair, ds.time.dt.month, func=mean)
361382
xr.testing.assert_allclose(actual, expected)
@@ -392,10 +413,18 @@ def test_func_is_aggregation():
392413
@requires_dask
393414
@pytest.mark.parametrize("method", ["cohorts", "map-reduce"])
394415
def test_groupby_bins_indexed_coordinate(method):
395-
ds = (
396-
xr.tutorial.open_dataset("air_temperature")
397-
.isel(time=slice(100))
398-
.chunk({"time": 20, "lat": 5})
416+
ds = xr.Dataset(
417+
{
418+
"air": (
419+
("time", "lat", "lon"),
420+
dask.array.random.random((125, 25, 53), chunks=(20, 5, -1)),
421+
)
422+
},
423+
coords={
424+
"time": pd.date_range("2013-01-01", "2013-02-01", freq="6H"),
425+
"lat": np.arange(75.0, 14.9, -2.5),
426+
"lon": np.arange(200.0, 331.0, 2.5),
427+
},
399428
)
400429
bins = [40, 50, 60, 70]
401430
expected = ds.groupby_bins("lat", bins=bins).mean(keep_attrs=True, dim=...)

0 commit comments

Comments
 (0)