Skip to content

Commit 299f3d0

Browse files
authored
refactor: fix docstrings and signatures of sample* methods (#1719)
*refactor: fix docstrings and signatures of sample* methods - remove deprecated args - make base class consistent - clarify formulations in docstring. * fix: remove unused sampling args from vi tests. * fix pyright issues * fix fromreview comments
1 parent 649f5d3 commit 299f3d0

File tree

7 files changed

+91
-136
lines changed

7 files changed

+91
-136
lines changed

sbi/inference/posteriors/base_posterior.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,47 @@ def sample(
106106
sample_shape: Shape = torch.Size(),
107107
x: Optional[Tensor] = None,
108108
show_progress_bars: bool = True,
109-
mcmc_method: Optional[str] = None,
110-
mcmc_parameters: Optional[Dict[str, Any]] = None,
109+
**kwargs: Any,
111110
) -> Tensor:
112-
"""See child classes for docstring."""
111+
r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
112+
113+
Args:
114+
sample_shape: Shape of samples to draw.
115+
x: Conditioning observation $x_o$. If not provided, uses the default
116+
`x` set via `.set_default_x()`.
117+
show_progress_bars: Whether to show a progress bar during sampling.
118+
**kwargs: Additional keyword arguments passed to the specific
119+
posterior's sampling method. See the docstring of the specific
120+
posterior class for available options.
121+
122+
Returns:
123+
Samples from the posterior with shape `(*sample_shape, *theta_shape)`.
124+
"""
113125
pass
114126

115127
@abstractmethod
116128
def sample_batched(
117129
self,
118130
sample_shape: Shape,
119131
x: Tensor,
120-
max_sampling_batch_size: int = 10_000,
121132
show_progress_bars: bool = True,
133+
**kwargs: Any,
122134
) -> Tensor:
123-
"""See child classes for docstring."""
135+
r"""Draw samples from the posteriors for a batch of different xs.
136+
137+
Given a batch of observations `[x_1, ..., x_B]`, this method samples from
138+
posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.
139+
140+
Args:
141+
sample_shape: Shape of samples to draw for each observation.
142+
x: Batch of observations with shape `(batch_dim, *event_shape_x)`.
143+
show_progress_bars: Whether to show a progress bar during sampling.
144+
**kwargs: Additional keyword arguments passed to the specific
145+
posterior's sampling method.
146+
147+
Returns:
148+
Samples with shape `(*sample_shape, batch_dim, *theta_shape)`.
149+
"""
124150
pass
125151

126152
@property

sbi/inference/posteriors/direct_posterior.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,19 @@ def sample(
137137
sample_shape: Shape = torch.Size(),
138138
x: Optional[Tensor] = None,
139139
max_sampling_batch_size: int = 10_000,
140-
sample_with: Optional[str] = None,
141140
show_progress_bars: bool = True,
142141
reject_outside_prior: bool = True,
143142
max_sampling_time: Optional[float] = None,
144143
) -> Tensor:
145-
r"""Return samples from posterior distribution $p(\theta|x)$.
144+
r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
146145
147146
Args:
148147
sample_shape: Desired shape of samples that are drawn from posterior. If
149148
sample_shape is multidimensional we simply draw `sample_shape.numel()`
150149
samples and then reshape into the desired shape.
151-
sample_with: This argument only exists to keep backward-compatibility with
152-
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
150+
x: Conditioning observation $x_o$. If not provided, uses the default `x`
151+
set via `.set_default_x()`.
152+
max_sampling_batch_size: Maximum batch size for rejection sampling.
153153
show_progress_bars: Whether to show sampling progress monitor.
154154
reject_outside_prior: If True (default), rejection sampling is used to
155155
ensure samples lie within the prior support. If False, samples are drawn
@@ -180,13 +180,6 @@ def sample(
180180
else max_sampling_batch_size
181181
)
182182

183-
if sample_with is not None:
184-
raise ValueError(
185-
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
186-
f"`sample_with` is no longer supported. You have to rerun "
187-
f"`.build_posterior(sample_with={sample_with}).`"
188-
)
189-
190183
if reject_outside_prior:
191184
# Normal rejection behavior.
192185
samples = rejection.accept_reject_sample(
@@ -218,9 +211,10 @@ def sample_batched(
218211
reject_outside_prior: bool = True,
219212
max_sampling_time: Optional[float] = None,
220213
) -> Tensor:
221-
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
222-
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
223-
manner.
214+
r"""Draw samples from the posteriors for a batch of different xs.
215+
216+
Given a batch of observations `[x_1, ..., x_B]`, this method samples from
217+
posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.
224218
225219
Args:
226220
sample_shape: Desired shape of samples that are drawn from the posterior

sbi/inference/posteriors/importance_posterior.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,14 @@ def sample(
190190
method: Optional[str] = None,
191191
oversampling_factor: int = 32,
192192
max_sampling_batch_size: int = 10_000,
193-
sample_with: Optional[str] = None,
194193
show_progress_bars: bool = False,
195194
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
196-
"""Return samples from the approximate posterior distribution.
195+
"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
197196
198197
Args:
199198
sample_shape: Shape of samples that are drawn from posterior.
200-
x: Observed data.
199+
x: Conditioning observation $x_o$. If not provided, uses the default `x`
200+
set via `.set_default_x()`.
201201
method: Either of [`sir`|`importance`]. This sets the behavior of the
202202
`.sample()` method. With `sir`, approximate posterior samples are
203203
generated with sampling importance resampling (SIR). With
@@ -212,13 +212,6 @@ def sample(
212212

213213
method = self.method if method is None else method
214214

215-
if sample_with is not None:
216-
raise ValueError(
217-
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
218-
f"`sample_with` is no longer supported. You have to rerun "
219-
f"`.build_posterior(sample_with={sample_with}).`"
220-
)
221-
222215
self.potential_fn.set_x(self._x_else_default_x(x))
223216

224217
if method == "sir":
@@ -255,8 +248,6 @@ def _importance_sample(
255248
256249
Args:
257250
sample_shape: Desired shape of samples that are drawn from posterior.
258-
sample_with: This argument only exists to keep backward-compatibility with
259-
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
260251
show_progress_bars: Whether to show sampling progress monitor.
261252
262253
Returns:
@@ -286,13 +277,10 @@ def _sir_sample(
286277
sample_shape: Desired shape of samples that are drawn from posterior. If
287278
sample_shape is multidimensional we simply draw `sample_shape.numel()`
288279
samples and then reshape into the desired shape.
289-
x: Observed data.
290-
sample_with: This argument only exists to keep backward-compatibility with
291-
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
292-
oversampling_factor: Number of proposed samples form which only one is
280+
oversampling_factor: Number of proposed samples from which only one is
293281
selected based on its importance weight.
294-
max_sampling_batch_size: The batchsize of samples being drawn from
295-
the proposal at every iteration. Used only in `sir_sample()`.
282+
max_sampling_batch_size: The batch size of samples being drawn from
283+
the proposal at every iteration.
296284
show_progress_bars: Whether to show sampling progress monitor.
297285
298286
Returns:

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 25 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -257,29 +257,36 @@ def sample(
257257
num_chains: Optional[int] = None,
258258
init_strategy: Optional[str] = None,
259259
init_strategy_parameters: Optional[Dict[str, Any]] = None,
260-
init_strategy_num_candidates: Optional[int] = None,
261-
mcmc_parameters: Optional[Dict] = None,
262-
mcmc_method: Optional[str] = None,
263-
sample_with: Optional[str] = None,
264260
num_workers: Optional[int] = None,
265261
mp_context: Optional[str] = None,
266262
show_progress_bars: bool = True,
267263
) -> Tensor:
268-
r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC.
269-
270-
Check the `__init__()` method for a description of all arguments as well as
271-
their default values.
264+
r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
272265
273266
Args:
274267
sample_shape: Desired shape of samples that are drawn from posterior. If
275268
sample_shape is multidimensional we simply draw `sample_shape.numel()`
276269
samples and then reshape into the desired shape.
277-
mcmc_parameters: Dictionary that is passed only to support the API of
278-
`sbi` v0.17.2 or older.
279-
mcmc_method: This argument only exists to keep backward-compatibility with
280-
`sbi` v0.17.2 or older. Please use `method` instead.
281-
sample_with: This argument only exists to keep backward-compatibility with
282-
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
270+
x: Conditioning observation $x_o$. If not provided, uses the default `x`
271+
set via `.set_default_x()`.
272+
method: MCMC method to use. One of `slice_np`, `slice_np_vectorized`,
273+
`hmc_pyro`, `nuts_pyro`, `slice_pymc`, `hmc_pymc`, `nuts_pymc`.
274+
If not provided, uses the method specified at initialization.
275+
thin: Thinning factor for the chain. If not provided, uses the value
276+
specified at initialization.
277+
warmup_steps: Number of warmup steps to discard. If not provided, uses
278+
the value specified at initialization.
279+
num_chains: Number of MCMC chains to run. If not provided, uses the
280+
value specified at initialization.
281+
init_strategy: Initialization strategy for chains (`proposal`, `sir`,
282+
or `resample`). If not provided, uses the value specified at
283+
initialization.
284+
init_strategy_parameters: Parameters for the initialization strategy.
285+
If not provided, uses the value specified at initialization.
286+
num_workers: Number of CPU cores for parallelization. If not provided,
287+
uses the value specified at initialization.
288+
mp_context: Multiprocessing context (`fork` or `spawn`). If not provided,
289+
uses the value specified at initialization.
283290
show_progress_bars: Whether to show sampling progress monitor.
284291
285292
Returns:
@@ -301,46 +308,6 @@ def sample(
301308
if init_strategy_parameters is None
302309
else init_strategy_parameters
303310
)
304-
if init_strategy_num_candidates is not None:
305-
warn(
306-
f"Passing `init_strategy_num_candidates` is deprecated as of sbi \
307-
v0.19.0. Instead, use e.g., \
308-
`init_strategy_parameters={'num_candidate_samples': 1000}`",
309-
stacklevel=2,
310-
)
311-
self.init_strategy_parameters["num_candidate_samples"] = (
312-
init_strategy_num_candidates
313-
)
314-
if sample_with is not None:
315-
raise ValueError(
316-
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
317-
"`sample_with` is no longer supported. You have to rerun "
318-
f"`.build_posterior(sample_with={sample_with}).`"
319-
)
320-
if mcmc_method is not None:
321-
warn(
322-
"You passed `mcmc_method` to `.sample()`. As of sbi v0.18.0, this "
323-
"is deprecated and will be removed in a future release. Use `method` "
324-
"instead of `mcmc_method`.",
325-
stacklevel=2,
326-
)
327-
method = mcmc_method
328-
if mcmc_parameters:
329-
warn(
330-
"You passed `mcmc_parameters` to `.sample()`. As of sbi v0.18.0, this "
331-
"is deprecated and will be removed in a future release. Instead, pass "
332-
"the variable to `.sample()` directly, e.g. "
333-
"`posterior.sample((1,), num_chains=5)`.",
334-
stacklevel=2,
335-
)
336-
# The following lines are only for backwards compatibility with sbi v0.17.2 or
337-
# older.
338-
m_p = mcmc_parameters or {} # define to shorten the variable name
339-
method = _maybe_use_dict_entry(method, "mcmc_method", m_p)
340-
thin = _maybe_use_dict_entry(thin, "thin", m_p)
341-
warmup_steps = _maybe_use_dict_entry(warmup_steps, "warmup_steps", m_p)
342-
num_chains = _maybe_use_dict_entry(num_chains, "num_chains", m_p)
343-
init_strategy = _maybe_use_dict_entry(init_strategy, "init_strategy", m_p)
344311
self.potential_ = self._prepare_potential(method) # type: ignore
345312

346313
initial_params = self._get_initial_params(
@@ -415,9 +382,10 @@ def sample_batched(
415382
mp_context: Optional[str] = None,
416383
show_progress_bars: bool = True,
417384
) -> Tensor:
418-
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
419-
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
420-
manner.
385+
r"""Draw samples from the posteriors for a batch of different xs.
386+
387+
Given a batch of observations `[x_1, ..., x_B]`, this method samples from
388+
posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.
421389
422390
Check the `__init__()` method for a description of all arguments as well as
423391
their default values.
@@ -1134,24 +1102,6 @@ def _process_thin_default(thin: int) -> int:
11341102
return thin
11351103

11361104

1137-
def _maybe_use_dict_entry(default: Any, key: str, dict_to_check: Dict) -> Any:
1138-
"""Returns `default` if `key` is not in the dict and otherwise the dict entry.
1139-
1140-
This method exists only to keep backwards compatibility with `sbi` v0.17.2 or
1141-
older. It allows passing `mcmc_parameters` to `.sample()`.
1142-
1143-
Args:
1144-
default: The default value if `key` is not in `dict_to_check`.
1145-
key: The key for which to check in `dict_to_check`.
1146-
dict_to_check: The dictionary to be checked.
1147-
1148-
Returns:
1149-
The potentially replaced value.
1150-
"""
1151-
attribute = dict_to_check.get(key, default)
1152-
return attribute
1153-
1154-
11551105
def _num_required_args(func):
11561106
"""
11571107
Utility for counting the number of positional args in a function.

sbi/inference/posteriors/rejection_posterior.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,19 +135,26 @@ def sample(
135135
num_samples_to_find_max: Optional[int] = None,
136136
num_iter_to_find_max: Optional[int] = None,
137137
m: Optional[float] = None,
138-
sample_with: Optional[str] = None,
139138
show_progress_bars: bool = True,
140139
reject_outside_prior: bool = True,
141140
max_sampling_time: Optional[float] = None,
142141
):
143-
r"""Return samples from posterior $p(\theta|x)$ via rejection sampling.
142+
r"""Draw samples from the approximate posterior via rejection sampling.
144143
145144
Args:
146145
sample_shape: Desired shape of samples that are drawn from posterior. If
147146
sample_shape is multidimensional we simply draw `sample_shape.numel()`
148147
samples and then reshape into the desired shape.
149-
sample_with: This argument only exists to keep backward-compatibility with
150-
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
148+
x: Conditioning observation $x_o$. If not provided, uses the default `x`
149+
set via `.set_default_x()`.
150+
max_sampling_batch_size: Maximum batch size for rejection sampling.
151+
If not provided, uses the value specified at initialization.
152+
num_samples_to_find_max: Number of samples to find the maximum of the
153+
potential function. If not provided, uses the value from initialization.
154+
num_iter_to_find_max: Number of optimization iterations to find the
155+
maximum. If not provided, uses the value from initialization.
156+
m: Multiplier for the proposal distribution. If not provided, uses the
157+
value from initialization.
151158
show_progress_bars: Whether to show sampling progress monitor.
152159
reject_outside_prior: If True (default), rejection sampling is used to
153160
ensure samples lie within the prior support. If False, samples are drawn
@@ -166,12 +173,6 @@ def sample(
166173

167174
potential = partial(self.potential_fn, track_gradients=True)
168175

169-
if sample_with is not None:
170-
raise ValueError(
171-
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
172-
f"`sample_with` is no longer supported. You have to rerun "
173-
f"`.build_posterior(sample_with={sample_with}).`"
174-
)
175176
# Replace arguments that were not passed with their default.
176177
max_sampling_batch_size = (
177178
self.max_sampling_batch_size

sbi/inference/posteriors/vi_posterior.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,16 @@ def sample(
332332
self,
333333
sample_shape: Shape = torch.Size(),
334334
x: Optional[Tensor] = None,
335-
**kwargs,
335+
show_progress_bars: bool = True,
336336
) -> Tensor:
337-
"""Samples from the variational posterior distribution.
337+
r"""Draw samples from the variational posterior distribution $p(\theta|x)$.
338338
339339
Args:
340-
sample_shape: Shape of samples
340+
sample_shape: Desired shape of samples that are drawn from the posterior.
341+
x: Conditioning observation $x_o$. If not provided, uses the default `x`
342+
set via `.set_default_x()`.
343+
show_progress_bars: Unused for `VIPosterior` since sampling from the
344+
variational distribution is fast. Included for API consistency.
341345
342346
Returns:
343347
Samples from posterior.

0 commit comments

Comments
 (0)