-
Notifications
You must be signed in to change notification settings - Fork 413
Open
Description
PyroModelGuideWarmup fails on GPU probably because Callback.setup()
is called in the accelerator environment in the latest PyTorch Lightning.
This test fails on GPU:
pytest tests/model/test_pyro.py::test_pyro_bayesian_regression_low_level --accelerator 'gpu'
(cell2state_cuda118_torch22) vk7@farm22-gpu0203:.../software/tests/scvi-tools$ pytest tests/model/test_pyro.py::test_pyro_bayesian_regression_low_level --accelerator 'gpu'
=================================================================== test session starts ===================================================================
platform linux -- Python 3.10.13, pytest-8.1.1, pluggy-1.4.0
rootdir: .../software/tests/scvi-tools
configfile: pyproject.toml
plugins: cov-4.1.0, anyio-4.3.0
collected 1 item
tests/model/test_pyro.py F [100%]
======================================================================== FAILURES =========================================================================
_________________________________________________________ test_pyro_bayesian_regression_low_level _________________________________________________________
self = BayesianRegressionPyroModel(
(linear): PyroLinear(in_features=100, out_features=1, bias=True)
)
x = tensor([[ 6., 25., 3., ..., 10., 22., 13.],
[14., 3., 14., ..., 0., 6., 14.],
[19., 0., 0., ....0., 8.],
[ 0., 9., 2., ..., 14., 6., 0.],
[ 0., 0., 0., ..., 13., 9., 12.]], device='cuda:0')
y = tensor([[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
... [0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]], device='cuda:0')
ind_x = tensor([ 29, 272, 379, 251, 149, 339, 147, 137, 197, 275, 139, 323, 365, 322,
362, 59, 99, 281, 397, 31, 7... 301,
92, 378, 221, 280, 349, 46, 83, 222, 48, 180, 279, 395, 53, 87,
386, 7], device='cuda:0')
def forward(self, x, y, ind_x):
obs_plate = self.create_plates(x, y, ind_x)
sigma = pyro.sample("sigma", dist.Exponential(self.one))
> mean = self.linear(x).squeeze(-1)
tests/model/test_pyro.py:98:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = PyroLinear(in_features=100, out_features=1, bias=True)
input = tensor([[ 6., 25., 3., ..., 10., 22., 13.],
[14., 3., 14., ..., 0., 6., 14.],
[19., 0., 0., ....0., 8.],
[ 0., 9., 2., ..., 14., 6., 0.],
[ 0., 0., 0., ..., 13., 9., 12.]], device='cuda:0')
def forward(self, input: Tensor) -> Tensor:
> return F.linear(input, self.weight, self.bias)
E RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/linear.py:116: RuntimeError
The above exception was the direct cause of the following exception:
accelerator = 'gpu', devices = 'auto'
def test_pyro_bayesian_regression_low_level(
accelerator: str,
devices: list | str | int,
):
adata = synthetic_iid()
adata_manager = _create_indices_adata_manager(adata)
train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128)
pyro.clear_param_store()
model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1)
plan = LowLevelPyroTrainingPlan(model)
plan.n_obs_training = len(train_dl.indices)
trainer = Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=2,
callbacks=[PyroModelGuideWarmup(train_dl)],
)
> trainer.fit(plan, train_dl)
tests/model/test_pyro.py:203:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
scvi/train/_trainer.py:219: in fit
super().fit(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544: in fit
call._call_and_handle_interrupt(
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44: in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:580: in _fit_impl
self._run(model, ckpt_path=ckpt_path)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:950: in _run
call._call_setup_hook(self) # allow user to setup lightning_module in accelerator environment
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:93: in _call_setup_hook
_call_callback_hooks(trainer, "setup", stage=fn)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:208: in _call_callback_hooks
fn(trainer, trainer.lightning_module, *args, **kwargs)
scvi/model/base/_pyromixin.py:72: in setup
pyro_guide(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
return forward_call(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:510: in forward
self._setup_prototype(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:460: in _setup_prototype
super()._setup_prototype(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:157: in _setup_prototype
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/messenger.py:32: in _context_wrap
return fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:216: in get_trace
self(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198: in __call__
raise exc from e
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:191: in __call__
ret = self.fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/messenger.py:32: in _context_wrap
return fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/messenger.py:32: in _context_wrap
return fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
return forward_call(*args, **kwargs)
tests/model/test_pyro.py:98: in forward
mean = self.linear(x).squeeze(-1)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = PyroLinear(in_features=100, out_features=1, bias=True)
input = tensor([[ 6., 25., 3., ..., 10., 22., 13.],
[14., 3., 14., ..., 0., 6., 14.],
[19., 0., 0., ....0., 8.],
[ 0., 9., 2., ..., 14., 6., 0.],
[ 0., 0., 0., ..., 13., 9., 12.]], device='cuda:0')
def forward(self, input: Tensor) -> Tensor:
> return F.linear(input, self.weight, self.bias)
E RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
E Trace Shapes:
E Param Sites:
E Sample Sites:
E sigma dist |
E value |
E linear.weight dist | 1 100
E value | 1 100
E linear.bias dist | 1
E value | 1
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/linear.py:116: RuntimeError
------------------------------------------------------------------ Captured stderr setup ------------------------------------------------------------------
Seed set to 0
------------------------------------------------------------------ Captured stderr call -------------------------------------------------------------------
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
-------------------------------------------------------------------- Captured log call --------------------------------------------------------------------
WARNING jax._src.xla_bridge:xla_bridge.py:742 An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
==================================================================== warnings summary =====================================================================
../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning_utilities/core/imports.py:14
/nfs/team283/vk7/software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning_utilities/core/imports.py:14: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/fabric/__init__.py:40
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/fabric/__init__.py:40: Deprecated call to `pkg_resources.declare_namespace('lightning.fabric')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pkg_resources/__init__.py:2350
../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pkg_resources/__init__.py:2350
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pkg_resources/__init__.py:2350: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('lightning')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(parent)
../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/__init__.py:37
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/__init__.py:37: Deprecated call to `pkg_resources.declare_namespace('lightning.pytorch')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================= short test summary info =================================================================
FAILED tests/model/test_pyro.py::test_pyro_bayesian_regression_low_level - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1...
============================================================== 1 failed, 5 warnings in 2.00s ==============================================================
Versions:
scvi 1.1.2
lightning 2.1.4
torch 2.2.1+cu118