@@ -39,34 +39,51 @@ def process_prior(
3939 prior : Union [Sequence [Distribution ], Distribution , rv_frozen , multi_rv_frozen ],
4040 custom_prior_wrapper_kwargs : Optional [Dict ] = None ,
4141) -> Tuple [Distribution , int , bool ]:
42- """Return PyTorch distribution-like prior from user-provided prior.
42+ """
43+ Return PyTorch distribution-like prior from user-provided prior.
4344
44- NOTE: If the prior argument is a sequence of PyTorch distributions, they will be
45- interpreted as independent prior dimensions wrapped in a `MultipleIndependent`
46- pytorch Distribution. In case the elements are not PyTorch distributions, make sure
47- to use process_prior on each element in the list beforehand.
45+ NOTE: If the prior argument is a sequence of PyTorch distributions, \
46+ they will be interpreted as independent prior dimensions wrapped in a \
47+ :class:`MultipleIndependent` PyTorch Distribution. In case the elements \
48+ are not PyTorch distributions, make sure to use :func:`process_prior` \
49+ on each element in the list beforehand.
4850
49- NOTE: returns a tuple (processed_prior, num_params, whether_prior_returns_numpy).
50- The last two entries in the tuple can be passed on to `process_simulator` to prepare
51- the simulator as well. For example, it will take care of casting parameters to numpy
52- or adding a batch dimension to the simulator output, if needed.
51+ NOTE: Returns a tuple `(processed_prior, num_params, \
52+ whether_prior_returns_numpy)`. The last two entries in the tuple\
53+ can be passed on to `process_simulator` to prepare the simulator. For \
54+ example, it ensures parameters are cast to numpy or adds a batch \
55+ dimension to the simulator output, if needed.
5356
5457 Args:
55- prior: Prior object with `.sample()` and `.log_prob()` as provided by the user,
56- or a sequence of such objects.
57- custom_prior_wrapper_kwargs: kwargs to be passed to the class that wraps a
58- custom prior into a pytorch Distribution, e.g., for passing bounds for a
59- prior with bounded support (lower_bound, upper_bound), or argument
60- constraints.
61- (arg_constraints), see pytorch.distributions.Distribution for more info.
58+ prior (:class:`~torch.distributions.distribution.Distribution` \
59+ or Sequence[:class:`~torch.distributions.distribution.Distribution`]):
60+ Prior object with `.sample()` and `.log_prob()`, or a sequence \
61+ of such objects.
62+ custom_prior_wrapper_kwargs (dict, optional):
63+ Additional arguments passed to the wrapper class that processes the prior
64+ into a PyTorch Distribution, such as bounds (`lower_bound`, `upper_bound`)
65+ or argument constraints (`arg_constraints`).
6266
6367 Raises:
64- AttributeError: If prior objects lacks `.sample()` or `.log_prob()`.
68+ AttributeError: If prior objects lack `.sample()` or `.log_prob()`.
6569
6670 Returns:
67- prior: Prior that emits samples and evaluates log prob as PyTorch Tensors.
68- theta_numel: Number of parameters - elements in a single sample from the prior.
69- prior_returns_numpy: Whether the return type of the prior was a Numpy array.
71+ Tuple[torch.distributions.Distribution, int, bool]:
72+ - `prior`: A PyTorch-compatible prior.
73+ - `theta_numel`: Dimensionality of a single sample from the prior.
74+ - `prior_returns_numpy`: Whether the prior originally returned NumPy arrays.
75+
76+ Example:
77+ --------
78+
79+ ::
80+
81+ import torch
82+ from torch.distributions import Uniform
83+ from sbi.utils.user_input_checks import process_prior
84+
85+ prior = Uniform(torch.zeros(1), torch.ones(1))
86+ prior, theta_numel, prior_returns_numpy = process_prior(prior)
7087 """
7188
7289 # If prior is a sequence, assume independent components and check as PyTorch prior.
@@ -456,14 +473,33 @@ def process_simulator(
456473 """Returns a simulator that meets the requirements for usage in sbi.
457474
458475 Args:
459- user_simulator: simulator provided by the user, possibly written in numpy.
460- prior: prior as pytorch distribution or processed with `process_prior`.
461- is_numpy_simulator: whether the simulator needs theta in numpy types, returned
476+ user_simulator (Callable):
477+ simulator provided by the user, possibly written in numpy.
478+ prior (torch.distributions.Distribution):
479+ prior as pytorch distribution or processed with :func:`process_prior`.
480+ is_numpy_simulator (bool):
481+ whether the simulator needs theta in numpy types, returned
462482 from `process_prior`.
463483
464484 Returns:
465- simulator: processed simulator that returns `torch.Tensor` can handle batches
466- of parameters.
485+ Callable:
486+ simulator: processed simulator that returns :class:`torch.Tensor` \
487+ and can handle batches of parameters.
488+
489+ Example:
490+ --------
491+
492+ ::
493+
494+ import torch
495+ from sbi.utils.user_input_checks import process_simulator
496+ from torch.distributions import Uniform
497+ from sbi.utils.user_input_checks import process_prior
498+
499+ prior = Uniform(torch.zeros(1), torch.ones(1))
500+ prior, theta_numel, prior_returns_numpy = process_prior(prior)
501+ simulator = lambda theta: theta + 1
502+ simulator = process_simulator(simulator, prior, prior_returns_numpy)
467503 """
468504
469505 assert isinstance (user_simulator , Callable ), "Simulator must be a function."
0 commit comments