Skip to content

Commit cb3fc1f

Browse files
authored
Add cohorts snapshot tests with syrupy (#379)
* Add cohorts snapshot tests with syrupy * Fix. * fix again * Rework CI * [revery] * improve * fix mypy? * Revert "[revery]" This reverts commit 7664e5e. * Try again * fix mypy
1 parent 4d03d70 commit cb3fc1f

13 files changed

+22362
-50
lines changed

.github/workflows/ci-additional.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ jobs:
128128
129129
- name: Run mypy
130130
run: |
131-
python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report
131+
mkdir .mypy_cache
132+
python -m mypy --install-types --non-interactive --cache-dir=.mypy_cache/ --cobertura-xml-report mypy_report
132133
133134
- name: Upload mypy coverage to Codecov
134135
uses: codecov/[email protected]

.github/workflows/ci.yaml

+10-44
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ concurrency:
1616

1717
jobs:
1818
test:
19-
name: Test (${{ matrix.python-version }}, ${{ matrix.os }})
19+
name: Test (${{matrix.env}}, ${{ matrix.python-version }}, ${{ matrix.os }})
2020
runs-on: ${{ matrix.os }}
2121
defaults:
2222
run:
@@ -25,10 +25,18 @@ jobs:
2525
fail-fast: false
2626
matrix:
2727
os: ["ubuntu-latest"]
28+
env: ["environment"]
2829
python-version: ["3.9", "3.12"]
2930
include:
3031
- os: "windows-latest"
32+
env: "environment"
3133
python-version: "3.12"
34+
- os: "ubuntu-latest"
35+
env: "no-dask" # "no-xarray", "no-numba"
36+
python-version: "3.12"
37+
- os: "ubuntu-latest"
38+
env: "minimal-requirements"
39+
python-version: "3.9"
3240
steps:
3341
- uses: actions/checkout@v4
3442
with:
@@ -39,7 +47,7 @@ jobs:
3947
- name: Set up conda environment
4048
uses: mamba-org/setup-micromamba@v1
4149
with:
42-
environment-file: ci/environment.yml
50+
environment-file: ci/${{ matrix.env }}.yml
4351
environment-name: flox-tests
4452
init-shell: bash
4553
cache-environment: true
@@ -81,48 +89,6 @@ jobs:
8189
path: .hypothesis/
8290
key: cache-hypothesis-${{ runner.os }}-${{ matrix.python-version }}-${{ github.run_id }}
8391

84-
optional-deps:
85-
name: ${{ matrix.env }}
86-
runs-on: "ubuntu-latest"
87-
defaults:
88-
run:
89-
shell: bash -l {0}
90-
strategy:
91-
fail-fast: false
92-
matrix:
93-
python-version: ["3.12"]
94-
env: ["no-dask"] # "no-xarray", "no-numba"
95-
include:
96-
- env: "minimal-requirements"
97-
python-version: "3.9"
98-
steps:
99-
- uses: actions/checkout@v4
100-
with:
101-
fetch-depth: 0 # Fetch all history for all branches and tags.
102-
- name: Set up conda environment
103-
uses: mamba-org/setup-micromamba@v1
104-
with:
105-
environment-file: ci/${{ matrix.env }}.yml
106-
environment-name: flox-tests
107-
init-shell: bash
108-
cache-environment: true
109-
create-args: |
110-
python=${{ matrix.python-version }}
111-
- name: Install flox
112-
run: |
113-
python -m pip install --no-deps -e .
114-
- name: Run tests
115-
run: |
116-
python -m pytest -n auto --cov=./ --cov-report=xml
117-
- name: Upload code coverage to Codecov
118-
uses: codecov/[email protected]
119-
with:
120-
file: ./coverage.xml
121-
flags: unittests
122-
env_vars: RUNNER_OS
123-
name: codecov-umbrella
124-
fail_ci_if_error: false
125-
12692
xarray-groupby:
12793
name: xarray-groupby
12894
runs-on: ubuntu-latest

asv_bench/__init__.py

Whitespace-only changes.

ci/environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies:
1818
- pytest-cov
1919
- pytest-pretty
2020
- pytest-xdist
21+
- syrupy
2122
- xarray
2223
- pre-commit
2324
- numpy_groupies>=0.9.19

ci/minimal-requirements.yml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies:
99
- pytest-cov
1010
- pytest-pretty
1111
- pytest-xdist
12+
- syrupy
1213
- numpy==1.22
1314
- scipy==1.9.0
1415
- numpy_groupies==0.9.19

ci/no-dask.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-cov
1414
- pytest-pretty
1515
- pytest-xdist
16+
- syrupy
1617
- xarray
1718
- numpydoc
1819
- pre-commit

ci/no-numba.yml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies:
1818
- pytest-cov
1919
- pytest-pretty
2020
- pytest-xdist
21+
- syrupy
2122
- xarray
2223
- pre-commit
2324
- numpy_groupies>=0.9.19

ci/no-xarray.yml

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ channels:
33
- conda-forge
44
dependencies:
55
- codecov
6+
- syrupy
67
- pandas
78
- numpy>=1.22
89
- scipy
@@ -11,6 +12,7 @@ dependencies:
1112
- pytest-cov
1213
- pytest-pretty
1314
- pytest-xdist
15+
- syrupy
1416
- dask-core
1517
- numpydoc
1618
- pre-commit

ci/upstream-dev-env.yml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
# - scipy
1212
- pytest-pretty
1313
- pytest-xdist
14+
- syrupy
1415
- pip
1516
# for cftime
1617
- cython>=0.29.20

flox/core.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,9 @@ def find_group_cohorts(
394394
chunks_per_label = chunks_per_label[present_labels_mask]
395395

396396
label_chunks = {
397-
present_labels[idx]: bitmask.indices[slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])]
397+
present_labels[idx].item(): bitmask.indices[
398+
slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])
399+
]
398400
for idx in range(bitmask.shape[LABEL_AXIS])
399401
}
400402

@@ -485,7 +487,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
485487

486488
# Iterate over labels, beginning with those with most chunks
487489
logger.debug("find_group_cohorts: merging cohorts")
488-
order = np.argsort(containment.sum(axis=LABEL_AXIS))[::-1]
490+
order = np.argsort(containment.sum(axis=LABEL_AXIS), kind="stable")[::-1]
489491
merged_cohorts = {}
490492
merged_keys = set()
491493
# TODO: we can optimize this to loop over chunk_cohorts instead
@@ -495,11 +497,11 @@ def invert(x) -> tuple[np.ndarray, ...]:
495497
slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
496498
]
497499
cohort_ = present_labels[cohidx]
498-
cohort = [elem for elem in cohort_ if elem not in merged_keys]
500+
cohort = [elem.item() for elem in cohort_ if elem not in merged_keys]
499501
if not cohort:
500502
continue
501503
merged_keys.update(cohort)
502-
allchunks = (label_chunks[member] for member in cohort)
504+
allchunks = (label_chunks[member].tolist() for member in cohort)
503505
chunk = tuple(set(itertools.chain(*allchunks)))
504506
merged_cohorts[chunk] = cohort
505507

0 commit comments

Comments
 (0)