@@ -1236,7 +1236,7 @@ def test_subset_block_2d(flatblocks, expectidx):
1236
1236
1237
1237
1238
1238
@pytest .mark .parametrize (
1239
- "expected , reindex, func, expected_groups, any_by_dask" ,
1239
+ "dask_expected , reindex, func, expected_groups, any_by_dask" ,
1240
1240
[
1241
1241
# argmax only False
1242
1242
[False , None , "argmax" , None , False ],
@@ -1252,22 +1252,43 @@ def test_subset_block_2d(flatblocks, expectidx):
1252
1252
[True , None , "sum" , ([1 ], None ), True ],
1253
1253
],
1254
1254
)
1255
- def test_validate_reindex_map_reduce (expected , reindex , func , expected_groups , any_by_dask ):
1256
- actual = _validate_reindex (reindex , func , "map-reduce" , expected_groups , any_by_dask )
1257
- assert actual == expected
1255
+ def test_validate_reindex_map_reduce (
1256
+ dask_expected , reindex , func , expected_groups , any_by_dask
1257
+ ) -> None :
1258
+ actual = _validate_reindex (
1259
+ reindex , func , "map-reduce" , expected_groups , any_by_dask , is_dask_array = True
1260
+ )
1261
+ assert actual is dask_expected
1258
1262
1263
+ # always reindex with all numpy inputs
1264
+ actual = _validate_reindex (
1265
+ reindex , func , "map-reduce" , expected_groups , any_by_dask = False , is_dask_array = False
1266
+ )
1267
+ assert actual
1268
+
1269
+ actual = _validate_reindex (
1270
+ True , func , "map-reduce" , expected_groups , any_by_dask = False , is_dask_array = False
1271
+ )
1272
+ assert actual
1259
1273
1260
- def test_validate_reindex ():
1274
+
1275
+ def test_validate_reindex () -> None :
1261
1276
for method in ["map-reduce" , "cohorts" ]:
1262
1277
with pytest .raises (NotImplementedError ):
1263
- _validate_reindex (True , "argmax" , method , expected_groups = None , any_by_dask = False )
1278
+ _validate_reindex (
1279
+ True , "argmax" , method , expected_groups = None , any_by_dask = False , is_dask_array = True
1280
+ )
1264
1281
1265
1282
for method in ["blockwise" , "cohorts" ]:
1266
1283
with pytest .raises (ValueError ):
1267
- _validate_reindex (True , "sum" , method , expected_groups = None , any_by_dask = False )
1284
+ _validate_reindex (
1285
+ True , "sum" , method , expected_groups = None , any_by_dask = False , is_dask_array = True
1286
+ )
1268
1287
1269
1288
for func in ["sum" , "argmax" ]:
1270
- actual = _validate_reindex (None , func , method , expected_groups = None , any_by_dask = False )
1289
+ actual = _validate_reindex (
1290
+ None , func , method , expected_groups = None , any_by_dask = False , is_dask_array = True
1291
+ )
1271
1292
assert actual is False
1272
1293
1273
1294
0 commit comments