Skip to content

Commit 0d5a71a

Browse files
authored
Add examples to documentation (#1548)
* Example BoxUniform * add imports BoxUniform example * update documentation and add example for process_prior, process_simulator, BoxUniform, and MultipleIndependent
1 parent c2bb760 commit 0d5a71a

File tree

4 files changed

+142
-59
lines changed

4 files changed

+142
-59
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
intersphinx_mapping = {
4242
"python": ("https://docs.python.org/3", None),
43+
"torch": ("https://pytorch.org/docs/stable/", None),
4344
}
4445

4546
source_suffix = {'.rst': 'restructuredtext', '.myst': 'myst-nb', '.ipynb': 'myst-nb'}

sbi/utils/torchutils.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -285,23 +285,48 @@ def __init__(
285285
):
286286
"""Multidimensional uniform distribution defined on a box.
287287
288-
A `Uniform` distribution initialized with e.g. a parameter vector low or high of
289-
length 3 will result in a /batch/ dimension of length 3. A log_prob evaluation
290-
will then output three numbers, one for each of the independent Uniforms in
291-
the batch. Instead, a `BoxUniform` initialized in the same way has three
292-
/event/ dimensions, and returns a scalar log_prob corresponding to whether
293-
the evaluated point is in the box defined by low and high or outside.
294-
295-
Refer to torch.distributions.Uniform and torch.distributions.Independent for
296-
further documentation.
288+
A :class:`~torch.distributions.uniform.Uniform` distribution initialized \
289+
with e.g. a parameter vector low or high of length 3 will result \
290+
in a *batch* dimension of length 3. A log_prob evaluation will then \
291+
output three numbers, one for each of the independent Uniforms in the \
292+
batch. Instead, a :class:`BoxUniform` initialized in the same way has three \
293+
*event* dimensions, and returns a scalar log_prob corresponding to whether \
294+
the evaluated point is in the box defined by low and high or outside.
295+
296+
Refer to :class:`~torch.distributions.uniform.Uniform`\
297+
and :class:`~torch.distributions.independent.Independent` for \
298+
further documentation.
297299
298300
Args:
299301
low: lower range (inclusive).
300302
high: upper range (exclusive).
301-
reinterpreted_batch_ndims (int): the number of batch dims to
302-
reinterpret as event dims.
303-
device: device of the prior, inferred from low arg, defaults to "cpu",
304-
should match the training device when used in SBI.
303+
reinterpreted_batch_ndims (int): the number of batch dims to \
304+
reinterpret as event dims.
305+
device (Optional): device of the prior, inferred from low arg, \
306+
defaults to "cpu", should match the training device when used in SBI.
307+
308+
Example:
309+
--------
310+
311+
::
312+
313+
import torch
314+
from sbi.utils.torchutils import BoxUniform
315+
316+
# Define lower bounds
317+
low = torch.tensor([0.0, 0.0, 0.0])
318+
319+
# Define upper bounds
320+
high = torch.tensor([1.0, 1.0, 1.0])
321+
322+
box_uniform = BoxUniform(low, high)
323+
324+
# Sample from the box_uniform
325+
N_samples = 100
326+
sample = box_uniform.sample((N_samples,))
327+
328+
# Evaluate the log probability of the sample
329+
log_prob = box_uniform.log_prob(sample)
305330
"""
306331

307332
# Type checks.
@@ -342,6 +367,15 @@ def to(self, device: Union[str, torch.device]) -> None:
342367
343368
Args:
344369
device: Target device (e.g., "cpu", "cuda", "mps").
370+
371+
Example:
372+
--------
373+
374+
::
375+
376+
device = "cuda"
377+
prior = BoxUniform(low=torch.zeros(2), high=torch.ones(2))
378+
prior.to(device) #inplace
345379
"""
346380
# Update the device attribute
347381
self.device = device

sbi/utils/user_input_checks.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,34 +39,51 @@ def process_prior(
3939
prior: Union[Sequence[Distribution], Distribution, rv_frozen, multi_rv_frozen],
4040
custom_prior_wrapper_kwargs: Optional[Dict] = None,
4141
) -> Tuple[Distribution, int, bool]:
42-
"""Return PyTorch distribution-like prior from user-provided prior.
42+
"""
43+
Return PyTorch distribution-like prior from user-provided prior.
4344
44-
NOTE: If the prior argument is a sequence of PyTorch distributions, they will be
45-
interpreted as independent prior dimensions wrapped in a `MultipleIndependent`
46-
pytorch Distribution. In case the elements are not PyTorch distributions, make sure
47-
to use process_prior on each element in the list beforehand.
45+
NOTE: If the prior argument is a sequence of PyTorch distributions, \
46+
they will be interpreted as independent prior dimensions wrapped in a \
47+
:class:`MultipleIndependent` PyTorch Distribution. In case the elements \
48+
are not PyTorch distributions, make sure to use :func:`process_prior` \
49+
on each element in the list beforehand.
4850
49-
NOTE: returns a tuple (processed_prior, num_params, whether_prior_returns_numpy).
50-
The last two entries in the tuple can be passed on to `process_simulator` to prepare
51-
the simulator as well. For example, it will take care of casting parameters to numpy
52-
or adding a batch dimension to the simulator output, if needed.
51+
NOTE: Returns a tuple `(processed_prior, num_params, \
52+
whether_prior_returns_numpy)`. The last two entries in the tuple\
53+
can be passed on to `process_simulator` to prepare the simulator. For \
54+
example, it ensures parameters are cast to numpy or adds a batch \
55+
dimension to the simulator output, if needed.
5356
5457
Args:
55-
prior: Prior object with `.sample()` and `.log_prob()` as provided by the user,
56-
or a sequence of such objects.
57-
custom_prior_wrapper_kwargs: kwargs to be passed to the class that wraps a
58-
custom prior into a pytorch Distribution, e.g., for passing bounds for a
59-
prior with bounded support (lower_bound, upper_bound), or argument
60-
constraints.
61-
(arg_constraints), see pytorch.distributions.Distribution for more info.
58+
prior (:class:`~torch.distributions.distribution.Distribution` \
59+
or Sequence[:class:`~torch.distributions.distribution.Distribution`]):
60+
Prior object with `.sample()` and `.log_prob()`, or a sequence \
61+
of such objects.
62+
custom_prior_wrapper_kwargs (dict, optional):
63+
Additional arguments passed to the wrapper class that processes the prior
64+
into a PyTorch Distribution, such as bounds (`lower_bound`, `upper_bound`)
65+
or argument constraints (`arg_constraints`).
6266
6367
Raises:
64-
AttributeError: If prior objects lacks `.sample()` or `.log_prob()`.
68+
AttributeError: If prior objects lack `.sample()` or `.log_prob()`.
6569
6670
Returns:
67-
prior: Prior that emits samples and evaluates log prob as PyTorch Tensors.
68-
theta_numel: Number of parameters - elements in a single sample from the prior.
69-
prior_returns_numpy: Whether the return type of the prior was a Numpy array.
71+
Tuple[torch.distributions.Distribution, int, bool]:
72+
- `prior`: A PyTorch-compatible prior.
73+
- `theta_numel`: Dimensionality of a single sample from the prior.
74+
- `prior_returns_numpy`: Whether the prior originally returned NumPy arrays.
75+
76+
Example:
77+
--------
78+
79+
::
80+
81+
import torch
82+
from torch.distributions import Uniform
83+
from sbi.utils.user_input_checks import process_prior
84+
85+
prior = Uniform(torch.zeros(1), torch.ones(1))
86+
prior, theta_numel, prior_returns_numpy = process_prior(prior)
7087
"""
7188

7289
# If prior is a sequence, assume independent components and check as PyTorch prior.
@@ -456,14 +473,33 @@ def process_simulator(
456473
"""Returns a simulator that meets the requirements for usage in sbi.
457474
458475
Args:
459-
user_simulator: simulator provided by the user, possibly written in numpy.
460-
prior: prior as pytorch distribution or processed with `process_prior`.
461-
is_numpy_simulator: whether the simulator needs theta in numpy types, returned
476+
user_simulator (Callable):
477+
simulator provided by the user, possibly written in numpy.
478+
prior (torch.distributions.Distribution):
479+
prior as pytorch distribution or processed with :func:`process_prior`.
480+
is_numpy_simulator (bool):
481+
whether the simulator needs theta in numpy types, returned
462482
from `process_prior`.
463483
464484
Returns:
465-
simulator: processed simulator that returns `torch.Tensor` can handle batches
466-
of parameters.
485+
Callable:
486+
simulator: processed simulator that returns :class:`torch.Tensor` \
487+
and can handle batches of parameters.
488+
489+
Example:
490+
--------
491+
492+
::
493+
494+
import torch
495+
from sbi.utils.user_input_checks import process_simulator
496+
from torch.distributions import Uniform
497+
from sbi.utils.user_input_checks import process_prior
498+
499+
prior = Uniform(torch.zeros(1), torch.ones(1))
500+
prior, theta_numel, prior_returns_numpy = process_prior(prior)
501+
simulator = lambda theta: theta + 1
502+
simulator = process_simulator(simulator, prior, prior_returns_numpy)
467503
"""
468504

469505
assert isinstance(user_simulator, Callable), "Simulator must be a function."

sbi/utils/user_input_checks_utils.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -209,31 +209,43 @@ def to(self, device: Union[str, torch.device]) -> None:
209209

210210

211211
class MultipleIndependent(Distribution):
212-
"""Wrap a sequence of PyTorch distributions into a joint PyTorch distribution.
213-
214-
Every element of the sequence is treated as independent from the other elements.
215-
Single elements can be multivariate with dependent dimensions.
216-
217-
Example:
218-
219-
::
220-
221-
import torch
222-
from torch.distributions import Gamma, Beta, MultivariateNormal
223-
prior = MultipleIndependent([
224-
Gamma(torch.zeros(1), torch.ones(1)),
225-
Beta(torch.zeros(1), torch.ones(1)),
226-
MultivariateNormal(torch.ones(2), torch.tensor([[1, .1], [.1, 1.]]))
227-
])
228-
"""
212+
"""Wrap a sequence of PyTorch distributions into a joint PyTorch distribution."""
229213

230214
def __init__(
231215
self,
232216
dists: Sequence[Distribution],
233-
validate_args=None,
217+
validate_args: Optional[bool] = None,
234218
arg_constraints: Optional[Dict[str, constraints.Constraint]] = None,
235219
device: Optional[str] = None,
236220
):
221+
"""Joint distribution of multiple independent :class:`torch.distributions`.
222+
223+
Every element of the sequence is treated as independent from the \
224+
other elements. Single elements can be multivariate with dependent dimensions.
225+
226+
Args:
227+
dists: Sequence of PyTorch distributions.
228+
validate_args (Optional): If True, the distribution checks its parameters.
229+
arg_constraints (Optional): Dictionary of constraints for the parameters \
230+
of the distribution.
231+
device (Optional): Device to move the distribution to. If None, \
232+
the distribution is moved to the CPU.
233+
234+
Example:
235+
--------
236+
237+
::
238+
239+
import torch
240+
from torch.distributions import Gamma, Beta, MultivariateNormal
241+
from sbi.utils.user_input_checks_utils import MultipleIndependent
242+
243+
prior = MultipleIndependent([
244+
Gamma(torch.zeros(1), torch.ones(1)),
245+
Beta(torch.zeros(1), torch.ones(1)),
246+
MultivariateNormal(torch.ones(2), torch.tensor([[1, .1], [.1, 1.]]))
247+
])
248+
"""
237249
self._check_distributions(dists)
238250
if validate_args is not None:
239251
[d.set_default_validate_args(validate_args) for d in dists]
@@ -280,9 +292,9 @@ def _check_distribution(self, dist: Distribution):
280292
)
281293
assert isinstance(
282294
dist, Distribution
283-
), """priors passed to MultipleIndependent must be PyTorch distributions. Make
284-
sure to process custom priors individually using process_prior before
285-
passing them in a list to process_prior."""
295+
), """priors passed to MultipleIndependent must be PyTorch distributions. Make \
296+
sure to process custom priors individually using :func:`process_prior` \
297+
before passing them in a list to :func:`process_prior`."""
286298
# Make sure batch shape is smaller or equal to 1.
287299
assert dist.batch_shape in (
288300
torch.Size([1]),

0 commit comments

Comments
 (0)