Skip to content

Commit 0ea9456

Browse files
jorobledovepe99
andauthored
Prior to(device) (#1505)
* add to(device) method for BoxUniform and test * add to(device) to PytorchReturnTypeWrapper * add process device to tests * add to(device) for MultipleIndependent * fix names, docstring, and not implemented to(device) on CurstomPriorWrapper * add type hints and process_prior for BoxUniform * add torch.device to process_device * add type hints for functions in user_input_checks_utils.py * add docstring for prior '.to()' methods * type hint added in BoxUniform * add test for MultivariateNormal, Binomial, and Normal. Also fix format. * add assertion messages in prior tests * add torch.device in dech_device * fix type of dist in get_distributions_parameters * add docstrings for prior_device_tests * change return type of process_device to also return torch.device * fix torch.device type when processing * add test for torch.device as input of process_device --------- Co-authored-by: vepe99 <viterbogiuseppe99@gmail.com>
1 parent 64c72b7 commit 0ea9456

File tree

4 files changed

+214
-6
lines changed

4 files changed

+214
-6
lines changed

sbi/utils/torchutils.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sbi.utils.typechecks import is_nonnegative_int, is_positive_int
1616

1717

18-
def process_device(device: str) -> str:
18+
def process_device(device: Union[str, torch.device]) -> str:
1919
"""Set and return the default device to cpu or gpu (cuda, mps).
2020
2121
Args:
@@ -52,6 +52,8 @@ def process_device(device: str) -> str:
5252
# Else, check whether the custom device is valid.
5353
else:
5454
check_device(device)
55+
if isinstance(device, torch.device):
56+
device = device.type
5557

5658
return device
5759

@@ -61,7 +63,7 @@ def gpu_available() -> bool:
6163
return torch.cuda.is_available() or torch.backends.mps.is_available()
6264

6365

64-
def check_device(device: str) -> None:
66+
def check_device(device: Union[str, torch.device]) -> None:
6567
"""Check whether the device is valid.
6668
6769
Args:
@@ -274,8 +276,8 @@ def gaussian_kde_log_eval(samples, query):
274276
class BoxUniform(Independent):
275277
def __init__(
276278
self,
277-
low: Tensor,
278-
high: Tensor,
279+
low: Union[Tensor, Array],
280+
high: Union[Tensor, Array],
279281
reinterpreted_batch_ndims: int = 1,
280282
device: Optional[str] = None,
281283
):
@@ -313,6 +315,11 @@ def __init__(
313315
# Device handling
314316
device = low.device.type if device is None else device
315317
device = process_device(device)
318+
self.device = device
319+
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
320+
321+
self.low = torch.as_tensor(low, dtype=torch.float32, device=device)
322+
self.high = torch.as_tensor(high, dtype=torch.float32, device=device)
316323

317324
super().__init__(
318325
Uniform(
@@ -327,6 +334,30 @@ def __init__(
327334
reinterpreted_batch_ndims,
328335
)
329336

337+
def to(self, device: Union[str, torch.device]) -> None:
338+
"""
339+
Moves the distribution to the specified device **in place**.
340+
341+
Args:
342+
device: Target device (e.g., "cpu", "cuda", "mps").
343+
"""
344+
# Update the device attribute
345+
self.device = device
346+
device = process_device(device)
347+
348+
# Move tensors to the new device
349+
self.low = self.low.to(device=device)
350+
self.high = self.high.to(device=device)
351+
352+
super().__init__(
353+
Uniform(
354+
low=self.low,
355+
high=self.high,
356+
validate_args=False,
357+
),
358+
self.reinterpreted_batch_ndims,
359+
)
360+
330361

331362
def ensure_theta_batched(theta: Tensor) -> Tensor:
332363
r"""

sbi/utils/user_input_checks_utils.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,35 @@
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

44
import warnings
5-
from typing import Dict, Optional, Sequence
5+
from typing import Dict, Optional, Sequence, Union
66

77
import torch
88
from torch import Tensor, float32
99
from torch.distributions import Distribution, constraints
1010

1111

12+
def get_distribution_parameters(
13+
dist: torch.distributions.Distribution, device: Union[str, torch.device]
14+
) -> Dict:
15+
"""Used to get the tensors of the parameters in torch distributions.
16+
17+
Returns the tensors relocated to device.
18+
"""
19+
params = {param: getattr(dist, param).to(device) for param in dist.arg_constraints}
20+
# torch.distributions.MultivariateNormal calculates precision
21+
# matrix from covariance, and stores it in the arg_constraints.
22+
# When reinstantiating, we must provide only one of them.
23+
if isinstance(dist, torch.distributions.MultivariateNormal):
24+
params["precision_matrix"] = None
25+
params["scale_tril"] = None
26+
# torch.distributions.MultivariateNormal calculates logits
27+
# from probabilities, and stores it in the arg_constraints.
28+
# When reinstantiating, we must provide only one of them.
29+
elif isinstance(dist, torch.distributions.Binomial):
30+
params["logits"] = None
31+
return params
32+
33+
1234
class CustomPriorWrapper(Distribution):
1335
def __init__(
1436
self,
@@ -83,6 +105,18 @@ def _set_mean_and_variance(self):
83105
stacklevel=2,
84106
)
85107

108+
def to(self, device: Union[str, torch.device]) -> None:
109+
"""
110+
Move the distribution to the specified device. Not implemented for this class.
111+
112+
Raises:
113+
NotImplementedError.
114+
"""
115+
raise NotImplementedError(
116+
"This class is not supported on the GPU. Use on cpu or use \
117+
any of `PytorchReturnTypeWrapper`, `BoxUniform`, or `MultipleIndependent`."
118+
)
119+
86120
@property
87121
def mean(self):
88122
return torch.as_tensor(
@@ -118,6 +152,7 @@ def __init__(
118152
)
119153

120154
self.prior = prior
155+
self.device = None
121156
self.return_type = return_type
122157

123158
def log_prob(self, value) -> Tensor:
@@ -150,6 +185,20 @@ def variance(self):
150185
def support(self):
151186
return self.prior.support
152187

188+
def to(self, device: Union[str, torch.device]) -> None:
189+
"""
190+
Move the distribution to the specified device.
191+
192+
Moves the distribution parameters to the specific device
193+
and updates the device attribute.
194+
195+
Args:
196+
device: device to move the distribution to.
197+
"""
198+
params = get_distribution_parameters(self.prior, device)
199+
self.prior = type(self.prior)(**params)
200+
self.device = device
201+
153202

154203
class MultipleIndependent(Distribution):
155204
"""Wrap a sequence of PyTorch distributions into a joint PyTorch distribution.
@@ -181,6 +230,7 @@ def __init__(
181230
[d.set_default_validate_args(validate_args) for d in dists]
182231

183232
self.dists = dists
233+
self.device = None
184234
# numel() instead of event_shape because for all dists both is possible,
185235
# event_shape=[1] or batch_shape=[1]
186236
self.dims_per_dist = [d.sample().numel() for d in self.dists]
@@ -319,6 +369,24 @@ def support(self):
319369
reinterpreted_batch_ndims=1,
320370
)
321371

372+
def to(self, device: Union[str, torch.device]) -> None:
373+
"""
374+
Move the distribution to the specified device.
375+
If the distribution has the `to` method, it is used. Otherwise, the
376+
parameters of the distribution are moved to the specified device.
377+
378+
Args:
379+
device: device to move the distribution to.
380+
"""
381+
for i in range(len(self.dists)):
382+
# ignoring because it is related to torch and not sbi
383+
if hasattr(self.dists[i], "to"):
384+
self.dists[i].to(device) # type: ignore
385+
else:
386+
params = get_distribution_parameters(self.dists[i], device)
387+
self.dists[i] = type(self.dists[i])(**params) # type: ignore
388+
self.device = device
389+
322390

323391
def build_support(
324392
lower_bound: Optional[Tensor] = None, upper_bound: Optional[Tensor] = None

tests/inference_on_device_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,15 @@ def allow_iid_x(self) -> bool:
444444
("gpu", None),
445445
("cpu", "cpu"),
446446
("gpu", "gpu"),
447+
(torch.device("cpu"), torch.device("cpu")),
447448
pytest.param("gpu", "cpu", marks=pytest.mark.xfail),
448449
pytest.param("cpu", "gpu", marks=pytest.mark.xfail),
449450
],
450451
)
451452
def test_boxuniform_device_handling(arg_device, device):
452-
"""Test mismatch between device passed via low / high and device kwarg."""
453+
"""Test mismatch between device passed via low / high and device kwarg.
454+
455+
Also tests torch.device as argument of process_device."""
453456

454457
arg_device = process_device(arg_device)
455458
device = process_device(device)

tests/prior_device_tests.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import pytest
2+
import torch
3+
from torch.distributions import Beta, Binomial, Gamma, MultivariateNormal, Normal
4+
5+
from sbi.utils.torchutils import BoxUniform, process_device
6+
from sbi.utils.user_input_checks_utils import (
7+
MultipleIndependent,
8+
PytorchReturnTypeWrapper,
9+
)
10+
11+
12+
@pytest.mark.gpu
13+
@pytest.mark.parametrize("device", ["cpu", "gpu"])
14+
def test_BoxUniform(device: str):
15+
"""Test moving BoxUniform prior between devices."""
16+
device = process_device(device)
17+
low = torch.tensor([0.0])
18+
high = torch.tensor([1.0])
19+
prior = BoxUniform(low, high)
20+
sample = prior.sample((1,))
21+
assert prior.device == "cpu", "Prior is not initially in cpu."
22+
assert sample.device.type == "cpu", "sample is not initially in cpu."
23+
log_probs = prior.log_prob(sample)
24+
assert log_probs.device.type == "cpu", "Log probs are not initially in cpu."
25+
26+
prior.to(device)
27+
assert prior.device == device, f"Prior was not moved to {device}."
28+
assert prior.low.device.type == device.strip(":0"), (
29+
f"BoxUniform low tensor is not in {device}."
30+
)
31+
assert prior.high.device.type == device.strip(":0"), (
32+
f"BoxUniform high tensor is not in {device}."
33+
)
34+
35+
sample_device = prior.sample((100,))
36+
assert sample_device.device.type == device.strip(":0"), (
37+
f"sample tensor is not in {device}."
38+
)
39+
log_probs = prior.log_prob(sample_device)
40+
assert log_probs.device.type == device.strip(":0"), (
41+
f"log_prob tensor is not in {device}."
42+
)
43+
44+
45+
@pytest.mark.gpu
46+
@pytest.mark.parametrize("device", ["cpu", "gpu"])
47+
@pytest.mark.parametrize(
48+
"prior",
49+
[
50+
Normal(loc=0.0, scale=1.0),
51+
Binomial(total_count=10, probs=torch.tensor([0.5])),
52+
MultivariateNormal(torch.tensor([0.1, 0.0]), covariance_matrix=torch.eye(2)),
53+
],
54+
)
55+
def test_PytorchReturnTypeWrapper(device: str, prior: torch.distributions):
56+
"""Test moving PytorchReturnTypeWrapper objects between devices.
57+
58+
Asserts that samples, prior, and log_probs are in device.
59+
"""
60+
device = process_device(device)
61+
prior = PytorchReturnTypeWrapper(prior)
62+
63+
prior.to(device)
64+
assert prior.device == device, f"Prior was not correctly moved to {device}."
65+
66+
sample_device = prior.sample((100,))
67+
assert sample_device.device.type == device.strip(":0"), (
68+
f"sample was not correctly moved to {device}."
69+
)
70+
log_probs = prior.log_prob(sample_device)
71+
assert log_probs.device.type == device.strip(":0"), (
72+
f"log_prob was not correctly moved to {device}."
73+
)
74+
75+
76+
@pytest.mark.gpu
77+
@pytest.mark.parametrize("device", ["cpu", "gpu"])
78+
def test_MultipleIndependent(device: str):
79+
"""Test moving MultipleIndependent objects between devices.
80+
81+
Asserts that samples, prior, and log_probs are in device.
82+
Uses Gamma, Beta, Normal and Binomial, from
83+
torch.distributions and BoxUniform form sbi.
84+
"""
85+
device = process_device(device)
86+
dists = [
87+
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
88+
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
89+
BoxUniform(torch.zeros(1), torch.ones(1)),
90+
Normal(torch.tensor([0.0]), torch.tensor([0.5])),
91+
Binomial(torch.tensor([10]), torch.tensor([0.5])),
92+
]
93+
94+
prior = MultipleIndependent(dists)
95+
96+
prior.to(device)
97+
assert prior.device == device, f"Prior was not correctly moved to {device}."
98+
99+
sample_device = prior.sample((100,))
100+
assert sample_device.device.type == device.strip(":0"), (
101+
f"sample was not correctly moved to {device}."
102+
)
103+
log_probs = prior.log_prob(sample_device)
104+
assert log_probs.device.type == device.strip(":0"), (
105+
f"log_prob was not correctly moved to {device}."
106+
)

0 commit comments

Comments
 (0)