Skip to content

PyroModelGuideWarmup fails on GPU - probably need to be manually run before trainer.fit() #2616

@vitkl

Description

@vitkl

PyroModelGuideWarmup fails on GPU probably because Callback.setup() is called in the accelerator environment in the latest PyTorch Lightning.

This test fails on GPU:

pytest tests/model/test_pyro.py::test_pyro_bayesian_regression_low_level --accelerator 'gpu'
(cell2state_cuda118_torch22) vk7@farm22-gpu0203:.../software/tests/scvi-tools$ pytest tests/model/test_pyro.py::test_pyro_bayesian_regression_low_level --accelerator 'gpu'
=================================================================== test session starts ===================================================================
platform linux -- Python 3.10.13, pytest-8.1.1, pluggy-1.4.0
rootdir: .../software/tests/scvi-tools
configfile: pyproject.toml
plugins: cov-4.1.0, anyio-4.3.0
collected 1 item                                                                                                                                          

tests/model/test_pyro.py F                                                                                                                          [100%]

======================================================================== FAILURES =========================================================================
_________________________________________________________ test_pyro_bayesian_regression_low_level _________________________________________________________

self = BayesianRegressionPyroModel(
  (linear): PyroLinear(in_features=100, out_features=1, bias=True)
)
x = tensor([[ 6., 25.,  3.,  ..., 10., 22., 13.],
        [14.,  3., 14.,  ...,  0.,  6., 14.],
        [19.,  0.,  0.,  ....0.,  8.],
        [ 0.,  9.,  2.,  ..., 14.,  6.,  0.],
        [ 0.,  0.,  0.,  ..., 13.,  9., 12.]], device='cuda:0')
y = tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
 ...      [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]], device='cuda:0')
ind_x = tensor([ 29, 272, 379, 251, 149, 339, 147, 137, 197, 275, 139, 323, 365, 322,
        362,  59,  99, 281, 397,  31,  7... 301,
         92, 378, 221, 280, 349,  46,  83, 222,  48, 180, 279, 395,  53,  87,
        386,   7], device='cuda:0')

    def forward(self, x, y, ind_x):
        obs_plate = self.create_plates(x, y, ind_x)
    
        sigma = pyro.sample("sigma", dist.Exponential(self.one))
    
>       mean = self.linear(x).squeeze(-1)

tests/model/test_pyro.py:98: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
    result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = PyroLinear(in_features=100, out_features=1, bias=True)
input = tensor([[ 6., 25.,  3.,  ..., 10., 22., 13.],
        [14.,  3., 14.,  ...,  0.,  6., 14.],
        [19.,  0.,  0.,  ....0.,  8.],
        [ 0.,  9.,  2.,  ..., 14.,  6.,  0.],
        [ 0.,  0.,  0.,  ..., 13.,  9., 12.]], device='cuda:0')

    def forward(self, input: Tensor) -> Tensor:
>       return F.linear(input, self.weight, self.bias)
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/linear.py:116: RuntimeError

The above exception was the direct cause of the following exception:

accelerator = 'gpu', devices = 'auto'

    def test_pyro_bayesian_regression_low_level(
        accelerator: str,
        devices: list | str | int,
    ):
        adata = synthetic_iid()
        adata_manager = _create_indices_adata_manager(adata)
        train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128)
        pyro.clear_param_store()
        model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1)
        plan = LowLevelPyroTrainingPlan(model)
        plan.n_obs_training = len(train_dl.indices)
        trainer = Trainer(
            accelerator=accelerator,
            devices=devices,
            max_epochs=2,
            callbacks=[PyroModelGuideWarmup(train_dl)],
        )
>       trainer.fit(plan, train_dl)

tests/model/test_pyro.py:203: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
scvi/train/_trainer.py:219: in fit
    super().fit(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544: in fit
    call._call_and_handle_interrupt(
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44: in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:580: in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:950: in _run
    call._call_setup_hook(self)  # allow user to setup lightning_module in accelerator environment
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:93: in _call_setup_hook
    _call_callback_hooks(trainer, "setup", stage=fn)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:208: in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
scvi/model/base/_pyromixin.py:72: in setup
    pyro_guide(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
    result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
    return forward_call(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:510: in forward
    self._setup_prototype(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:460: in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:157: in _setup_prototype
    self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/messenger.py:32: in _context_wrap
    return fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:216: in get_trace
    self(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198: in __call__
    raise exc from e
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:191: in __call__
    ret = self.fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/messenger.py:32: in _context_wrap
    return fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/messenger.py:32: in _context_wrap
    return fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
    result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
    return forward_call(*args, **kwargs)
tests/model/test_pyro.py:98: in forward
    mean = self.linear(x).squeeze(-1)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
    result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = PyroLinear(in_features=100, out_features=1, bias=True)
input = tensor([[ 6., 25.,  3.,  ..., 10., 22., 13.],
        [14.,  3., 14.,  ...,  0.,  6., 14.],
        [19.,  0.,  0.,  ....0.,  8.],
        [ 0.,  9.,  2.,  ..., 14.,  6.,  0.],
        [ 0.,  0.,  0.,  ..., 13.,  9., 12.]], device='cuda:0')

    def forward(self, input: Tensor) -> Tensor:
>       return F.linear(input, self.weight, self.bias)
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
E            Trace Shapes:        
E             Param Sites:        
E            Sample Sites:        
E               sigma dist |      
E                    value |      
E       linear.weight dist | 1 100
E                    value | 1 100
E         linear.bias dist | 1    
E                    value | 1

.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/linear.py:116: RuntimeError
------------------------------------------------------------------ Captured stderr setup ------------------------------------------------------------------
Seed set to 0
------------------------------------------------------------------ Captured stderr call -------------------------------------------------------------------
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
-------------------------------------------------------------------- Captured log call --------------------------------------------------------------------
WARNING  jax._src.xla_bridge:xla_bridge.py:742 An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
==================================================================== warnings summary =====================================================================
../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning_utilities/core/imports.py:14
  /nfs/team283/vk7/software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning_utilities/core/imports.py:14: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html

../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/fabric/__init__.py:40
  .../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/fabric/__init__.py:40: Deprecated call to `pkg_resources.declare_namespace('lightning.fabric')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages

../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pkg_resources/__init__.py:2350
../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pkg_resources/__init__.py:2350
  .../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pkg_resources/__init__.py:2350: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('lightning')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(parent)

../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/__init__.py:37
  .../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/__init__.py:37: Deprecated call to `pkg_resources.declare_namespace('lightning.pytorch')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================= short test summary info =================================================================
FAILED tests/model/test_pyro.py::test_pyro_bayesian_regression_low_level - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1...
============================================================== 1 failed, 5 warnings in 2.00s ==============================================================

Versions:

scvi 1.1.2
lightning 2.1.4
torch 2.2.1+cu118

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions