|
12 | 12 | from somacore import AxisQuery
|
13 | 13 | from tiledbsoma import Experiment, _factory
|
14 | 14 | from tiledbsoma._collection import CollectionBase
|
15 |
| -from torch.utils.data._utils.worker import WorkerInfo |
16 | 15 |
|
17 | 16 | # conditionally import torch, as it will not be available in all test environments
|
18 | 17 | try:
|
19 | 18 | from torch import Tensor, float32
|
| 19 | + from torch.utils.data._utils.worker import WorkerInfo |
20 | 20 |
|
21 | 21 | from cellxgene_census.experimental.ml.pytorch import (
|
22 | 22 | ExperimentDataPipe,
|
@@ -583,17 +583,15 @@ def test_experiment_dataloader__multiprocess_dense_matrix__ok() -> None:
|
583 | 583 |
|
584 | 584 |
|
585 | 585 | @pytest.mark.experimental
|
586 |
| -@patch("cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe") |
587 |
| -def test_experiment_dataloader__unsupported_params__fails( |
588 |
| - dummy_exp_data_pipe: ExperimentDataPipe, |
589 |
| -) -> None: |
590 |
| - with pytest.raises(ValueError): |
591 |
| - experiment_dataloader(dummy_exp_data_pipe, shuffle=True) |
592 |
| - with pytest.raises(ValueError): |
593 |
| - experiment_dataloader(dummy_exp_data_pipe, batch_size=3) |
594 |
| - with pytest.raises(ValueError): |
595 |
| - experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) |
596 |
| - with pytest.raises(ValueError): |
597 |
| - experiment_dataloader(dummy_exp_data_pipe, sampler=[]) |
598 |
| - with pytest.raises(ValueError): |
599 |
| - experiment_dataloader(dummy_exp_data_pipe, collate_fn=lambda x: x) |
| 586 | +def test_experiment_dataloader__unsupported_params__fails() -> None: |
| 587 | + with patch("cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe") as dummy_exp_data_pipe: |
| 588 | + with pytest.raises(ValueError): |
| 589 | + experiment_dataloader(dummy_exp_data_pipe, shuffle=True) |
| 590 | + with pytest.raises(ValueError): |
| 591 | + experiment_dataloader(dummy_exp_data_pipe, batch_size=3) |
| 592 | + with pytest.raises(ValueError): |
| 593 | + experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) |
| 594 | + with pytest.raises(ValueError): |
| 595 | + experiment_dataloader(dummy_exp_data_pipe, sampler=[]) |
| 596 | + with pytest.raises(ValueError): |
| 597 | + experiment_dataloader(dummy_exp_data_pipe, collate_fn=lambda x: x) |
0 commit comments