Skip to content

Commit c6f827b

Browse files
authored
[python] Fix pytorch imports (#1129)
1 parent 8fb99b3 commit c6f827b

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
from somacore import AxisQuery
1313
from tiledbsoma import Experiment, _factory
1414
from tiledbsoma._collection import CollectionBase
15-
from torch.utils.data._utils.worker import WorkerInfo
1615

1716
# conditionally import torch, as it will not be available in all test environments
1817
try:
1918
from torch import Tensor, float32
19+
from torch.utils.data._utils.worker import WorkerInfo
2020

2121
from cellxgene_census.experimental.ml.pytorch import (
2222
ExperimentDataPipe,
@@ -583,17 +583,15 @@ def test_experiment_dataloader__multiprocess_dense_matrix__ok() -> None:
583583

584584

585585
@pytest.mark.experimental
586-
@patch("cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe")
587-
def test_experiment_dataloader__unsupported_params__fails(
588-
dummy_exp_data_pipe: ExperimentDataPipe,
589-
) -> None:
590-
with pytest.raises(ValueError):
591-
experiment_dataloader(dummy_exp_data_pipe, shuffle=True)
592-
with pytest.raises(ValueError):
593-
experiment_dataloader(dummy_exp_data_pipe, batch_size=3)
594-
with pytest.raises(ValueError):
595-
experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[])
596-
with pytest.raises(ValueError):
597-
experiment_dataloader(dummy_exp_data_pipe, sampler=[])
598-
with pytest.raises(ValueError):
599-
experiment_dataloader(dummy_exp_data_pipe, collate_fn=lambda x: x)
586+
def test_experiment_dataloader__unsupported_params__fails() -> None:
587+
with patch("cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe") as dummy_exp_data_pipe:
588+
with pytest.raises(ValueError):
589+
experiment_dataloader(dummy_exp_data_pipe, shuffle=True)
590+
with pytest.raises(ValueError):
591+
experiment_dataloader(dummy_exp_data_pipe, batch_size=3)
592+
with pytest.raises(ValueError):
593+
experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[])
594+
with pytest.raises(ValueError):
595+
experiment_dataloader(dummy_exp_data_pipe, sampler=[])
596+
with pytest.raises(ValueError):
597+
experiment_dataloader(dummy_exp_data_pipe, collate_fn=lambda x: x)

0 commit comments

Comments
 (0)