Skip to content

Commit 5269ac1

Browse files
committed
Merge remote-tracking branch 'upstream/main' into feature/lc2st-mlp-gpu-support
2 parents 7716037 + 937efc2 commit 5269ac1

File tree

15 files changed

+524
-159
lines changed

15 files changed

+524
-159
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
## `sbi`: Simulation-Based Inference
1111

1212
[Getting Started](https://sbi.readthedocs.io/en/latest/tutorials/00_getting_started.html) |
13-
[Documentation](https://sbi.readthedocs.io/en/latest/) | [Discord Server](https://discord.gg/eEeVPSvWKy)
13+
[Documentation](https://sbi.readthedocs.io/en/latest/) | [Discord Server](https://discord.gg/VPkV7XPj7k)
1414

1515
`sbi` is a Python package for simulation-based inference, designed to meet the needs of
1616
both researchers and practitioners. Whether you need fine-grained control or an

docs/how_to_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Sampling
5757
how_to_guide/09_sampler_interface.ipynb
5858
how_to_guide/10_refine_posterior_with_importance_sampling.ipynb
5959
how_to_guide/11_iid_sampling_with_nle_or_nre.ipynb
60+
how_to_guide/23_using_pyro_with_sbi.ipynb
6061

6162

6263
Diagnostics

docs/how_to_guide/23_using_pyro_with_sbi.ipynb

Lines changed: 289 additions & 0 deletions
Large diffs are not rendered by default.

sbi/inference/posteriors/base_posterior.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,47 @@ def sample(
106106
sample_shape: Shape = torch.Size(),
107107
x: Optional[Tensor] = None,
108108
show_progress_bars: bool = True,
109-
mcmc_method: Optional[str] = None,
110-
mcmc_parameters: Optional[Dict[str, Any]] = None,
109+
**kwargs: Any,
111110
) -> Tensor:
112-
"""See child classes for docstring."""
111+
r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
112+
113+
Args:
114+
sample_shape: Shape of samples to draw.
115+
x: Conditioning observation $x_o$. If not provided, uses the default
116+
`x` set via `.set_default_x()`.
117+
show_progress_bars: Whether to show a progress bar during sampling.
118+
**kwargs: Additional keyword arguments passed to the specific
119+
posterior's sampling method. See the docstring of the specific
120+
posterior class for available options.
121+
122+
Returns:
123+
Samples from the posterior with shape `(*sample_shape, *theta_shape)`.
124+
"""
113125
pass
114126

115127
@abstractmethod
116128
def sample_batched(
117129
self,
118130
sample_shape: Shape,
119131
x: Tensor,
120-
max_sampling_batch_size: int = 10_000,
121132
show_progress_bars: bool = True,
133+
**kwargs: Any,
122134
) -> Tensor:
123-
"""See child classes for docstring."""
135+
r"""Draw samples from the posteriors for a batch of different xs.
136+
137+
Given a batch of observations `[x_1, ..., x_B]`, this method samples from
138+
posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.
139+
140+
Args:
141+
sample_shape: Shape of samples to draw for each observation.
142+
x: Batch of observations with shape `(batch_dim, *event_shape_x)`.
143+
show_progress_bars: Whether to show a progress bar during sampling.
144+
**kwargs: Additional keyword arguments passed to the specific
145+
posterior's sampling method.
146+
147+
Returns:
148+
Samples with shape `(*sample_shape, batch_dim, *theta_shape)`.
149+
"""
124150
pass
125151

126152
@property

sbi/inference/posteriors/direct_posterior.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,20 @@ def sample(
137137
sample_shape: Shape = torch.Size(),
138138
x: Optional[Tensor] = None,
139139
max_sampling_batch_size: int = 10_000,
140-
sample_with: Optional[str] = None,
141140
show_progress_bars: bool = True,
142141
reject_outside_prior: bool = True,
143142
max_sampling_time: Optional[float] = None,
143+
return_partial_on_timeout: bool = False,
144144
) -> Tensor:
145-
r"""Return samples from posterior distribution $p(\theta|x)$.
145+
r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
146146
147147
Args:
148148
sample_shape: Desired shape of samples that are drawn from posterior. If
149149
sample_shape is multidimensional we simply draw `sample_shape.numel()`
150150
samples and then reshape into the desired shape.
151-
sample_with: This argument only exists to keep backward-compatibility with
152-
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
151+
x: Conditioning observation $x_o$. If not provided, uses the default `x`
152+
set via `.set_default_x()`.
153+
max_sampling_batch_size: Maximum batch size for rejection sampling.
153154
show_progress_bars: Whether to show sampling progress monitor.
154155
reject_outside_prior: If True (default), rejection sampling is used to
155156
ensure samples lie within the prior support. If False, samples are drawn
@@ -159,6 +160,10 @@ def sample(
159160
If exceeded, sampling is aborted and a RuntimeError is raised. Only
160161
applies when `reject_outside_prior=True` (no effect otherwise since
161162
direct sampling is fast).
163+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
164+
return the samples collected so far instead of raising a RuntimeError.
165+
A warning will be issued. Only applies when `reject_outside_prior=True`
166+
(default).
162167
"""
163168
num_samples = torch.Size(sample_shape).numel()
164169
x = self._x_else_default_x(x)
@@ -180,13 +185,6 @@ def sample(
180185
else max_sampling_batch_size
181186
)
182187

183-
if sample_with is not None:
184-
raise ValueError(
185-
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
186-
f"`sample_with` is no longer supported. You have to rerun "
187-
f"`.build_posterior(sample_with={sample_with}).`"
188-
)
189-
190188
if reject_outside_prior:
191189
# Normal rejection behavior.
192190
samples = rejection.accept_reject_sample(
@@ -198,6 +196,7 @@ def sample(
198196
proposal_sampling_kwargs={"condition": x},
199197
alternative_method="build_posterior(..., sample_with='mcmc')",
200198
max_sampling_time=max_sampling_time,
199+
return_partial_on_timeout=return_partial_on_timeout,
201200
)[0]
202201
else:
203202
# Bypass rejection sampling entirely.
@@ -217,10 +216,12 @@ def sample_batched(
217216
show_progress_bars: bool = True,
218217
reject_outside_prior: bool = True,
219218
max_sampling_time: Optional[float] = None,
219+
return_partial_on_timeout: bool = False,
220220
) -> Tensor:
221-
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
222-
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
223-
manner.
221+
r"""Draw samples from the posteriors for a batch of different xs.
222+
223+
Given a batch of observations `[x_1, ..., x_B]`, this method samples from
224+
posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.
224225
225226
Args:
226227
sample_shape: Desired shape of samples that are drawn from the posterior
@@ -236,6 +237,9 @@ def sample_batched(
236237
max_sampling_time: Optional maximum allowed sampling time in seconds.
237238
If exceeded, sampling is aborted and a RuntimeError is raised. Only
238239
applies when `reject_outside_prior=True`.
240+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
241+
return the samples collected so far instead of raising a RuntimeError.
242+
A warning will be issued. Only applies when `reject_outside_prior=True`.
239243
240244
Returns:
241245
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -282,6 +286,7 @@ def sample_batched(
282286
proposal_sampling_kwargs={"condition": x},
283287
alternative_method="build_posterior(..., sample_with='mcmc')",
284288
max_sampling_time=max_sampling_time,
289+
return_partial_on_timeout=return_partial_on_timeout,
285290
)[0]
286291
else:
287292
# Bypass rejection sampling entirely.

sbi/inference/posteriors/importance_posterior.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,14 @@ def sample(
190190
method: Optional[str] = None,
191191
oversampling_factor: int = 32,
192192
max_sampling_batch_size: int = 10_000,
193-
sample_with: Optional[str] = None,
194193
show_progress_bars: bool = False,
195194
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
196-
"""Return samples from the approximate posterior distribution.
195+
"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
197196
198197
Args:
199198
sample_shape: Shape of samples that are drawn from posterior.
200-
x: Observed data.
199+
x: Conditioning observation $x_o$. If not provided, uses the default `x`
200+
set via `.set_default_x()`.
201201
method: Either of [`sir`|`importance`]. This sets the behavior of the
202202
`.sample()` method. With `sir`, approximate posterior samples are
203203
generated with sampling importance resampling (SIR). With
@@ -212,13 +212,6 @@ def sample(
212212

213213
method = self.method if method is None else method
214214

215-
if sample_with is not None:
216-
raise ValueError(
217-
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
218-
f"`sample_with` is no longer supported. You have to rerun "
219-
f"`.build_posterior(sample_with={sample_with}).`"
220-
)
221-
222215
self.potential_fn.set_x(self._x_else_default_x(x))
223216

224217
if method == "sir":
@@ -255,8 +248,6 @@ def _importance_sample(
255248
256249
Args:
257250
sample_shape: Desired shape of samples that are drawn from posterior.
258-
sample_with: This argument only exists to keep backward-compatibility with
259-
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
260251
show_progress_bars: Whether to show sampling progress monitor.
261252
262253
Returns:
@@ -286,13 +277,10 @@ def _sir_sample(
286277
sample_shape: Desired shape of samples that are drawn from posterior. If
287278
sample_shape is multidimensional we simply draw `sample_shape.numel()`
288279
samples and then reshape into the desired shape.
289-
x: Observed data.
290-
sample_with: This argument only exists to keep backward-compatibility with
291-
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
292-
oversampling_factor: Number of proposed samples form which only one is
280+
oversampling_factor: Number of proposed samples from which only one is
293281
selected based on its importance weight.
294-
max_sampling_batch_size: The batchsize of samples being drawn from
295-
the proposal at every iteration. Used only in `sir_sample()`.
282+
max_sampling_batch_size: The batch size of samples being drawn from
283+
the proposal at every iteration.
296284
show_progress_bars: Whether to show sampling progress monitor.
297285
298286
Returns:

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 25 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -257,29 +257,36 @@ def sample(
257257
num_chains: Optional[int] = None,
258258
init_strategy: Optional[str] = None,
259259
init_strategy_parameters: Optional[Dict[str, Any]] = None,
260-
init_strategy_num_candidates: Optional[int] = None,
261-
mcmc_parameters: Optional[Dict] = None,
262-
mcmc_method: Optional[str] = None,
263-
sample_with: Optional[str] = None,
264260
num_workers: Optional[int] = None,
265261
mp_context: Optional[str] = None,
266262
show_progress_bars: bool = True,
267263
) -> Tensor:
268-
r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC.
269-
270-
Check the `__init__()` method for a description of all arguments as well as
271-
their default values.
264+
r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
272265
273266
Args:
274267
sample_shape: Desired shape of samples that are drawn from posterior. If
275268
sample_shape is multidimensional we simply draw `sample_shape.numel()`
276269
samples and then reshape into the desired shape.
277-
mcmc_parameters: Dictionary that is passed only to support the API of
278-
`sbi` v0.17.2 or older.
279-
mcmc_method: This argument only exists to keep backward-compatibility with
280-
`sbi` v0.17.2 or older. Please use `method` instead.
281-
sample_with: This argument only exists to keep backward-compatibility with
282-
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
270+
x: Conditioning observation $x_o$. If not provided, uses the default `x`
271+
set via `.set_default_x()`.
272+
method: MCMC method to use. One of `slice_np`, `slice_np_vectorized`,
273+
`hmc_pyro`, `nuts_pyro`, `slice_pymc`, `hmc_pymc`, `nuts_pymc`.
274+
If not provided, uses the method specified at initialization.
275+
thin: Thinning factor for the chain. If not provided, uses the value
276+
specified at initialization.
277+
warmup_steps: Number of warmup steps to discard. If not provided, uses
278+
the value specified at initialization.
279+
num_chains: Number of MCMC chains to run. If not provided, uses the
280+
value specified at initialization.
281+
init_strategy: Initialization strategy for chains (`proposal`, `sir`,
282+
or `resample`). If not provided, uses the value specified at
283+
initialization.
284+
init_strategy_parameters: Parameters for the initialization strategy.
285+
If not provided, uses the value specified at initialization.
286+
num_workers: Number of CPU cores for parallelization. If not provided,
287+
uses the value specified at initialization.
288+
mp_context: Multiprocessing context (`fork` or `spawn`). If not provided,
289+
uses the value specified at initialization.
283290
show_progress_bars: Whether to show sampling progress monitor.
284291
285292
Returns:
@@ -301,46 +308,6 @@ def sample(
301308
if init_strategy_parameters is None
302309
else init_strategy_parameters
303310
)
304-
if init_strategy_num_candidates is not None:
305-
warn(
306-
f"Passing `init_strategy_num_candidates` is deprecated as of sbi \
307-
v0.19.0. Instead, use e.g., \
308-
`init_strategy_parameters={'num_candidate_samples': 1000}`",
309-
stacklevel=2,
310-
)
311-
self.init_strategy_parameters["num_candidate_samples"] = (
312-
init_strategy_num_candidates
313-
)
314-
if sample_with is not None:
315-
raise ValueError(
316-
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
317-
"`sample_with` is no longer supported. You have to rerun "
318-
f"`.build_posterior(sample_with={sample_with}).`"
319-
)
320-
if mcmc_method is not None:
321-
warn(
322-
"You passed `mcmc_method` to `.sample()`. As of sbi v0.18.0, this "
323-
"is deprecated and will be removed in a future release. Use `method` "
324-
"instead of `mcmc_method`.",
325-
stacklevel=2,
326-
)
327-
method = mcmc_method
328-
if mcmc_parameters:
329-
warn(
330-
"You passed `mcmc_parameters` to `.sample()`. As of sbi v0.18.0, this "
331-
"is deprecated and will be removed in a future release. Instead, pass "
332-
"the variable to `.sample()` directly, e.g. "
333-
"`posterior.sample((1,), num_chains=5)`.",
334-
stacklevel=2,
335-
)
336-
# The following lines are only for backwards compatibility with sbi v0.17.2 or
337-
# older.
338-
m_p = mcmc_parameters or {} # define to shorten the variable name
339-
method = _maybe_use_dict_entry(method, "mcmc_method", m_p)
340-
thin = _maybe_use_dict_entry(thin, "thin", m_p)
341-
warmup_steps = _maybe_use_dict_entry(warmup_steps, "warmup_steps", m_p)
342-
num_chains = _maybe_use_dict_entry(num_chains, "num_chains", m_p)
343-
init_strategy = _maybe_use_dict_entry(init_strategy, "init_strategy", m_p)
344311
self.potential_ = self._prepare_potential(method) # type: ignore
345312

346313
initial_params = self._get_initial_params(
@@ -415,9 +382,10 @@ def sample_batched(
415382
mp_context: Optional[str] = None,
416383
show_progress_bars: bool = True,
417384
) -> Tensor:
418-
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
419-
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
420-
manner.
385+
r"""Draw samples from the posteriors for a batch of different xs.
386+
387+
Given a batch of observations `[x_1, ..., x_B]`, this method samples from
388+
posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.
421389
422390
Check the `__init__()` method for a description of all arguments as well as
423391
their default values.
@@ -1134,24 +1102,6 @@ def _process_thin_default(thin: int) -> int:
11341102
return thin
11351103

11361104

1137-
def _maybe_use_dict_entry(default: Any, key: str, dict_to_check: Dict) -> Any:
1138-
"""Returns `default` if `key` is not in the dict and otherwise the dict entry.
1139-
1140-
This method exists only to keep backwards compatibility with `sbi` v0.17.2 or
1141-
older. It allows passing `mcmc_parameters` to `.sample()`.
1142-
1143-
Args:
1144-
default: The default value if `key` is not in `dict_to_check`.
1145-
key: The key for which to check in `dict_to_check`.
1146-
dict_to_check: The dictionary to be checked.
1147-
1148-
Returns:
1149-
The potentially replaced value.
1150-
"""
1151-
attribute = dict_to_check.get(key, default)
1152-
return attribute
1153-
1154-
11551105
def _num_required_args(func):
11561106
"""
11571107
Utility for counting the number of positional args in a function.

0 commit comments

Comments
 (0)